diff --git a/examples/factoids.py b/examples/factoids.py index 336dc2b..9118c75 100644 --- a/examples/factoids.py +++ b/examples/factoids.py @@ -1,32 +1,34 @@ import asyncio, re -from argparse import ArgumentParser -from typing import Dict, List, Optional +from argparse import ArgumentParser +from typing import Dict, List, Optional from irctokens import build, Line -from ircrobots import Bot as BaseBot +from ircrobots import Bot as BaseBot from ircrobots import Server as BaseServer from ircrobots import ConnectionParams TRIGGER = "!" + def _delims(s: str, delim: str): s_copy = list(s) while s_copy: char = s_copy.pop(0) if char == delim: if not s_copy: - yield len(s)-(len(s_copy)+1) + yield len(s) - (len(s_copy) + 1) elif not s_copy.pop(0) == delim: - yield len(s)-(len(s_copy)+2) + yield len(s) - (len(s_copy) + 2) + def _sed(sed: str, s: str) -> Optional[str]: if len(sed) > 1: - delim = sed[1] - last = 0 + delim = sed[1] + last = 0 parts: List[str] = [] for i in _delims(sed, delim): parts.append(sed[last:i]) - last = i+1 + last = i + 1 if len(parts) == 4: break if last < (len(sed)): @@ -36,10 +38,10 @@ def _sed(sed: str, s: str) -> Optional[str]: flags_s = (args or [""])[0] flags = re.I if "i" in flags_s else 0 - count = 0 if "g" in flags_s else 1 + count = 0 if "g" in flags_s else 1 for i in reversed(list(_delims(replace, "&"))): - replace = replace[:i] + "\\g<0>" + replace[i+1:] + replace = replace[:i] + "\\g<0>" + replace[i + 1 :] try: compiled = re.compile(pattern, flags) @@ -49,18 +51,22 @@ def _sed(sed: str, s: str) -> Optional[str]: else: return None + class Database: def __init__(self): self._settings: Dict[str, str] = {} async def get(self, context: str, setting: str) -> Optional[str]: return self._settings.get(setting, None) + async def set(self, context: str, setting: str, value: str): self._settings[setting] = value + async def rem(self, context: str, setting: str): if setting in self._settings: del self._settings[setting] + class Server(BaseServer): def __init__(self, bot: Bot, name: str, channel: str, database: Database): super().__init__(bot, name) @@ -78,24 +84,24 @@ class Server(BaseServer): await self.send(build("JOIN", [self._channel])) if ( - line.command == "PRIVMSG" and - self.has_channel(line.params[0]) and - not line.hostmask is None and - not self.casefold(line.hostmask.nickname) == me and - self.has_user(line.hostmask.nickname) and - line.params[1].startswith(TRIGGER)): + line.command == "PRIVMSG" + and self.has_channel(line.params[0]) + and not line.hostmask is None + and not self.casefold(line.hostmask.nickname) == me + and self.has_user(line.hostmask.nickname) + and line.params[1].startswith(TRIGGER) + ): channel = self.channels[self.casefold(line.params[0])] - user = self.users[self.casefold(line.hostmask.nickname)] - cuser = channel.users[user.nickname_lower] - text = line.params[1].replace(TRIGGER, "", 1) + user = self.users[self.casefold(line.hostmask.nickname)] + cuser = channel.users[user.nickname_lower] + text = line.params[1].replace(TRIGGER, "", 1) db_context = f"{self.name}:{channel.name}" - name, _, text = text.partition(" ") + name, _, text = text.partition(" ") action, _, text = text.partition(" ") name = name.lower() - key = f"factoid-{name}" - + key = f"factoid-{name}" out = "" if not action or action == "@": @@ -125,10 +131,8 @@ class Server(BaseServer): elif value: changed = _sed(value, current) if not changed is None: - await self._database.set( - db_context, key, changed) - out = (f"{user.nickname}: " - f"changed '{name}' factoid") + await self._database.set(db_context, key, changed) + out = f"{user.nickname}: " f"changed '{name}' factoid" else: out = f"{user.nickname}: invalid sed" else: @@ -136,29 +140,28 @@ class Server(BaseServer): else: out = f"{user.nickname}: you are not an op" - else: out = f"{user.nickname}: unknown action '{action}'" await self.send(build("PRIVMSG", [line.params[0], out])) + class Bot(BaseBot): def __init__(self, channel: str): super().__init__() self._channel = channel + def create_server(self, name: str): return Server(self, name, self._channel, Database()) + async def main(hostname: str, channel: str, nickname: str): bot = Bot(channel) - params = ConnectionParams( - nickname, - hostname, - 6697 - ) + params = ConnectionParams(nickname, hostname, 6697) await bot.add_server("freenode", params) await bot.run() + if __name__ == "__main__": parser = ArgumentParser(description="A simple IRC bot for factoids") parser.add_argument("hostname") diff --git a/examples/sasl.py b/examples/sasl.py index 97c81fa..b3e66a2 100644 --- a/examples/sasl.py +++ b/examples/sasl.py @@ -5,28 +5,31 @@ from ircrobots import Bot as BaseBot from ircrobots import Server as BaseServer from ircrobots import ConnectionParams, SASLUserPass, SASLSCRAM + class Server(BaseServer): async def line_read(self, line: Line): print(f"{self.name} < {line.format()}") + async def line_send(self, line: Line): print(f"{self.name} > {line.format()}") + class Bot(BaseBot): def create_server(self, name: str): return Server(self, name) + async def main(): bot = Bot() sasl_params = SASLUserPass("myusername", "invalidpassword") - params = ConnectionParams( - "MyNickname", - host = "chat.freenode.invalid", - port = 6697, - sasl = sasl_params) + params = ConnectionParams( + "MyNickname", host="chat.freenode.invalid", port=6697, sasl=sasl_params + ) await bot.add_server("freenode", params) await bot.run() + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/simple.py b/examples/simple.py index e47fc1b..b9b336b 100644 --- a/examples/simple.py +++ b/examples/simple.py @@ -5,9 +5,8 @@ from ircrobots import Bot as BaseBot from ircrobots import Server as BaseServer from ircrobots import ConnectionParams -SERVERS = [ - ("freenode", "chat.freenode.invalid") -] +SERVERS = [("freenode", "chat.freenode.invalid")] + class Server(BaseServer): async def line_read(self, line: Line): @@ -15,13 +14,16 @@ class Server(BaseServer): if line.command == "001": print(f"connected to {self.isupport.network}") await self.send(build("JOIN", ["#testchannel"])) + async def line_send(self, line: Line): print(f"{self.name} > {line.format()}") + class Bot(BaseBot): def create_server(self, name: str): return Server(self, name) + async def main(): bot = Bot() for name, host in SERVERS: @@ -30,5 +32,6 @@ async def main(): await bot.run() + if __name__ == "__main__": asyncio.run(main()) diff --git a/ircrobots/__init__.py b/ircrobots/__init__.py index 5b798ed..b42a689 100644 --- a/ircrobots/__init__.py +++ b/ircrobots/__init__.py @@ -1,5 +1,11 @@ -from .bot import Bot +from .bot import Bot from .server import Server -from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM, - STSPolicy, ResumePolicy) -from .ircv3 import Capability +from .params import ( + ConnectionParams, + SASLUserPass, + SASLExternal, + SASLSCRAM, + STSPolicy, + ResumePolicy, +) +from .ircv3 import Capability diff --git a/ircrobots/asyncs.py b/ircrobots/asyncs.py index 54d0b3b..0a723a1 100644 --- a/ircrobots/asyncs.py +++ b/ircrobots/asyncs.py @@ -1,13 +1,14 @@ -from asyncio import Future -from typing import (Any, Awaitable, Callable, Generator, Generic, Optional, - TypeVar) +from asyncio import Future +from typing import Any, Awaitable, Callable, Generator, Generic, Optional, TypeVar -from irctokens import Line -from .matching import IMatchResponse +from irctokens import Line +from .matching import IMatchResponse from .interface import IServer -from .ircv3 import TAG_LABEL +from .ircv3 import TAG_LABEL TEvent = TypeVar("TEvent") + + class MaybeAwait(Generic[TEvent]): def __init__(self, func: Callable[[], Awaitable[TEvent]]): self._func = func @@ -16,13 +17,12 @@ class MaybeAwait(Generic[TEvent]): coro = self._func() return coro.__await__() + class WaitFor(object): - def __init__(self, - response: IMatchResponse, - deadline: float): + def __init__(self, response: IMatchResponse, deadline: float): self.response = response self.deadline = deadline - self._label: Optional[str] = None + self._label: Optional[str] = None self._our_fut: "Future[Line]" = Future() def __await__(self) -> Generator[Any, None, Line]: @@ -32,11 +32,9 @@ class WaitFor(object): self._label = label def match(self, server: IServer, line: Line): - if (self._label is not None and - line.tags is not None): + if self._label is not None and line.tags is not None: label = TAG_LABEL.get(line.tags) - if (label is not None and - label == self._label): + if label is not None and label == self._label: return True return self.response.match(server, line) diff --git a/ircrobots/bot.py b/ircrobots/bot.py index 11090d0..f509858 100644 --- a/ircrobots/bot.py +++ b/ircrobots/bot.py @@ -4,10 +4,11 @@ from typing import Dict from ircstates.server import ServerDisconnectedException -from .server import ConnectionParams, Server +from .server import ConnectionParams, Server from .transport import TCPTransport from .interface import IBot, IServer, ITCPTransport + class Bot(IBot): def __init__(self): self.servers: Dict[str, Server] = {} @@ -17,9 +18,11 @@ class Bot(IBot): return Server(self, name) async def disconnected(self, server: IServer): - if (server.name in self.servers and - server.params is not None and - server.disconnected): + if ( + server.name in self.servers + and server.params is not None + and server.disconnected + ): reconnect = server.params.reconnect @@ -30,7 +33,7 @@ class Bot(IBot): except Exception as e: traceback.print_exc() # let's try again, exponential backoff up to 5 mins - reconnect = min(reconnect*2, 300) + reconnect = min(reconnect * 2, 300) else: break @@ -38,10 +41,12 @@ class Bot(IBot): del self.servers[server.name] await server.disconnect() - async def add_server(self, - name: str, - params: ConnectionParams, - transport: ITCPTransport = TCPTransport()) -> Server: + async def add_server( + self, + name: str, + params: ConnectionParams, + transport: ITCPTransport = TCPTransport(), + ) -> Server: server = self.create_server(name) self.servers[name] = server await server.connect(transport, params) diff --git a/ircrobots/contexts.py b/ircrobots/contexts.py index 6b7bd3e..c9e275e 100644 --- a/ircrobots/contexts.py +++ b/ircrobots/contexts.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from .interface import IServer + @dataclass class ServerContext(object): server: IServer diff --git a/ircrobots/formatting.py b/ircrobots/formatting.py index d62e008..0bc34d0 100644 --- a/ircrobots/formatting.py +++ b/ircrobots/formatting.py @@ -1,19 +1,14 @@ from typing import List -BOLD = "\x02" -COLOR = "\x03" -INVERT = "\x16" -ITALIC = "\x1D" +BOLD = "\x02" +COLOR = "\x03" +INVERT = "\x16" +ITALIC = "\x1D" UNDERLINE = "\x1F" -RESET = "\x0F" +RESET = "\x0F" + +FORMATTERS = [BOLD, INVERT, ITALIC, UNDERLINE, RESET] -FORMATTERS = [ - BOLD, - INVERT, - ITALIC, - UNDERLINE, - RESET -] def tokens(s: str) -> List[str]: tokens: List[str] = [] @@ -25,9 +20,7 @@ def tokens(s: str) -> List[str]: for i in range(2): if s_copy and s_copy[0].isdigit(): token += s_copy.pop(0) - if (len(s_copy) > 1 and - s_copy[0] == "," and - s_copy[1].isdigit()): + if len(s_copy) > 1 and s_copy[0] == "," and s_copy[1].isdigit(): token += s_copy.pop(0) token += s_copy.pop(0) if s_copy and s_copy[0].isdigit(): @@ -38,6 +31,7 @@ def tokens(s: str) -> List[str]: tokens.append(token) return tokens + def strip(s: str): for token in tokens(s): s = s.replace(token, "", 1) diff --git a/ircrobots/glob.py b/ircrobots/glob.py index 0ad3c0a..080eff9 100644 --- a/ircrobots/glob.py +++ b/ircrobots/glob.py @@ -1,4 +1,3 @@ - def collapse(pattern: str) -> str: out = "" i = 0 @@ -15,9 +14,10 @@ def collapse(pattern: str) -> str: if pattern[i:]: out += pattern[i] - i += 1 + i += 1 return out + def _match(pattern: str, s: str): i, j = 0, 0 @@ -45,10 +45,14 @@ def _match(pattern: str, s: str): return i == len(pattern) + class Glob(object): def __init__(self, pattern: str): self._pattern = pattern + def match(self, s: str) -> bool: return _match(self._pattern, s) + + def compile(pattern: str) -> Glob: return Glob(collapse(pattern)) diff --git a/ircrobots/interface.py b/ircrobots/interface.py index db66353..45a4e14 100644 --- a/ircrobots/interface.py +++ b/ircrobots/interface.py @@ -1,16 +1,19 @@ from asyncio import Future -from typing import Awaitable, Iterable, List, Optional, Set, Tuple, Union -from enum import IntEnum +from typing import Awaitable, Iterable, List, Optional, Set, Tuple, Union +from enum import IntEnum from ircstates import Server, Emit from irctokens import Line, Hostmask -from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy +from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy from .security import TLS + class ITCPReader(object): async def read(self, byte_count: int): pass + + class ITCPWriter(object): def write(self, data: bytes): pass @@ -20,37 +23,40 @@ class ITCPWriter(object): async def drain(self): pass + async def close(self): pass + class ITCPTransport(object): - async def connect(self, - hostname: str, - port: int, - tls: Optional[TLS], - bindhost: Optional[str]=None - ) -> Tuple[ITCPReader, ITCPWriter]: + async def connect( + self, + hostname: str, + port: int, + tls: Optional[TLS], + bindhost: Optional[str] = None, + ) -> Tuple[ITCPReader, ITCPWriter]: pass + class SendPriority(IntEnum): - HIGH = 0 + HIGH = 0 MEDIUM = 10 - LOW = 20 + LOW = 20 DEFAULT = MEDIUM + class SentLine(object): - def __init__(self, - id: int, - priority: int, - line: Line): - self.id = id - self.priority = priority - self.line = line + def __init__(self, id: int, priority: int, line: Line): + self.id = id + self.priority = priority + self.line = line self.future: "Future[SentLine]" = Future() def __lt__(self, other: "SentLine") -> bool: return self.priority < other.priority + class ICapability(object): def available(self, capabilities: Iterable[str]) -> Optional[str]: pass @@ -61,38 +67,46 @@ class ICapability(object): def copy(self) -> "ICapability": pass + class IMatchResponse(object): def match(self, server: "IServer", line: Line) -> bool: pass + + class IMatchResponseParam(object): def match(self, server: "IServer", arg: str) -> bool: pass + + class IMatchResponseValueParam(IMatchResponseParam): def value(self, server: "IServer"): pass + def set_value(self, value: str): pass + + class IMatchResponseHostmask(object): def match(self, server: "IServer", hostmask: Hostmask) -> bool: pass + class IServer(Server): - bot: "IBot" + bot: "IBot" disconnected: bool - params: ConnectionParams + params: ConnectionParams desired_caps: Set[ICapability] - last_read: float + last_read: float - def send_raw(self, line: str, priority=SendPriority.DEFAULT - ) -> Awaitable[SentLine]: - pass - def send(self, line: Line, priority=SendPriority.DEFAULT - ) -> Awaitable[SentLine]: + def send_raw(self, line: str, priority=SendPriority.DEFAULT) -> Awaitable[SentLine]: pass - def wait_for(self, - response: Union[IMatchResponse, Set[IMatchResponse]] - ) -> Awaitable[Line]: + def send(self, line: Line, priority=SendPriority.DEFAULT) -> Awaitable[SentLine]: + pass + + def wait_for( + self, response: Union[IMatchResponse, Set[IMatchResponse]] + ) -> Awaitable[Line]: pass def set_throttle(self, rate: int, time: float): @@ -101,37 +115,44 @@ class IServer(Server): def server_address(self) -> Tuple[str, int]: pass - async def connect(self, - transport: ITCPTransport, - params: ConnectionParams): + async def connect(self, transport: ITCPTransport, params: ConnectionParams): pass + async def disconnect(self): pass def line_preread(self, line: Line): pass + def line_presend(self, line: Line): pass + async def line_read(self, line: Line): pass + async def line_send(self, line: Line): pass + async def sts_policy(self, sts: STSPolicy): pass + async def resume_policy(self, resume: ResumePolicy): pass def cap_agreed(self, capability: ICapability) -> bool: pass + def cap_available(self, capability: ICapability) -> Optional[str]: pass async def sasl_auth(self, sasl: SASLParams) -> bool: pass + class IBot(object): def create_server(self, name: str) -> IServer: pass + async def disconnected(self, server: IServer): pass diff --git a/ircrobots/ircv3.py b/ircrobots/ircv3.py index 359ca12..f518d90 100644 --- a/ircrobots/ircv3.py +++ b/ircrobots/ircv3.py @@ -1,22 +1,25 @@ -from time import time -from typing import Dict, Iterable, List, Optional, Tuple +from time import time +from typing import Dict, Iterable, List, Optional, Tuple from dataclasses import dataclass -from irctokens import build +from irctokens import build from ircstates.server import ServerDisconnectedException -from .contexts import ServerContext -from .matching import Response, ANY +from .contexts import ServerContext +from .matching import Response, ANY from .interface import ICapability -from .params import ConnectionParams, STSPolicy, ResumePolicy -from .security import TLS_VERIFYCHAIN +from .params import ConnectionParams, STSPolicy, ResumePolicy +from .security import TLS_VERIFYCHAIN + class Capability(ICapability): - def __init__(self, - ratified_name: Optional[str], - draft_name: Optional[str]=None, - alias: Optional[str]=None, - depends_on: List[str]=[]): - self.name = ratified_name + def __init__( + self, + ratified_name: Optional[str], + draft_name: Optional[str] = None, + alias: Optional[str] = None, + depends_on: List[str] = [], + ): + self.name = ratified_name self.draft = draft_name self.alias = alias or ratified_name self.depends_on = depends_on.copy() @@ -26,8 +29,7 @@ class Capability(ICapability): def match(self, capability: str) -> bool: return capability in self._caps - def available(self, capabilities: Iterable[str] - ) -> Optional[str]: + def available(self, capabilities: Iterable[str]) -> Optional[str]: for cap in self._caps: if not cap is None and cap in capabilities: return cap @@ -36,16 +38,13 @@ class Capability(ICapability): def copy(self): return Capability( - self.name, - self.draft, - alias=self.alias, - depends_on=self.depends_on[:]) + self.name, self.draft, alias=self.alias, depends_on=self.depends_on[:] + ) + class MessageTag(object): - def __init__(self, - name: Optional[str], - draft_name: Optional[str]=None): - self.name = name + def __init__(self, name: Optional[str], draft_name: Optional[str] = None): + self.name = name self.draft = draft_name self._tags = [self.name, self.draft] @@ -63,37 +62,36 @@ class MessageTag(object): else: return None -CAP_SASL = Capability("sasl") -CAP_ECHO = Capability("echo-message") -CAP_STS = Capability("sts", "draft/sts") + +CAP_SASL = Capability("sasl") +CAP_ECHO = Capability("echo-message") +CAP_STS = Capability("sts", "draft/sts") CAP_RESUME = Capability(None, "draft/resume-0.5", alias="resume") -CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2") -TAG_LABEL = MessageTag("label", "draft/label") +CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2") +TAG_LABEL = MessageTag("label", "draft/label") LABEL_TAG_MAP = { "draft/labeled-response-0.2": "draft/label", - "labeled-response": "label" + "labeled-response": "label", } CAPS: List[ICapability] = [ Capability("multi-prefix"), Capability("chghost"), Capability("away-notify"), - Capability("invite-notify"), Capability("account-tag"), Capability("account-notify"), Capability("extended-join"), - Capability("message-tags", "draft/message-tags-0.2"), Capability("cap-notify"), Capability("batch"), - Capability(None, "draft/rename", alias="rename"), Capability("setname", "draft/setname"), - CAP_RESUME + CAP_RESUME, ] + def _cap_dict(s: str) -> Dict[str, str]: d: Dict[str, str] = {} for token in s.split(","): @@ -101,41 +99,44 @@ def _cap_dict(s: str) -> Dict[str, str]: d[key] = value return d + async def sts_transmute(params: ConnectionParams): if not params.sts is None and params.tls is None: - now = time() - since = (now-params.sts.created) + now = time() + since = now - params.sts.created if since <= params.sts.duration: params.port = params.sts.port - params.tls = TLS_VERIFYCHAIN + params.tls = TLS_VERIFYCHAIN + + async def resume_transmute(params: ConnectionParams): if params.resume is not None: params.host = params.resume.address + class HandshakeCancel(Exception): pass + class CAPContext(ServerContext): async def on_ls(self, tokens: Dict[str, str]): await self._sts(tokens) - caps = list(self.server.desired_caps)+CAPS + caps = list(self.server.desired_caps) + CAPS - if (not self.server.params.sasl is None and - not CAP_SASL in caps): + if not self.server.params.sasl is None and not CAP_SASL in caps: caps.append(CAP_SASL) - matched = (c.available(tokens) for c in caps) + matched = (c.available(tokens) for c in caps) cap_names = [name for name in matched if not name is None] if cap_names: await self.server.send(build("CAP", ["REQ", " ".join(cap_names)])) while cap_names: - line = await self.server.wait_for({ - Response("CAP", [ANY, "ACK"]), - Response("CAP", [ANY, "NAK"]) - }) + line = await self.server.wait_for( + {Response("CAP", [ANY, "ACK"]), Response("CAP", [ANY, "NAK"])} + ) current_caps = line.params[2].split(" ") for cap in current_caps: @@ -144,8 +145,7 @@ class CAPContext(ServerContext): if CAP_RESUME.available(current_caps): await self.resume_token() - if (self.server.cap_agreed(CAP_SASL) and - not self.server.params.sasl is None): + if self.server.cap_agreed(CAP_SASL) and not self.server.params.sasl is None: await self.server.sasl_auth(self.server.params.sasl) async def resume_token(self): @@ -160,10 +160,9 @@ class CAPContext(ServerContext): if previous_policy is not None and not self.server.registered: await self.server.send(build("RESUME", [previous_policy.token])) - line = await self.server.wait_for({ - Response("RESUME", ["SUCCESS"]), - Response("FAIL", ["RESUME"]) - }) + line = await self.server.wait_for( + {Response("RESUME", ["SUCCESS"]), Response("FAIL", ["RESUME"])} + ) if line.command == "RESUME": raise HandshakeCancel() @@ -179,11 +178,11 @@ class CAPContext(ServerContext): cap_sts = CAP_STS.available(tokens) if not cap_sts is None: sts_dict = _cap_dict(tokens[cap_sts]) - params = self.server.params + params = self.server.params if not params.tls: if "port" in sts_dict: params.port = int(sts_dict["port"]) - params.tls = TLS_VERIFYCHAIN + params.tls = TLS_VERIFYCHAIN await self.server.bot.disconnect(self.server) await self.server.bot.add_server(self.server.name, params) @@ -194,6 +193,6 @@ class CAPContext(ServerContext): int(time()), params.port, int(sts_dict["duration"]), - "preload" in sts_dict) + "preload" in sts_dict, + ) await self.server.sts_policy(policy) - diff --git a/ircrobots/matching/__init__.py b/ircrobots/matching/__init__.py index 58608e6..3358c9a 100644 --- a/ircrobots/matching/__init__.py +++ b/ircrobots/matching/__init__.py @@ -1,3 +1,2 @@ - from .responses import * -from .params import * +from .params import * diff --git a/ircrobots/matching/params.py b/ircrobots/matching/params.py index c038db4..e1fd22c 100644 --- a/ircrobots/matching/params.py +++ b/ircrobots/matching/params.py @@ -1,16 +1,24 @@ -from re import compile as re_compile -from typing import Optional, Pattern, Union -from irctokens import Hostmask -from ..interface import (IMatchResponseParam, IMatchResponseValueParam, - IMatchResponseHostmask, IServer) -from ..glob import Glob, compile as glob_compile +from re import compile as re_compile +from typing import Optional, Pattern, Union +from irctokens import Hostmask +from ..interface import ( + IMatchResponseParam, + IMatchResponseValueParam, + IMatchResponseHostmask, + IServer, +) +from ..glob import Glob, compile as glob_compile from .. import formatting + class Any(IMatchResponseParam): def __repr__(self) -> str: return "Any()" + def match(self, server: IServer, arg: str) -> bool: return True + + ANY = Any() # NOT @@ -18,107 +26,142 @@ ANY = Any() # REGEX # LITERAL + class Literal(IMatchResponseValueParam): def __init__(self, value: str): self._value = value + def __repr__(self) -> str: return f"{self._value!r}" def value(self, server: IServer) -> str: return self._value + def set_value(self, value: str): self._value = value + def match(self, server: IServer, arg: str) -> bool: return arg == self._value -TYPE_MAYBELIT = Union[str, IMatchResponseParam] + +TYPE_MAYBELIT = Union[str, IMatchResponseParam] TYPE_MAYBELIT_VALUE = Union[str, IMatchResponseValueParam] + + def _assure_lit(value: TYPE_MAYBELIT_VALUE) -> IMatchResponseValueParam: if isinstance(value, str): return Literal(value) else: return value + class Not(IMatchResponseParam): def __init__(self, param: IMatchResponseParam): self._param = param + def __repr__(self) -> str: return f"Not({self._param!r})" + def match(self, server: IServer, arg: str) -> bool: return not self._param.match(server, arg) + class ParamValuePassthrough(IMatchResponseValueParam): _value: IMatchResponseValueParam + def value(self, server: IServer): return self._value.value(server) + def set_value(self, value: str): self._value.set_value(value) + class Folded(ParamValuePassthrough): def __init__(self, value: TYPE_MAYBELIT_VALUE): self._value = _assure_lit(value) self._folded = False + def __repr__(self) -> str: return f"Folded({self._value!r})" + def match(self, server: IServer, arg: str) -> bool: if not self._folded: - value = self.value(server) + value = self.value(server) folded = server.casefold(value) self.set_value(folded) self._folded = True return self._value.match(server, server.casefold(arg)) + class Formatless(IMatchResponseParam): def __init__(self, value: TYPE_MAYBELIT_VALUE): self._value = _assure_lit(value) + def __repr__(self) -> str: brepr = super().__repr__() return f"Formatless({brepr})" + def match(self, server: IServer, arg: str) -> bool: strip = formatting.strip(arg) return self._value.match(server, strip) + class Regex(IMatchResponseParam): def __init__(self, value: str): self._value = value self._pattern: Optional[Pattern] = None + def match(self, server: IServer, arg: str) -> bool: if self._pattern is None: self._pattern = re_compile(self._value) return bool(self._pattern.search(arg)) + class Self(IMatchResponseParam): def __repr__(self) -> str: return "Self()" + def match(self, server: IServer, arg: str) -> bool: return server.casefold(arg) == server.nickname_lower + + SELF = Self() + class MaskSelf(IMatchResponseHostmask): def __repr__(self) -> str: return "MaskSelf()" + def match(self, server: IServer, hostmask: Hostmask): return server.casefold(hostmask.nickname) == server.nickname_lower + + MASK_SELF = MaskSelf() + class Nick(IMatchResponseHostmask): def __init__(self, nickname: str): self._nickname = nickname self._folded: Optional[str] = None + def __repr__(self) -> str: return f"Nick({self._nickname!r})" + def match(self, server: IServer, hostmask: Hostmask): if self._folded is None: self._folded = server.casefold(self._nickname) return self._folded == server.casefold(hostmask.nickname) + class Mask(IMatchResponseHostmask): def __init__(self, mask: str): self._mask = mask self._compiled: Optional[Glob] + def __repr__(self) -> str: return f"Mask({self._mask!r})" + def match(self, server: IServer, hostmask: Hostmask): if self._compiled is None: self._compiled = glob_compile(self._mask) diff --git a/ircrobots/matching/responses.py b/ircrobots/matching/responses.py index ccc4c95..31ea4ba 100644 --- a/ircrobots/matching/responses.py +++ b/ircrobots/matching/responses.py @@ -1,17 +1,25 @@ -from typing import List, Optional, Sequence, Union -from irctokens import Line -from ..interface import (IServer, IMatchResponse, IMatchResponseParam, - IMatchResponseHostmask) -from .params import * +from typing import List, Optional, Sequence, Union +from irctokens import Line +from ..interface import ( + IServer, + IMatchResponse, + IMatchResponseParam, + IMatchResponseHostmask, +) +from .params import * TYPE_PARAM = Union[str, IMatchResponseParam] + + class Responses(IMatchResponse): - def __init__(self, - commands: Sequence[str], - params: Sequence[TYPE_PARAM]=[], - source: Optional[IMatchResponseHostmask]=None): + def __init__( + self, + commands: Sequence[str], + params: Sequence[TYPE_PARAM] = [], + source: Optional[IMatchResponseHostmask] = None, + ): self._commands = commands - self._source = source + self._source = source self._params: Sequence[IMatchResponseParam] = [] for param in params: @@ -25,36 +33,43 @@ class Responses(IMatchResponse): def match(self, server: IServer, line: Line) -> bool: for command in self._commands: - if (line.command == command and ( - self._source is None or ( - line.hostmask is not None and - self._source.match(server, line.hostmask) - ))): + if line.command == command and ( + self._source is None + or ( + line.hostmask is not None + and self._source.match(server, line.hostmask) + ) + ): for i, param in enumerate(self._params): - if (i >= len(line.params) or - not param.match(server, line.params[i])): + if i >= len(line.params) or not param.match(server, line.params[i]): break else: return True else: return False + class Response(Responses): - def __init__(self, - command: str, - params: Sequence[TYPE_PARAM]=[], - source: Optional[IMatchResponseHostmask]=None): + def __init__( + self, + command: str, + params: Sequence[TYPE_PARAM] = [], + source: Optional[IMatchResponseHostmask] = None, + ): super().__init__([command], params, source=source) def __repr__(self) -> str: return f"Response({self._commands[0]}: {self._params!r})" + class ResponseOr(IMatchResponse): def __init__(self, *responses: IMatchResponse): self._responses = responses + def __repr__(self) -> str: return f"ResponseOr({self._responses!r})" + def match(self, server: IServer, line: Line) -> bool: for response in self._responses: if response.match(server, line): diff --git a/ircrobots/params.py b/ircrobots/params.py index e52d6d6..89809f2 100644 --- a/ircrobots/params.py +++ b/ircrobots/params.py @@ -1,74 +1,80 @@ -from re import compile as re_compile -from typing import List, Optional +from re import compile as re_compile +from typing import List, Optional from dataclasses import dataclass, field from .security import TLS, TLS_NOVERIFY, TLS_VERIFYCHAIN + class SASLParams(object): mechanism: str + @dataclass class _SASLUserPass(SASLParams): - username: str - password: str + username: str + password: str + class SASLUserPass(_SASLUserPass): mechanism = "USERPASS" + + class SASLSCRAM(_SASLUserPass): mechanism = "SCRAM" + + class SASLExternal(SASLParams): mechanism = "EXTERNAL" + @dataclass class STSPolicy(object): - created: int - port: int + created: int + port: int duration: int - preload: bool + preload: bool + @dataclass class ResumePolicy(object): address: str - token: str + token: str + RE_IPV6HOST = re_compile("\[([a-fA-F0-9:]+)\]") -_TLS_TYPES = { - "+": TLS_VERIFYCHAIN, - "~": TLS_NOVERIFY -} +_TLS_TYPES = {"+": TLS_VERIFYCHAIN, "~": TLS_NOVERIFY} + + @dataclass class ConnectionParams(object): nickname: str - host: str - port: int - tls: Optional[TLS] = TLS_VERIFYCHAIN + host: str + port: int + tls: Optional[TLS] = TLS_VERIFYCHAIN username: Optional[str] = None realname: Optional[str] = None bindhost: Optional[str] = None - password: Optional[str] = None - sasl: Optional[SASLParams] = None + password: Optional[str] = None + sasl: Optional[SASLParams] = None - sts: Optional[STSPolicy] = None + sts: Optional[STSPolicy] = None resume: Optional[ResumePolicy] = None - reconnect: int = 10 # seconds + reconnect: int = 10 # seconds alt_nicknames: List[str] = field(default_factory=list) - autojoin: List[str] = field(default_factory=list) + autojoin: List[str] = field(default_factory=list) @staticmethod - def from_hoststring( - nickname: str, - hoststring: str - ) -> "ConnectionParams": + def from_hoststring(nickname: str, hoststring: str) -> "ConnectionParams": ipv6host = RE_IPV6HOST.search(hoststring) if ipv6host is not None and ipv6host.start() == 0: host = ipv6host.group(1) - port_s = hoststring[ipv6host.end()+1:] + port_s = hoststring[ipv6host.end() + 1 :] else: host, _, port_s = hoststring.strip().partition(":") diff --git a/ircrobots/sasl.py b/ircrobots/sasl.py index 8f3e21c..9e1ac65 100644 --- a/ircrobots/sasl.py +++ b/ircrobots/sasl.py @@ -1,52 +1,62 @@ -from typing import List -from enum import Enum -from base64 import b64decode, b64encode +from typing import List +from enum import Enum +from base64 import b64decode, b64encode from irctokens import build from ircstates.numerics import * from .matching import Responses, Response, ANY from .contexts import ServerContext -from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal -from .scram import SCRAMContext, SCRAMAlgorithm +from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal +from .scram import SCRAMContext, SCRAMAlgorithm SASL_SCRAM_MECHANISMS = [ "SCRAM-SHA-512", "SCRAM-SHA-256", "SCRAM-SHA-1", ] -SASL_USERPASS_MECHANISMS = SASL_SCRAM_MECHANISMS+["PLAIN"] +SASL_USERPASS_MECHANISMS = SASL_SCRAM_MECHANISMS + ["PLAIN"] + class SASLResult(Enum): - NONE = 0 + NONE = 0 SUCCESS = 1 FAILURE = 2 ALREADY = 3 + class SASLError(Exception): pass + + class SASLUnknownMechanismError(SASLError): pass + AUTH_BYTE_MAX = 400 AUTHENTICATE_ANY = Response("AUTHENTICATE", [ANY]) -NUMERICS_FAIL = Response(ERR_SASLFAIL) -NUMERICS_INITIAL = Responses([ - ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS, ERR_SASLABORTED -]) -NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL]) +NUMERICS_FAIL = Response(ERR_SASLFAIL) +NUMERICS_INITIAL = Responses( + [ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS, ERR_SASLABORTED] +) +NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL]) + def _b64e(s: str): return b64encode(s.encode("utf8")).decode("ascii") + def _b64eb(s: bytes) -> str: # encode-from-bytes return b64encode(s).decode("ascii") + + def _b64db(s: str) -> bytes: # decode-to-bytes return b64decode(s) + class SASLContext(ServerContext): async def from_params(self, params: SASLParams) -> SASLResult: if isinstance(params, SASLUserPass): @@ -57,15 +67,12 @@ class SASLContext(ServerContext): return await self.external() else: raise SASLUnknownMechanismError( - "SASLParams given with unknown mechanism " - f"{params.mechanism!r}") + "SASLParams given with unknown mechanism " f"{params.mechanism!r}" + ) async def external(self) -> SASLResult: await self.server.send(build("AUTHENTICATE", ["EXTERNAL"])) - line = await self.server.wait_for({ - AUTHENTICATE_ANY, - NUMERICS_INITIAL - }) + line = await self.server.wait_for({AUTHENTICATE_ANY, NUMERICS_INITIAL}) if line.command == "907": # we've done SASL already. cleanly abort @@ -73,8 +80,8 @@ class SASLContext(ServerContext): elif line.command == "908": available = line.params[1].split(",") raise SASLUnknownMechanismError( - "Server does not support SASL EXTERNAL " - f"(it supports {available}") + "Server does not support SASL EXTERNAL " f"(it supports {available}" + ) elif line.command == "AUTHENTICATE" and line.params[0] == "+": await self.server.send(build("AUTHENTICATE", ["+"])) @@ -89,11 +96,12 @@ class SASLContext(ServerContext): async def scram(self, username: str, password: str) -> SASLResult: return await self.userpass(username, password, SASL_SCRAM_MECHANISMS) - async def userpass(self, - username: str, - password: str, - mechanisms: List[str]=SASL_USERPASS_MECHANISMS - ) -> SASLResult: + async def userpass( + self, + username: str, + password: str, + mechanisms: List[str] = SASL_USERPASS_MECHANISMS, + ) -> SASLResult: def _common(server_mechs) -> List[str]: mechs: List[str] = [] for our_mech in mechanisms: @@ -106,23 +114,21 @@ class SASLContext(ServerContext): raise SASLUnknownMechanismError( "No matching SASL mechanims. " f"(we want: {mechanisms} " - f"server has: {server_mechs})") + f"server has: {server_mechs})" + ) if self.server.available_caps["sasl"]: # CAP v3.2 tells us what mechs it supports available = self.server.available_caps["sasl"].split(",") - match = _common(available) + match = _common(available) else: # CAP v3.1 does not. pick the pick and wait for 907 to inform us of # what mechanisms are supported - match = mechanisms + match = mechanisms while match: await self.server.send(build("AUTHENTICATE", [match[0]])) - line = await self.server.wait_for({ - AUTHENTICATE_ANY, - NUMERICS_INITIAL - }) + line = await self.server.wait_for({AUTHENTICATE_ANY, NUMERICS_INITIAL}) if line.command == "907": # we've done SASL already. cleanly abort @@ -130,7 +136,7 @@ class SASLContext(ServerContext): elif line.command == "908": # prior to CAP v3.2 - ERR telling us which mechs are supported available = line.params[1].split(",") - match = _common(available) + match = _common(available) await self.server.wait_for(NUMERICS_FAIL) elif line.command == "AUTHENTICATE" and line.params[0] == "+": auth_text = "" @@ -138,8 +144,7 @@ class SASLContext(ServerContext): if match[0] == "PLAIN": auth_text = f"{username}\0{username}\0{password}" elif match[0].startswith("SCRAM-SHA-"): - auth_text = await self._scram( - match[0], username, password) + auth_text = await self._scram(match[0], username, password) if not auth_text == "+": auth_text = _b64e(auth_text) @@ -148,7 +153,7 @@ class SASLContext(ServerContext): await self._send_auth_text(auth_text) line = await self.server.wait_for(NUMERICS_LAST) - if line.command == "903": + if line.command == "903": return SASLResult.SUCCESS elif line.command == "904": match.pop(0) @@ -157,11 +162,8 @@ class SASLContext(ServerContext): return SASLResult.FAILURE - async def _scram(self, algo_str: str, - username: str, - password: str) -> str: - algo_str_prep = algo_str.replace("SCRAM-", "", 1 - ).replace("-", "").upper() + async def _scram(self, algo_str: str, username: str, password: str) -> str: + algo_str_prep = algo_str.replace("SCRAM-", "", 1).replace("-", "").upper() try: algo = SCRAMAlgorithm(algo_str_prep) except ValueError: @@ -179,15 +181,15 @@ class SASLContext(ServerContext): line = await self.server.wait_for(AUTHENTICATE_ANY) server_final = _b64db(line.params[0]) - verified = scram.server_final(server_final) - #TODO PANIC if verified is false! + verified = scram.server_final(server_final) + # TODO PANIC if verified is false! return "+" else: return "" async def _send_auth_text(self, text: str): n = AUTH_BYTE_MAX - chunks = [text[i:i+n] for i in range(0, len(text), n)] + chunks = [text[i : i + n] for i in range(0, len(text), n)] if len(chunks[-1]) == 400: chunks.append("+") diff --git a/ircrobots/scram.py b/ircrobots/scram.py index 7de3cf8..10ce1ae 100644 --- a/ircrobots/scram.py +++ b/ircrobots/scram.py @@ -7,51 +7,60 @@ from typing import Dict # https://www.iana.org/assignments/hash-function-text-names/ # MD2 has been removed as it's unacceptably weak class SCRAMAlgorithm(Enum): - MD5 = "MD5" - SHA_1 = "SHA1" + MD5 = "MD5" + SHA_1 = "SHA1" SHA_224 = "SHA224" SHA_256 = "SHA256" SHA_384 = "SHA384" SHA_512 = "SHA512" + SCRAM_ERRORS = [ "invalid-encoding", - "extensions-not-supported", # unrecognized 'm' value + "extensions-not-supported", # unrecognized 'm' value "invalid-proof", "channel-bindings-dont-match", "server-does-support-channel-binding", "channel-binding-not-supported", "unsupported-channel-binding-type", "unknown-user", - "invalid-username-encoding", # invalid utf8 or bad SASLprep - "no-resources" + "invalid-username-encoding", # invalid utf8 or bad SASLprep + "no-resources", ] + def _scram_nonce() -> bytes: return base64.b64encode(os.urandom(32)) + + def _scram_escape(s: bytes) -> bytes: return s.replace(b"=", b"=3D").replace(b",", b"=2C") + + def _scram_unescape(s: bytes) -> bytes: return s.replace(b"=3D", b"=").replace(b"=2C", b",") + + def _scram_xor(s1: bytes, s2: bytes) -> bytes: return bytes(a ^ b for a, b in zip(s1, s2)) + class SCRAMState(Enum): - NONE = 0 - CLIENT_FIRST = 1 - CLIENT_FINAL = 2 - SUCCESS = 3 - FAILURE = 4 + NONE = 0 + CLIENT_FIRST = 1 + CLIENT_FINAL = 2 + SUCCESS = 3 + FAILURE = 4 VERIFY_FAILURE = 5 + class SCRAMError(Exception): pass + class SCRAMContext(object): - def __init__(self, algo: SCRAMAlgorithm, - username: str, - password: str): - self._algo = algo + def __init__(self, algo: SCRAMAlgorithm, username: str, password: str): + self._algo = algo self._username = username.encode("utf8") self._password = password.encode("utf8") @@ -59,11 +68,11 @@ class SCRAMContext(object): self.error = "" self.raw_error = "" - self._client_first = b"" - self._client_nonce = b"" + self._client_first = b"" + self._client_nonce = b"" self._salted_password = b"" - self._auth_message = b"" + self._auth_message = b"" def _get_pieces(self, data: bytes) -> Dict[bytes, bytes]: pieces = (piece.split(b"=", 1) for piece in data.split(b",")) @@ -71,6 +80,7 @@ class SCRAMContext(object): def _hmac(self, key: bytes, msg: bytes) -> bytes: return hmac.new(key, msg, self._algo.value).digest() + def _hash(self, msg: bytes) -> bytes: return hashlib.new(self._algo.value, msg).digest() @@ -89,7 +99,9 @@ class SCRAMContext(object): self.state = SCRAMState.CLIENT_FIRST self._client_nonce = _scram_nonce() self._client_first = b"n=%s,r=%s" % ( - _scram_escape(self._username), self._client_nonce) + _scram_escape(self._username), + self._client_nonce, + ) # n,,n=,r= return b"n,,%s" % self._client_first @@ -109,17 +121,17 @@ class SCRAMContext(object): if self._assert_error(pieces): return b"" - nonce = pieces[b"r"] # server combines your nonce with it's own - if (not nonce.startswith(self._client_nonce) or - nonce == self._client_nonce): + nonce = pieces[b"r"] # server combines your nonce with it's own + if not nonce.startswith(self._client_nonce) or nonce == self._client_nonce: self._fail("nonce-unacceptable") return b"" - salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded + salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded iterations = int(pieces[b"i"]) - salted_password = hashlib.pbkdf2_hmac(self._algo.value, - self._password, salt, iterations, dklen=None) + salted_password = hashlib.pbkdf2_hmac( + self._algo.value, self._password, salt, iterations, dklen=None + ) self._salted_password = salted_password client_key = self._hmac(salted_password, b"Client Key") diff --git a/ircrobots/security.py b/ircrobots/security.py index f10b700..1cfa5ad 100644 --- a/ircrobots/security.py +++ b/ircrobots/security.py @@ -1,26 +1,35 @@ import ssl + class TLS: pass + # tls without verification class TLSNoVerify(TLS): pass + + TLS_NOVERIFY = TLSNoVerify() # verify via CAs class TLSVerifyChain(TLS): pass + + TLS_VERIFYCHAIN = TLSVerifyChain() # verify by a pinned hash class TLSVerifyHash(TLSNoVerify): def __init__(self, sum: str): self.sum = sum.lower() + + class TLSVerifySHA512(TLSVerifyHash): pass -def tls_context(verify: bool=True) -> ssl.SSLContext: + +def tls_context(verify: bool = True) -> ssl.SSLContext: ctx = ssl.create_default_context() if not verify: ctx.check_hostname = False diff --git a/ircrobots/server.py b/ircrobots/server.py index d29946d..00b8028 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -1,36 +1,58 @@ import asyncio -from asyncio import Future, PriorityQueue -from typing import (AsyncIterable, Awaitable, Deque, Dict, Iterable, List, - Optional, Set, Tuple, Union) +from asyncio import Future, PriorityQueue +from typing import ( + AsyncIterable, + Awaitable, + Deque, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) from collections import deque -from time import monotonic +from time import monotonic import anyio -from asyncio_rlock import RLock -from asyncio_throttle import Throttler -from async_timeout import timeout as timeout_ -from ircstates import Emit, Channel, ChannelUser +from asyncio_rlock import RLock +from asyncio_throttle import Throttler +from async_timeout import timeout as timeout_ +from ircstates import Emit, Channel, ChannelUser from ircstates.numerics import * -from ircstates.server import ServerDisconnectedException -from ircstates.names import Name -from irctokens import build, Line, tokenise +from ircstates.server import ServerDisconnectedException +from ircstates.names import Name +from irctokens import build, Line, tokenise -from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL, - CAP_LABEL, LABEL_TAG_MAP, resume_transmute) -from .sasl import SASLContext, SASLResult -from .matching import (ResponseOr, Responses, Response, ANY, SELF, MASK_SELF, - Folded) -from .asyncs import MaybeAwait, WaitFor -from .struct import Whois -from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy -from .interface import (IBot, ICapability, IServer, SentLine, SendPriority, - IMatchResponse) +from .ircv3 import ( + CAPContext, + sts_transmute, + CAP_ECHO, + CAP_SASL, + CAP_LABEL, + LABEL_TAG_MAP, + resume_transmute, +) +from .sasl import SASLContext, SASLResult +from .matching import ResponseOr, Responses, Response, ANY, SELF, MASK_SELF, Folded +from .asyncs import MaybeAwait, WaitFor +from .struct import Whois +from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy +from .interface import ( + IBot, + ICapability, + IServer, + SentLine, + SendPriority, + IMatchResponse, +) from .interface import ITCPTransport, ITCPReader, ITCPWriter THROTTLE_RATE = 4 # lines THROTTLE_TIME = 2 # seconds -PING_TIMEOUT = 60 # seconds -WAIT_TIMEOUT = 20 # seconds +PING_TIMEOUT = 60 # seconds +WAIT_TIMEOUT = 20 # seconds JOIN_ERR_FIRST = [ ERR_NOSUCHCHANNEL, @@ -41,13 +63,14 @@ JOIN_ERR_FIRST = [ ERR_INVITEONLYCHAN, ERR_BADCHANNELKEY, ERR_NEEDREGGEDNICK, - ERR_THROTTLE + ERR_THROTTLE, ] + class Server(IServer): _reader: ITCPReader _writer: ITCPWriter - params: ConnectionParams + params: ConnectionParams def __init__(self, bot: IBot, name: str): super().__init__(name) @@ -58,23 +81,23 @@ class Server(IServer): self.throttle = Throttler(rate_limit=100, period=1) self.sasl_state = SASLResult.NONE - self.last_read = monotonic() + self.last_read = monotonic() - self._sent_count: int = 0 + self._sent_count: int = 0 self._send_queue: PriorityQueue[SentLine] = PriorityQueue() self.desired_caps: Set[ICapability] = set([]) - self._read_queue: Deque[Line] = deque() + self._read_queue: Deque[Line] = deque() self._process_queue: Deque[Tuple[Line, Optional[Emit]]] = deque() - self._ping_sent = False + self._ping_sent = False self._read_lguard = RLock() - self.read_lock = self._read_lguard - self._read_lwork = asyncio.Lock() - self._wait_for = asyncio.Event() + self.read_lock = self._read_lguard + self._read_lwork = asyncio.Lock() + self._wait_for = asyncio.Event() self._pending_who: Deque[str] = deque() - self._alt_nicks: List[str] = [] + self._alt_nicks: List[str] = [] def hostmask(self) -> str: hostmask = self.nickname @@ -84,13 +107,10 @@ class Server(IServer): hostmask += f"@{self.hostname}" return hostmask - def send_raw(self, line: str, priority=SendPriority.DEFAULT - ) -> Awaitable[SentLine]: + def send_raw(self, line: str, priority=SendPriority.DEFAULT) -> Awaitable[SentLine]: return self.send(tokenise(line), priority) - def send(self, - line: Line, - priority=SendPriority.DEFAULT - ) -> Awaitable[SentLine]: + + def send(self, line: Line, priority=SendPriority.DEFAULT) -> Awaitable[SentLine]: self.line_presend(line) sent_line = SentLine(self._sent_count, priority, line) @@ -110,28 +130,25 @@ class Server(IServer): def set_throttle(self, rate: int, time: float): self.throttle.rate_limit = rate - self.throttle.period = time + self.throttle.period = time def server_address(self) -> Tuple[str, int]: return self._writer.get_peer() - async def connect(self, - transport: ITCPTransport, - params: ConnectionParams): + async def connect(self, transport: ITCPTransport, params: ConnectionParams): await sts_transmute(params) await resume_transmute(params) reader, writer = await transport.connect( - params.host, - params.port, - tls =params.tls, - bindhost =params.bindhost) + params.host, params.port, tls=params.tls, bindhost=params.bindhost + ) self._reader = reader self._writer = writer self.params = params await self.handshake() + async def disconnect(self): if not self._writer is None: await self._writer.close() @@ -145,29 +162,35 @@ class Server(IServer): alt_nicks = self.params.alt_nicknames if not alt_nicks: - alt_nicks = [nickname+"_"*i for i in range(1, 4)] - self._alt_nicks = alt_nicks + alt_nicks = [nickname + "_" * i for i in range(1, 4)] + self._alt_nicks = alt_nicks # these must remain non-awaited; reading hasn't started yet if not self.params.password is None: self.send(build("PASS", [self.params.password])) - self.send(build("CAP", ["LS", "302"])) + self.send(build("CAP", ["LS", "302"])) self.send(build("NICK", [nickname])) self.send(build("USER", [username, "0", "*", realname])) # to be overridden def line_preread(self, line: Line): pass + def line_presend(self, line: Line): pass + async def line_read(self, line: Line): pass + async def line_send(self, line: Line): pass + async def sts_policy(self, sts: STSPolicy): pass + async def resume_policy(self, resume: ResumePolicy): pass + # /to be overriden async def _on_read(self, line: Line, emit: Optional[Emit]): @@ -176,13 +199,14 @@ class Server(IServer): elif line.command == RPL_ENDOFWHO: chan = self.casefold(line.params[1]) - if (self._pending_who and - self._pending_who[0] == chan): + if self._pending_who and self._pending_who[0] == chan: self._pending_who.popleft() await self._next_who() - elif (line.command in { - ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE - } and not self.registered): + elif ( + line.command + in {ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE} + and not self.registered + ): if self._alt_nicks: nick = self._alt_nicks.pop(0) await self.send(build("NICK", [nick])) @@ -203,8 +227,7 @@ class Server(IServer): await self._check_regain([line.params[1]]) elif line.command == RPL_MONOFFLINE: await self._check_regain(line.params[1].split(",")) - elif (line.command in ["NICK", "QUIT"] and - line.source is not None): + elif line.command in ["NICK", "QUIT"] and line.source is not None: await self._check_regain([line.hostmask.nickname]) elif emit is not None: @@ -216,10 +239,9 @@ class Server(IServer): await self._batch_joins(self.params.autojoin) elif emit.command == "CAP": - if emit.subcommand == "NEW": + if emit.subcommand == "NEW": await self._cap_ls(emit) - elif (emit.subcommand == "LS" and - emit.finished): + elif emit.subcommand == "LS" and emit.finished: if not self.registered: await CAPContext(self).handshake() else: @@ -227,7 +249,7 @@ class Server(IServer): elif emit.command == "JOIN": if emit.self and not emit.channel is None: - chan = emit.channel.name_lower + chan = emit.channel.name_lower await self.send(build("MODE", [chan])) modes = "".join(self.isupport.chanmodes.a_modes) @@ -241,18 +263,18 @@ class Server(IServer): async def _check_regain(self, nicks: List[str]): for nick in nicks: - if (self.casefold_equals(nick, self.params.nickname) and - not self.nickname == self.params.nickname): + if ( + self.casefold_equals(nick, self.params.nickname) + and not self.nickname == self.params.nickname + ): await self.send(build("NICK", [self.params.nickname])) - async def _batch_joins(self, - channels: List[str], - batch_n: int=10): - #TODO: do as many JOINs in one line as we can fit - #TODO: channel keys + async def _batch_joins(self, channels: List[str], batch_n: int = 10): + # TODO: do as many JOINs in one line as we can fit + # TODO: channel keys for i in range(0, len(channels), batch_n): - batch = channels[i:i+batch_n] + batch = channels[i : i + batch_n] await self.send(build("JOIN", [",".join(batch)])) async def _next_who(self): @@ -275,7 +297,7 @@ class Server(IServer): return None self.last_read = monotonic() - lines = self.recv(data) + lines = self.recv(data) for line in lines: self.line_preread(line) self._read_queue.append(line) @@ -287,10 +309,10 @@ class Server(IServer): if not self._process_queue: async with self._read_lwork: - read_aw = self._read_line(PING_TIMEOUT) + read_aw = self._read_line(PING_TIMEOUT) dones, notdones = await asyncio.wait( [read_aw, self._wait_for.wait()], - return_when=asyncio.FIRST_COMPLETED + return_when=asyncio.FIRST_COMPLETED, ) self._wait_for.clear() @@ -314,11 +336,12 @@ class Server(IServer): line, emit = self._process_queue.popleft() await self._on_read(line, emit) - async def wait_for(self, - response: Union[IMatchResponse, Set[IMatchResponse]], - sent_aw: Optional[Awaitable[SentLine]]=None, - timeout: float=WAIT_TIMEOUT - ) -> Line: + async def wait_for( + self, + response: Union[IMatchResponse, Set[IMatchResponse]], + sent_aw: Optional[Awaitable[SentLine]] = None, + timeout: float = WAIT_TIMEOUT, + ) -> Line: response_obj: IMatchResponse if isinstance(response, set): @@ -340,8 +363,9 @@ class Server(IServer): return line async def _on_send_line(self, line: Line): - if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and - not self.cap_agreed(CAP_ECHO)): + if line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and not self.cap_agreed( + CAP_ECHO + ): new_line = line.with_source(self.hostmask()) self._read_queue.append(new_line) @@ -349,15 +373,13 @@ class Server(IServer): while True: lines: List[SentLine] = [] - while (not lines or - (len(lines) < 5 and self._send_queue.qsize() > 0)): + while not lines or (len(lines) < 5 and self._send_queue.qsize() > 0): prio_line = await self._send_queue.get() lines.append(prio_line) for line in lines: async with self.throttle: - self._writer.write( - f"{line.line.format()}\r\n".encode("utf8")) + self._writer.write(f"{line.line.format()}\r\n".encode("utf8")) await self._writer.drain() @@ -369,6 +391,7 @@ class Server(IServer): # CAP-related def cap_agreed(self, capability: ICapability) -> bool: return bool(self.cap_available(capability)) + def cap_available(self, capability: ICapability) -> Optional[str]: return capability.available(self.agreed_caps) @@ -381,78 +404,81 @@ class Server(IServer): await CAPContext(self).on_ls(tokens) async def sasl_auth(self, params: SASLParams) -> bool: - if (self.sasl_state == SASLResult.NONE and - self.cap_agreed(CAP_SASL)): + if self.sasl_state == SASLResult.NONE and self.cap_agreed(CAP_SASL): res = await SASLContext(self).from_params(params) self.sasl_state = res return True else: return False + # /CAP-related def send_nick(self, new_nick: str) -> Awaitable[bool]: fut = self.send(build("NICK", [new_nick])) + async def _assure() -> bool: - line = await self.wait_for({ - Response("NICK", [Folded(new_nick)], source=MASK_SELF), - Responses([ - ERR_BANNICKCHANGE, - ERR_NICKTOOFAST, - ERR_CANTCHANGENICK - ], [ANY]), - Responses([ - ERR_NICKNAMEINUSE, - ERR_ERRONEUSNICKNAME, - ERR_UNAVAILRESOURCE - ], [ANY, Folded(new_nick)]) - }, fut) + line = await self.wait_for( + { + Response("NICK", [Folded(new_nick)], source=MASK_SELF), + Responses( + [ERR_BANNICKCHANGE, ERR_NICKTOOFAST, ERR_CANTCHANGENICK], [ANY] + ), + Responses( + [ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE], + [ANY, Folded(new_nick)], + ), + }, + fut, + ) return line.command == "NICK" + return MaybeAwait(_assure) - def send_join(self, - name: str, - key: Optional[str]=None - ) -> Awaitable[Channel]: + def send_join(self, name: str, key: Optional[str] = None) -> Awaitable[Channel]: fut = self.send_joins([name], [] if key is None else [key]) async def _assure(): channels = await fut return channels[0] + return MaybeAwait(_assure) + def send_part(self, name: str): fut = self.send(build("PART", [name])) async def _assure(): line = await self.wait_for( - Response("PART", [Folded(name)], source=MASK_SELF), - fut + Response("PART", [Folded(name)], source=MASK_SELF), fut ) return + return MaybeAwait(_assure) - def send_joins(self, - names: List[str], - keys: List[str]=[] - ) -> Awaitable[List[Channel]]: + def send_joins( + self, names: List[str], keys: List[str] = [] + ) -> Awaitable[List[Channel]]: folded_names = [self.casefold(name) for name in names] if not keys: fut = self.send(build("JOIN", [",".join(names)])) else: - fut = self.send(build("JOIN", [",".join(names)]+keys)) + fut = self.send(build("JOIN", [",".join(names)] + keys)) async def _assure(): channels: List[Channel] = [] while folded_names: - line = await self.wait_for({ - Response(RPL_CHANNELMODEIS, [ANY, ANY]), - Responses(JOIN_ERR_FIRST, [ANY, ANY]), - Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]), - Response(ERR_LINKCHANNEL, [ANY, ANY, ANY]) - }, fut) + line = await self.wait_for( + { + Response(RPL_CHANNELMODEIS, [ANY, ANY]), + Responses(JOIN_ERR_FIRST, [ANY, ANY]), + Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]), + Response(ERR_LINKCHANNEL, [ANY, ANY, ANY]), + }, + fut, + ) chan: Optional[str] = None if line.command == RPL_CHANNELMODEIS: @@ -462,7 +488,7 @@ class Server(IServer): elif line.command == ERR_USERONCHANNEL: chan = line.params[2] elif line.command == ERR_LINKCHANNEL: - #XXX i dont like this + # XXX i dont like this chan = line.params[2] await self.wait_for( Response(RPL_CHANNELMODEIS, [ANY, Folded(chan)]) @@ -477,51 +503,58 @@ class Server(IServer): channels.append(self.channels[folded]) return channels + return MaybeAwait(_assure) - def send_message(self, target: str, message: str - ) -> Awaitable[Optional[str]]: + def send_message(self, target: str, message: str) -> Awaitable[Optional[str]]: fut = self.send(build("PRIVMSG", [target, message])) + async def _assure(): line = await self.wait_for( - Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF), - fut + Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF), fut ) if line.command == "PRIVMSG": return line.params[1] else: return None + return MaybeAwait(_assure) - def send_whois(self, - target: str, - remote: bool=False - ) -> Awaitable[Optional[Whois]]: + def send_whois( + self, target: str, remote: bool = False + ) -> Awaitable[Optional[Whois]]: args = [target] if remote: args.append(target) fut = self.send(build("WHOIS", args)) + async def _assure() -> Optional[Whois]: folded = self.casefold(target) params = [ANY, Folded(folded)] obj = Whois() while True: - line = await self.wait_for(Responses([ - ERR_NOSUCHNICK, - ERR_NOSUCHSERVER, - RPL_WHOISUSER, - RPL_WHOISSERVER, - RPL_WHOISOPERATOR, - RPL_WHOISIDLE, - RPL_WHOISCHANNELS, - RPL_WHOISHOST, - RPL_WHOISACCOUNT, - RPL_WHOISSECURE, - RPL_ENDOFWHOIS - ], params), fut) - if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]: + line = await self.wait_for( + Responses( + [ + ERR_NOSUCHNICK, + ERR_NOSUCHSERVER, + RPL_WHOISUSER, + RPL_WHOISSERVER, + RPL_WHOISOPERATOR, + RPL_WHOISIDLE, + RPL_WHOISCHANNELS, + RPL_WHOISHOST, + RPL_WHOISACCOUNT, + RPL_WHOISSECURE, + RPL_ENDOFWHOIS, + ], + params, + ), + fut, + ) + if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]: return None elif line.command == RPL_WHOISUSER: nick, user, host, _, real = line.params[1:] @@ -531,7 +564,7 @@ class Server(IServer): obj.realname = real elif line.command == RPL_WHOISIDLE: idle, signon, _ = line.params[2:] - obj.idle = int(idle) + obj.idle = int(idle) obj.signon = int(signon) elif line.command == RPL_WHOISACCOUNT: obj.account = line.params[2] @@ -544,11 +577,11 @@ class Server(IServer): symbols = "" while channel[0] in self.isupport.prefix.prefixes: symbols += channel[0] - channel = channel[1:] + channel = channel[1:] channel_user = ChannelUser( Name(obj.nickname, folded), - Name(channel, self.casefold(channel)) + Name(channel, self.casefold(channel)), ) for symbol in symbols: mode = self.isupport.prefix.from_prefix(symbol) @@ -558,4 +591,5 @@ class Server(IServer): obj.channels.append(channel_user) elif line.command == RPL_ENDOFWHOIS: return obj + return MaybeAwait(_assure) diff --git a/ircrobots/struct.py b/ircrobots/struct.py index 1733e72..53c535d 100644 --- a/ircrobots/struct.py +++ b/ircrobots/struct.py @@ -3,21 +3,21 @@ from dataclasses import dataclass from ircstates import ChannelUser + class Whois(object): - server: Optional[str] = None - server_info: Optional[str] = None - operator: bool = False + server: Optional[str] = None + server_info: Optional[str] = None + operator: bool = False - secure: bool = False + secure: bool = False - signon: Optional[int] = None - idle: Optional[int] = None + signon: Optional[int] = None + idle: Optional[int] = None - channels: Optional[List[ChannelUser]] = None + channels: Optional[List[ChannelUser]] = None nickname: str = "" username: str = "" hostname: str = "" realname: str = "" - account: Optional[str] = None - + account: Optional[str] = None diff --git a/ircrobots/transport.py b/ircrobots/transport.py index a7cb330..2aa6146 100644 --- a/ircrobots/transport.py +++ b/ircrobots/transport.py @@ -1,12 +1,12 @@ -from hashlib import sha512 -from ssl import SSLContext -from typing import Optional, Tuple -from asyncio import StreamReader, StreamWriter +from hashlib import sha512 +from ssl import SSLContext +from typing import Optional, Tuple +from asyncio import StreamReader, StreamWriter from async_stagger import open_connection from .interface import ITCPTransport, ITCPReader, ITCPWriter -from .security import (tls_context, TLS, TLSNoVerify, TLSVerifyHash, - TLSVerifySHA512) +from .security import tls_context, TLS, TLSNoVerify, TLSVerifyHash, TLSVerifySHA512 + class TCPReader(ITCPReader): def __init__(self, reader: StreamReader): @@ -14,6 +14,8 @@ class TCPReader(ITCPReader): async def read(self, byte_count: int) -> bytes: return await self._reader.read(byte_count) + + class TCPWriter(ITCPWriter): def __init__(self, writer: StreamWriter): self._writer = writer @@ -32,13 +34,15 @@ class TCPWriter(ITCPWriter): self._writer.close() await self._writer.wait_closed() + class TCPTransport(ITCPTransport): - async def connect(self, - hostname: str, - port: int, - tls: Optional[TLS], - bindhost: Optional[str]=None - ) -> Tuple[ITCPReader, ITCPWriter]: + async def connect( + self, + hostname: str, + port: int, + tls: Optional[TLS], + bindhost: Optional[str] = None, + ) -> Tuple[ITCPReader, ITCPWriter]: cur_ssl: Optional[SSLContext] = None if tls is not None: @@ -54,22 +58,20 @@ class TCPTransport(ITCPTransport): hostname, port, server_hostname=server_hostname, - ssl =cur_ssl, - local_addr =local_addr) + ssl=cur_ssl, + local_addr=local_addr, + ) if isinstance(tls, TLSVerifyHash): - cert: bytes = writer.transport.get_extra_info( - "ssl_object" - ).getpeercert(True) + cert: bytes = writer.transport.get_extra_info("ssl_object").getpeercert( + True + ) if isinstance(tls, TLSVerifySHA512): sum = sha512(cert).hexdigest() else: raise ValueError(f"unknown hash pinning {type(tls)}") if not sum == tls.sum: - raise ValueError( - f"pinned hash for {hostname} does not match ({sum})" - ) + raise ValueError(f"pinned hash for {hostname} does not match ({sum})") return (TCPReader(reader), TCPWriter(writer)) - diff --git a/setup.py b/setup.py index b4b6b17..f4e2dc5 100644 --- a/setup.py +++ b/setup.py @@ -24,8 +24,8 @@ setup( "Operating System :: OS Independent", "Operating System :: POSIX", "Operating System :: Microsoft :: Windows", - "Topic :: Communications :: Chat :: Internet Relay Chat" + "Topic :: Communications :: Chat :: Internet Relay Chat", ], - python_requires='>=3.7', - install_requires=install_requires + python_requires=">=3.7", + install_requires=install_requires, ) diff --git a/test/glob.py b/test/glob.py index ddc6297..5109991 100644 --- a/test/glob.py +++ b/test/glob.py @@ -1,6 +1,7 @@ import unittest from ircrobots import glob + class GlobTestCollapse(unittest.TestCase): def test(self): c1 = glob.collapse("**?*")