formatted with black

This commit is contained in:
A_D 2022-01-30 20:05:31 +02:00
parent 0ce3b9b0b0
commit 9b31aff951
No known key found for this signature in database
GPG Key ID: 4BE9EB7DF45076C4
23 changed files with 629 additions and 469 deletions

View File

@ -1,32 +1,34 @@
import asyncio, re import asyncio, re
from argparse import ArgumentParser from argparse import ArgumentParser
from typing import Dict, List, Optional from typing import Dict, List, Optional
from irctokens import build, Line from irctokens import build, Line
from ircrobots import Bot as BaseBot from ircrobots import Bot as BaseBot
from ircrobots import Server as BaseServer from ircrobots import Server as BaseServer
from ircrobots import ConnectionParams from ircrobots import ConnectionParams
TRIGGER = "!" TRIGGER = "!"
def _delims(s: str, delim: str): def _delims(s: str, delim: str):
s_copy = list(s) s_copy = list(s)
while s_copy: while s_copy:
char = s_copy.pop(0) char = s_copy.pop(0)
if char == delim: if char == delim:
if not s_copy: if not s_copy:
yield len(s)-(len(s_copy)+1) yield len(s) - (len(s_copy) + 1)
elif not s_copy.pop(0) == delim: elif not s_copy.pop(0) == delim:
yield len(s)-(len(s_copy)+2) yield len(s) - (len(s_copy) + 2)
def _sed(sed: str, s: str) -> Optional[str]: def _sed(sed: str, s: str) -> Optional[str]:
if len(sed) > 1: if len(sed) > 1:
delim = sed[1] delim = sed[1]
last = 0 last = 0
parts: List[str] = [] parts: List[str] = []
for i in _delims(sed, delim): for i in _delims(sed, delim):
parts.append(sed[last:i]) parts.append(sed[last:i])
last = i+1 last = i + 1
if len(parts) == 4: if len(parts) == 4:
break break
if last < (len(sed)): if last < (len(sed)):
@ -36,10 +38,10 @@ def _sed(sed: str, s: str) -> Optional[str]:
flags_s = (args or [""])[0] flags_s = (args or [""])[0]
flags = re.I if "i" in flags_s else 0 flags = re.I if "i" in flags_s else 0
count = 0 if "g" in flags_s else 1 count = 0 if "g" in flags_s else 1
for i in reversed(list(_delims(replace, "&"))): for i in reversed(list(_delims(replace, "&"))):
replace = replace[:i] + "\\g<0>" + replace[i+1:] replace = replace[:i] + "\\g<0>" + replace[i + 1 :]
try: try:
compiled = re.compile(pattern, flags) compiled = re.compile(pattern, flags)
@ -49,18 +51,22 @@ def _sed(sed: str, s: str) -> Optional[str]:
else: else:
return None return None
class Database: class Database:
def __init__(self): def __init__(self):
self._settings: Dict[str, str] = {} self._settings: Dict[str, str] = {}
async def get(self, context: str, setting: str) -> Optional[str]: async def get(self, context: str, setting: str) -> Optional[str]:
return self._settings.get(setting, None) return self._settings.get(setting, None)
async def set(self, context: str, setting: str, value: str): async def set(self, context: str, setting: str, value: str):
self._settings[setting] = value self._settings[setting] = value
async def rem(self, context: str, setting: str): async def rem(self, context: str, setting: str):
if setting in self._settings: if setting in self._settings:
del self._settings[setting] del self._settings[setting]
class Server(BaseServer): class Server(BaseServer):
def __init__(self, bot: Bot, name: str, channel: str, database: Database): def __init__(self, bot: Bot, name: str, channel: str, database: Database):
super().__init__(bot, name) super().__init__(bot, name)
@ -78,24 +84,24 @@ class Server(BaseServer):
await self.send(build("JOIN", [self._channel])) await self.send(build("JOIN", [self._channel]))
if ( if (
line.command == "PRIVMSG" and line.command == "PRIVMSG"
self.has_channel(line.params[0]) and and self.has_channel(line.params[0])
not line.hostmask is None and and not line.hostmask is None
not self.casefold(line.hostmask.nickname) == me and and not self.casefold(line.hostmask.nickname) == me
self.has_user(line.hostmask.nickname) and and self.has_user(line.hostmask.nickname)
line.params[1].startswith(TRIGGER)): and line.params[1].startswith(TRIGGER)
):
channel = self.channels[self.casefold(line.params[0])] channel = self.channels[self.casefold(line.params[0])]
user = self.users[self.casefold(line.hostmask.nickname)] user = self.users[self.casefold(line.hostmask.nickname)]
cuser = channel.users[user.nickname_lower] cuser = channel.users[user.nickname_lower]
text = line.params[1].replace(TRIGGER, "", 1) text = line.params[1].replace(TRIGGER, "", 1)
db_context = f"{self.name}:{channel.name}" db_context = f"{self.name}:{channel.name}"
name, _, text = text.partition(" ") name, _, text = text.partition(" ")
action, _, text = text.partition(" ") action, _, text = text.partition(" ")
name = name.lower() name = name.lower()
key = f"factoid-{name}" key = f"factoid-{name}"
out = "" out = ""
if not action or action == "@": if not action or action == "@":
@ -125,10 +131,8 @@ class Server(BaseServer):
elif value: elif value:
changed = _sed(value, current) changed = _sed(value, current)
if not changed is None: if not changed is None:
await self._database.set( await self._database.set(db_context, key, changed)
db_context, key, changed) out = f"{user.nickname}: " f"changed '{name}' factoid"
out = (f"{user.nickname}: "
f"changed '{name}' factoid")
else: else:
out = f"{user.nickname}: invalid sed" out = f"{user.nickname}: invalid sed"
else: else:
@ -136,29 +140,28 @@ class Server(BaseServer):
else: else:
out = f"{user.nickname}: you are not an op" out = f"{user.nickname}: you are not an op"
else: else:
out = f"{user.nickname}: unknown action '{action}'" out = f"{user.nickname}: unknown action '{action}'"
await self.send(build("PRIVMSG", [line.params[0], out])) await self.send(build("PRIVMSG", [line.params[0], out]))
class Bot(BaseBot): class Bot(BaseBot):
def __init__(self, channel: str): def __init__(self, channel: str):
super().__init__() super().__init__()
self._channel = channel self._channel = channel
def create_server(self, name: str): def create_server(self, name: str):
return Server(self, name, self._channel, Database()) return Server(self, name, self._channel, Database())
async def main(hostname: str, channel: str, nickname: str): async def main(hostname: str, channel: str, nickname: str):
bot = Bot(channel) bot = Bot(channel)
params = ConnectionParams( params = ConnectionParams(nickname, hostname, 6697)
nickname,
hostname,
6697
)
await bot.add_server("freenode", params) await bot.add_server("freenode", params)
await bot.run() await bot.run()
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser(description="A simple IRC bot for factoids") parser = ArgumentParser(description="A simple IRC bot for factoids")
parser.add_argument("hostname") parser.add_argument("hostname")

View File

@ -5,28 +5,31 @@ from ircrobots import Bot as BaseBot
from ircrobots import Server as BaseServer from ircrobots import Server as BaseServer
from ircrobots import ConnectionParams, SASLUserPass, SASLSCRAM from ircrobots import ConnectionParams, SASLUserPass, SASLSCRAM
class Server(BaseServer): class Server(BaseServer):
async def line_read(self, line: Line): async def line_read(self, line: Line):
print(f"{self.name} < {line.format()}") print(f"{self.name} < {line.format()}")
async def line_send(self, line: Line): async def line_send(self, line: Line):
print(f"{self.name} > {line.format()}") print(f"{self.name} > {line.format()}")
class Bot(BaseBot): class Bot(BaseBot):
def create_server(self, name: str): def create_server(self, name: str):
return Server(self, name) return Server(self, name)
async def main(): async def main():
bot = Bot() bot = Bot()
sasl_params = SASLUserPass("myusername", "invalidpassword") sasl_params = SASLUserPass("myusername", "invalidpassword")
params = ConnectionParams( params = ConnectionParams(
"MyNickname", "MyNickname", host="chat.freenode.invalid", port=6697, sasl=sasl_params
host = "chat.freenode.invalid", )
port = 6697,
sasl = sasl_params)
await bot.add_server("freenode", params) await bot.add_server("freenode", params)
await bot.run() await bot.run()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@ -5,9 +5,8 @@ from ircrobots import Bot as BaseBot
from ircrobots import Server as BaseServer from ircrobots import Server as BaseServer
from ircrobots import ConnectionParams from ircrobots import ConnectionParams
SERVERS = [ SERVERS = [("freenode", "chat.freenode.invalid")]
("freenode", "chat.freenode.invalid")
]
class Server(BaseServer): class Server(BaseServer):
async def line_read(self, line: Line): async def line_read(self, line: Line):
@ -15,13 +14,16 @@ class Server(BaseServer):
if line.command == "001": if line.command == "001":
print(f"connected to {self.isupport.network}") print(f"connected to {self.isupport.network}")
await self.send(build("JOIN", ["#testchannel"])) await self.send(build("JOIN", ["#testchannel"]))
async def line_send(self, line: Line): async def line_send(self, line: Line):
print(f"{self.name} > {line.format()}") print(f"{self.name} > {line.format()}")
class Bot(BaseBot): class Bot(BaseBot):
def create_server(self, name: str): def create_server(self, name: str):
return Server(self, name) return Server(self, name)
async def main(): async def main():
bot = Bot() bot = Bot()
for name, host in SERVERS: for name, host in SERVERS:
@ -30,5 +32,6 @@ async def main():
await bot.run() await bot.run()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@ -1,5 +1,11 @@
from .bot import Bot from .bot import Bot
from .server import Server from .server import Server
from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM, from .params import (
STSPolicy, ResumePolicy) ConnectionParams,
from .ircv3 import Capability SASLUserPass,
SASLExternal,
SASLSCRAM,
STSPolicy,
ResumePolicy,
)
from .ircv3 import Capability

View File

@ -1,13 +1,14 @@
from asyncio import Future from asyncio import Future
from typing import (Any, Awaitable, Callable, Generator, Generic, Optional, from typing import Any, Awaitable, Callable, Generator, Generic, Optional, TypeVar
TypeVar)
from irctokens import Line from irctokens import Line
from .matching import IMatchResponse from .matching import IMatchResponse
from .interface import IServer from .interface import IServer
from .ircv3 import TAG_LABEL from .ircv3 import TAG_LABEL
TEvent = TypeVar("TEvent") TEvent = TypeVar("TEvent")
class MaybeAwait(Generic[TEvent]): class MaybeAwait(Generic[TEvent]):
def __init__(self, func: Callable[[], Awaitable[TEvent]]): def __init__(self, func: Callable[[], Awaitable[TEvent]]):
self._func = func self._func = func
@ -16,13 +17,12 @@ class MaybeAwait(Generic[TEvent]):
coro = self._func() coro = self._func()
return coro.__await__() return coro.__await__()
class WaitFor(object): class WaitFor(object):
def __init__(self, def __init__(self, response: IMatchResponse, deadline: float):
response: IMatchResponse,
deadline: float):
self.response = response self.response = response
self.deadline = deadline self.deadline = deadline
self._label: Optional[str] = None self._label: Optional[str] = None
self._our_fut: "Future[Line]" = Future() self._our_fut: "Future[Line]" = Future()
def __await__(self) -> Generator[Any, None, Line]: def __await__(self) -> Generator[Any, None, Line]:
@ -32,11 +32,9 @@ class WaitFor(object):
self._label = label self._label = label
def match(self, server: IServer, line: Line): def match(self, server: IServer, line: Line):
if (self._label is not None and if self._label is not None and line.tags is not None:
line.tags is not None):
label = TAG_LABEL.get(line.tags) label = TAG_LABEL.get(line.tags)
if (label is not None and if label is not None and label == self._label:
label == self._label):
return True return True
return self.response.match(server, line) return self.response.match(server, line)

View File

@ -4,10 +4,11 @@ from typing import Dict
from ircstates.server import ServerDisconnectedException from ircstates.server import ServerDisconnectedException
from .server import ConnectionParams, Server from .server import ConnectionParams, Server
from .transport import TCPTransport from .transport import TCPTransport
from .interface import IBot, IServer, ITCPTransport from .interface import IBot, IServer, ITCPTransport
class Bot(IBot): class Bot(IBot):
def __init__(self): def __init__(self):
self.servers: Dict[str, Server] = {} self.servers: Dict[str, Server] = {}
@ -17,9 +18,11 @@ class Bot(IBot):
return Server(self, name) return Server(self, name)
async def disconnected(self, server: IServer): async def disconnected(self, server: IServer):
if (server.name in self.servers and if (
server.params is not None and server.name in self.servers
server.disconnected): and server.params is not None
and server.disconnected
):
reconnect = server.params.reconnect reconnect = server.params.reconnect
@ -30,7 +33,7 @@ class Bot(IBot):
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
# let's try again, exponential backoff up to 5 mins # let's try again, exponential backoff up to 5 mins
reconnect = min(reconnect*2, 300) reconnect = min(reconnect * 2, 300)
else: else:
break break
@ -38,10 +41,12 @@ class Bot(IBot):
del self.servers[server.name] del self.servers[server.name]
await server.disconnect() await server.disconnect()
async def add_server(self, async def add_server(
name: str, self,
params: ConnectionParams, name: str,
transport: ITCPTransport = TCPTransport()) -> Server: params: ConnectionParams,
transport: ITCPTransport = TCPTransport(),
) -> Server:
server = self.create_server(name) server = self.create_server(name)
self.servers[name] = server self.servers[name] = server
await server.connect(transport, params) await server.connect(transport, params)

View File

@ -1,6 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from .interface import IServer from .interface import IServer
@dataclass @dataclass
class ServerContext(object): class ServerContext(object):
server: IServer server: IServer

View File

@ -1,19 +1,14 @@
from typing import List from typing import List
BOLD = "\x02" BOLD = "\x02"
COLOR = "\x03" COLOR = "\x03"
INVERT = "\x16" INVERT = "\x16"
ITALIC = "\x1D" ITALIC = "\x1D"
UNDERLINE = "\x1F" UNDERLINE = "\x1F"
RESET = "\x0F" RESET = "\x0F"
FORMATTERS = [BOLD, INVERT, ITALIC, UNDERLINE, RESET]
FORMATTERS = [
BOLD,
INVERT,
ITALIC,
UNDERLINE,
RESET
]
def tokens(s: str) -> List[str]: def tokens(s: str) -> List[str]:
tokens: List[str] = [] tokens: List[str] = []
@ -25,9 +20,7 @@ def tokens(s: str) -> List[str]:
for i in range(2): for i in range(2):
if s_copy and s_copy[0].isdigit(): if s_copy and s_copy[0].isdigit():
token += s_copy.pop(0) token += s_copy.pop(0)
if (len(s_copy) > 1 and if len(s_copy) > 1 and s_copy[0] == "," and s_copy[1].isdigit():
s_copy[0] == "," and
s_copy[1].isdigit()):
token += s_copy.pop(0) token += s_copy.pop(0)
token += s_copy.pop(0) token += s_copy.pop(0)
if s_copy and s_copy[0].isdigit(): if s_copy and s_copy[0].isdigit():
@ -38,6 +31,7 @@ def tokens(s: str) -> List[str]:
tokens.append(token) tokens.append(token)
return tokens return tokens
def strip(s: str): def strip(s: str):
for token in tokens(s): for token in tokens(s):
s = s.replace(token, "", 1) s = s.replace(token, "", 1)

View File

@ -1,4 +1,3 @@
def collapse(pattern: str) -> str: def collapse(pattern: str) -> str:
out = "" out = ""
i = 0 i = 0
@ -15,9 +14,10 @@ def collapse(pattern: str) -> str:
if pattern[i:]: if pattern[i:]:
out += pattern[i] out += pattern[i]
i += 1 i += 1
return out return out
def _match(pattern: str, s: str): def _match(pattern: str, s: str):
i, j = 0, 0 i, j = 0, 0
@ -45,10 +45,14 @@ def _match(pattern: str, s: str):
return i == len(pattern) return i == len(pattern)
class Glob(object): class Glob(object):
def __init__(self, pattern: str): def __init__(self, pattern: str):
self._pattern = pattern self._pattern = pattern
def match(self, s: str) -> bool: def match(self, s: str) -> bool:
return _match(self._pattern, s) return _match(self._pattern, s)
def compile(pattern: str) -> Glob: def compile(pattern: str) -> Glob:
return Glob(collapse(pattern)) return Glob(collapse(pattern))

View File

@ -1,16 +1,19 @@
from asyncio import Future from asyncio import Future
from typing import Awaitable, Iterable, List, Optional, Set, Tuple, Union from typing import Awaitable, Iterable, List, Optional, Set, Tuple, Union
from enum import IntEnum from enum import IntEnum
from ircstates import Server, Emit from ircstates import Server, Emit
from irctokens import Line, Hostmask from irctokens import Line, Hostmask
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
from .security import TLS from .security import TLS
class ITCPReader(object): class ITCPReader(object):
async def read(self, byte_count: int): async def read(self, byte_count: int):
pass pass
class ITCPWriter(object): class ITCPWriter(object):
def write(self, data: bytes): def write(self, data: bytes):
pass pass
@ -20,37 +23,40 @@ class ITCPWriter(object):
async def drain(self): async def drain(self):
pass pass
async def close(self): async def close(self):
pass pass
class ITCPTransport(object): class ITCPTransport(object):
async def connect(self, async def connect(
hostname: str, self,
port: int, hostname: str,
tls: Optional[TLS], port: int,
bindhost: Optional[str]=None tls: Optional[TLS],
) -> Tuple[ITCPReader, ITCPWriter]: bindhost: Optional[str] = None,
) -> Tuple[ITCPReader, ITCPWriter]:
pass pass
class SendPriority(IntEnum): class SendPriority(IntEnum):
HIGH = 0 HIGH = 0
MEDIUM = 10 MEDIUM = 10
LOW = 20 LOW = 20
DEFAULT = MEDIUM DEFAULT = MEDIUM
class SentLine(object): class SentLine(object):
def __init__(self, def __init__(self, id: int, priority: int, line: Line):
id: int, self.id = id
priority: int, self.priority = priority
line: Line): self.line = line
self.id = id
self.priority = priority
self.line = line
self.future: "Future[SentLine]" = Future() self.future: "Future[SentLine]" = Future()
def __lt__(self, other: "SentLine") -> bool: def __lt__(self, other: "SentLine") -> bool:
return self.priority < other.priority return self.priority < other.priority
class ICapability(object): class ICapability(object):
def available(self, capabilities: Iterable[str]) -> Optional[str]: def available(self, capabilities: Iterable[str]) -> Optional[str]:
pass pass
@ -61,38 +67,46 @@ class ICapability(object):
def copy(self) -> "ICapability": def copy(self) -> "ICapability":
pass pass
class IMatchResponse(object): class IMatchResponse(object):
def match(self, server: "IServer", line: Line) -> bool: def match(self, server: "IServer", line: Line) -> bool:
pass pass
class IMatchResponseParam(object): class IMatchResponseParam(object):
def match(self, server: "IServer", arg: str) -> bool: def match(self, server: "IServer", arg: str) -> bool:
pass pass
class IMatchResponseValueParam(IMatchResponseParam): class IMatchResponseValueParam(IMatchResponseParam):
def value(self, server: "IServer"): def value(self, server: "IServer"):
pass pass
def set_value(self, value: str): def set_value(self, value: str):
pass pass
class IMatchResponseHostmask(object): class IMatchResponseHostmask(object):
def match(self, server: "IServer", hostmask: Hostmask) -> bool: def match(self, server: "IServer", hostmask: Hostmask) -> bool:
pass pass
class IServer(Server): class IServer(Server):
bot: "IBot" bot: "IBot"
disconnected: bool disconnected: bool
params: ConnectionParams params: ConnectionParams
desired_caps: Set[ICapability] desired_caps: Set[ICapability]
last_read: float last_read: float
def send_raw(self, line: str, priority=SendPriority.DEFAULT def send_raw(self, line: str, priority=SendPriority.DEFAULT) -> Awaitable[SentLine]:
) -> Awaitable[SentLine]:
pass
def send(self, line: Line, priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]:
pass pass
def wait_for(self, def send(self, line: Line, priority=SendPriority.DEFAULT) -> Awaitable[SentLine]:
response: Union[IMatchResponse, Set[IMatchResponse]] pass
) -> Awaitable[Line]:
def wait_for(
self, response: Union[IMatchResponse, Set[IMatchResponse]]
) -> Awaitable[Line]:
pass pass
def set_throttle(self, rate: int, time: float): def set_throttle(self, rate: int, time: float):
@ -101,37 +115,44 @@ class IServer(Server):
def server_address(self) -> Tuple[str, int]: def server_address(self) -> Tuple[str, int]:
pass pass
async def connect(self, async def connect(self, transport: ITCPTransport, params: ConnectionParams):
transport: ITCPTransport,
params: ConnectionParams):
pass pass
async def disconnect(self): async def disconnect(self):
pass pass
def line_preread(self, line: Line): def line_preread(self, line: Line):
pass pass
def line_presend(self, line: Line): def line_presend(self, line: Line):
pass pass
async def line_read(self, line: Line): async def line_read(self, line: Line):
pass pass
async def line_send(self, line: Line): async def line_send(self, line: Line):
pass pass
async def sts_policy(self, sts: STSPolicy): async def sts_policy(self, sts: STSPolicy):
pass pass
async def resume_policy(self, resume: ResumePolicy): async def resume_policy(self, resume: ResumePolicy):
pass pass
def cap_agreed(self, capability: ICapability) -> bool: def cap_agreed(self, capability: ICapability) -> bool:
pass pass
def cap_available(self, capability: ICapability) -> Optional[str]: def cap_available(self, capability: ICapability) -> Optional[str]:
pass pass
async def sasl_auth(self, sasl: SASLParams) -> bool: async def sasl_auth(self, sasl: SASLParams) -> bool:
pass pass
class IBot(object): class IBot(object):
def create_server(self, name: str) -> IServer: def create_server(self, name: str) -> IServer:
pass pass
async def disconnected(self, server: IServer): async def disconnected(self, server: IServer):
pass pass

View File

@ -1,22 +1,25 @@
from time import time from time import time
from typing import Dict, Iterable, List, Optional, Tuple from typing import Dict, Iterable, List, Optional, Tuple
from dataclasses import dataclass from dataclasses import dataclass
from irctokens import build from irctokens import build
from ircstates.server import ServerDisconnectedException from ircstates.server import ServerDisconnectedException
from .contexts import ServerContext from .contexts import ServerContext
from .matching import Response, ANY from .matching import Response, ANY
from .interface import ICapability from .interface import ICapability
from .params import ConnectionParams, STSPolicy, ResumePolicy from .params import ConnectionParams, STSPolicy, ResumePolicy
from .security import TLS_VERIFYCHAIN from .security import TLS_VERIFYCHAIN
class Capability(ICapability): class Capability(ICapability):
def __init__(self, def __init__(
ratified_name: Optional[str], self,
draft_name: Optional[str]=None, ratified_name: Optional[str],
alias: Optional[str]=None, draft_name: Optional[str] = None,
depends_on: List[str]=[]): alias: Optional[str] = None,
self.name = ratified_name depends_on: List[str] = [],
):
self.name = ratified_name
self.draft = draft_name self.draft = draft_name
self.alias = alias or ratified_name self.alias = alias or ratified_name
self.depends_on = depends_on.copy() self.depends_on = depends_on.copy()
@ -26,8 +29,7 @@ class Capability(ICapability):
def match(self, capability: str) -> bool: def match(self, capability: str) -> bool:
return capability in self._caps return capability in self._caps
def available(self, capabilities: Iterable[str] def available(self, capabilities: Iterable[str]) -> Optional[str]:
) -> Optional[str]:
for cap in self._caps: for cap in self._caps:
if not cap is None and cap in capabilities: if not cap is None and cap in capabilities:
return cap return cap
@ -36,16 +38,13 @@ class Capability(ICapability):
def copy(self): def copy(self):
return Capability( return Capability(
self.name, self.name, self.draft, alias=self.alias, depends_on=self.depends_on[:]
self.draft, )
alias=self.alias,
depends_on=self.depends_on[:])
class MessageTag(object): class MessageTag(object):
def __init__(self, def __init__(self, name: Optional[str], draft_name: Optional[str] = None):
name: Optional[str], self.name = name
draft_name: Optional[str]=None):
self.name = name
self.draft = draft_name self.draft = draft_name
self._tags = [self.name, self.draft] self._tags = [self.name, self.draft]
@ -63,37 +62,36 @@ class MessageTag(object):
else: else:
return None return None
CAP_SASL = Capability("sasl")
CAP_ECHO = Capability("echo-message") CAP_SASL = Capability("sasl")
CAP_STS = Capability("sts", "draft/sts") CAP_ECHO = Capability("echo-message")
CAP_STS = Capability("sts", "draft/sts")
CAP_RESUME = Capability(None, "draft/resume-0.5", alias="resume") CAP_RESUME = Capability(None, "draft/resume-0.5", alias="resume")
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2") CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
TAG_LABEL = MessageTag("label", "draft/label") TAG_LABEL = MessageTag("label", "draft/label")
LABEL_TAG_MAP = { LABEL_TAG_MAP = {
"draft/labeled-response-0.2": "draft/label", "draft/labeled-response-0.2": "draft/label",
"labeled-response": "label" "labeled-response": "label",
} }
CAPS: List[ICapability] = [ CAPS: List[ICapability] = [
Capability("multi-prefix"), Capability("multi-prefix"),
Capability("chghost"), Capability("chghost"),
Capability("away-notify"), Capability("away-notify"),
Capability("invite-notify"), Capability("invite-notify"),
Capability("account-tag"), Capability("account-tag"),
Capability("account-notify"), Capability("account-notify"),
Capability("extended-join"), Capability("extended-join"),
Capability("message-tags", "draft/message-tags-0.2"), Capability("message-tags", "draft/message-tags-0.2"),
Capability("cap-notify"), Capability("cap-notify"),
Capability("batch"), Capability("batch"),
Capability(None, "draft/rename", alias="rename"), Capability(None, "draft/rename", alias="rename"),
Capability("setname", "draft/setname"), Capability("setname", "draft/setname"),
CAP_RESUME CAP_RESUME,
] ]
def _cap_dict(s: str) -> Dict[str, str]: def _cap_dict(s: str) -> Dict[str, str]:
d: Dict[str, str] = {} d: Dict[str, str] = {}
for token in s.split(","): for token in s.split(","):
@ -101,41 +99,44 @@ def _cap_dict(s: str) -> Dict[str, str]:
d[key] = value d[key] = value
return d return d
async def sts_transmute(params: ConnectionParams): async def sts_transmute(params: ConnectionParams):
if not params.sts is None and params.tls is None: if not params.sts is None and params.tls is None:
now = time() now = time()
since = (now-params.sts.created) since = now - params.sts.created
if since <= params.sts.duration: if since <= params.sts.duration:
params.port = params.sts.port params.port = params.sts.port
params.tls = TLS_VERIFYCHAIN params.tls = TLS_VERIFYCHAIN
async def resume_transmute(params: ConnectionParams): async def resume_transmute(params: ConnectionParams):
if params.resume is not None: if params.resume is not None:
params.host = params.resume.address params.host = params.resume.address
class HandshakeCancel(Exception): class HandshakeCancel(Exception):
pass pass
class CAPContext(ServerContext): class CAPContext(ServerContext):
async def on_ls(self, tokens: Dict[str, str]): async def on_ls(self, tokens: Dict[str, str]):
await self._sts(tokens) await self._sts(tokens)
caps = list(self.server.desired_caps)+CAPS caps = list(self.server.desired_caps) + CAPS
if (not self.server.params.sasl is None and if not self.server.params.sasl is None and not CAP_SASL in caps:
not CAP_SASL in caps):
caps.append(CAP_SASL) caps.append(CAP_SASL)
matched = (c.available(tokens) for c in caps) matched = (c.available(tokens) for c in caps)
cap_names = [name for name in matched if not name is None] cap_names = [name for name in matched if not name is None]
if cap_names: if cap_names:
await self.server.send(build("CAP", ["REQ", " ".join(cap_names)])) await self.server.send(build("CAP", ["REQ", " ".join(cap_names)]))
while cap_names: while cap_names:
line = await self.server.wait_for({ line = await self.server.wait_for(
Response("CAP", [ANY, "ACK"]), {Response("CAP", [ANY, "ACK"]), Response("CAP", [ANY, "NAK"])}
Response("CAP", [ANY, "NAK"]) )
})
current_caps = line.params[2].split(" ") current_caps = line.params[2].split(" ")
for cap in current_caps: for cap in current_caps:
@ -144,8 +145,7 @@ class CAPContext(ServerContext):
if CAP_RESUME.available(current_caps): if CAP_RESUME.available(current_caps):
await self.resume_token() await self.resume_token()
if (self.server.cap_agreed(CAP_SASL) and if self.server.cap_agreed(CAP_SASL) and not self.server.params.sasl is None:
not self.server.params.sasl is None):
await self.server.sasl_auth(self.server.params.sasl) await self.server.sasl_auth(self.server.params.sasl)
async def resume_token(self): async def resume_token(self):
@ -160,10 +160,9 @@ class CAPContext(ServerContext):
if previous_policy is not None and not self.server.registered: if previous_policy is not None and not self.server.registered:
await self.server.send(build("RESUME", [previous_policy.token])) await self.server.send(build("RESUME", [previous_policy.token]))
line = await self.server.wait_for({ line = await self.server.wait_for(
Response("RESUME", ["SUCCESS"]), {Response("RESUME", ["SUCCESS"]), Response("FAIL", ["RESUME"])}
Response("FAIL", ["RESUME"]) )
})
if line.command == "RESUME": if line.command == "RESUME":
raise HandshakeCancel() raise HandshakeCancel()
@ -179,11 +178,11 @@ class CAPContext(ServerContext):
cap_sts = CAP_STS.available(tokens) cap_sts = CAP_STS.available(tokens)
if not cap_sts is None: if not cap_sts is None:
sts_dict = _cap_dict(tokens[cap_sts]) sts_dict = _cap_dict(tokens[cap_sts])
params = self.server.params params = self.server.params
if not params.tls: if not params.tls:
if "port" in sts_dict: if "port" in sts_dict:
params.port = int(sts_dict["port"]) params.port = int(sts_dict["port"])
params.tls = TLS_VERIFYCHAIN params.tls = TLS_VERIFYCHAIN
await self.server.bot.disconnect(self.server) await self.server.bot.disconnect(self.server)
await self.server.bot.add_server(self.server.name, params) await self.server.bot.add_server(self.server.name, params)
@ -194,6 +193,6 @@ class CAPContext(ServerContext):
int(time()), int(time()),
params.port, params.port,
int(sts_dict["duration"]), int(sts_dict["duration"]),
"preload" in sts_dict) "preload" in sts_dict,
)
await self.server.sts_policy(policy) await self.server.sts_policy(policy)

View File

@ -1,3 +1,2 @@
from .responses import * from .responses import *
from .params import * from .params import *

View File

@ -1,16 +1,24 @@
from re import compile as re_compile from re import compile as re_compile
from typing import Optional, Pattern, Union from typing import Optional, Pattern, Union
from irctokens import Hostmask from irctokens import Hostmask
from ..interface import (IMatchResponseParam, IMatchResponseValueParam, from ..interface import (
IMatchResponseHostmask, IServer) IMatchResponseParam,
from ..glob import Glob, compile as glob_compile IMatchResponseValueParam,
IMatchResponseHostmask,
IServer,
)
from ..glob import Glob, compile as glob_compile
from .. import formatting from .. import formatting
class Any(IMatchResponseParam): class Any(IMatchResponseParam):
def __repr__(self) -> str: def __repr__(self) -> str:
return "Any()" return "Any()"
def match(self, server: IServer, arg: str) -> bool: def match(self, server: IServer, arg: str) -> bool:
return True return True
ANY = Any() ANY = Any()
# NOT # NOT
@ -18,107 +26,142 @@ ANY = Any()
# REGEX # REGEX
# LITERAL # LITERAL
class Literal(IMatchResponseValueParam): class Literal(IMatchResponseValueParam):
def __init__(self, value: str): def __init__(self, value: str):
self._value = value self._value = value
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self._value!r}" return f"{self._value!r}"
def value(self, server: IServer) -> str: def value(self, server: IServer) -> str:
return self._value return self._value
def set_value(self, value: str): def set_value(self, value: str):
self._value = value self._value = value
def match(self, server: IServer, arg: str) -> bool: def match(self, server: IServer, arg: str) -> bool:
return arg == self._value return arg == self._value
TYPE_MAYBELIT = Union[str, IMatchResponseParam]
TYPE_MAYBELIT = Union[str, IMatchResponseParam]
TYPE_MAYBELIT_VALUE = Union[str, IMatchResponseValueParam] TYPE_MAYBELIT_VALUE = Union[str, IMatchResponseValueParam]
def _assure_lit(value: TYPE_MAYBELIT_VALUE) -> IMatchResponseValueParam: def _assure_lit(value: TYPE_MAYBELIT_VALUE) -> IMatchResponseValueParam:
if isinstance(value, str): if isinstance(value, str):
return Literal(value) return Literal(value)
else: else:
return value return value
class Not(IMatchResponseParam): class Not(IMatchResponseParam):
def __init__(self, param: IMatchResponseParam): def __init__(self, param: IMatchResponseParam):
self._param = param self._param = param
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Not({self._param!r})" return f"Not({self._param!r})"
def match(self, server: IServer, arg: str) -> bool: def match(self, server: IServer, arg: str) -> bool:
return not self._param.match(server, arg) return not self._param.match(server, arg)
class ParamValuePassthrough(IMatchResponseValueParam): class ParamValuePassthrough(IMatchResponseValueParam):
_value: IMatchResponseValueParam _value: IMatchResponseValueParam
def value(self, server: IServer): def value(self, server: IServer):
return self._value.value(server) return self._value.value(server)
def set_value(self, value: str): def set_value(self, value: str):
self._value.set_value(value) self._value.set_value(value)
class Folded(ParamValuePassthrough): class Folded(ParamValuePassthrough):
def __init__(self, value: TYPE_MAYBELIT_VALUE): def __init__(self, value: TYPE_MAYBELIT_VALUE):
self._value = _assure_lit(value) self._value = _assure_lit(value)
self._folded = False self._folded = False
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Folded({self._value!r})" return f"Folded({self._value!r})"
def match(self, server: IServer, arg: str) -> bool: def match(self, server: IServer, arg: str) -> bool:
if not self._folded: if not self._folded:
value = self.value(server) value = self.value(server)
folded = server.casefold(value) folded = server.casefold(value)
self.set_value(folded) self.set_value(folded)
self._folded = True self._folded = True
return self._value.match(server, server.casefold(arg)) return self._value.match(server, server.casefold(arg))
class Formatless(IMatchResponseParam): class Formatless(IMatchResponseParam):
def __init__(self, value: TYPE_MAYBELIT_VALUE): def __init__(self, value: TYPE_MAYBELIT_VALUE):
self._value = _assure_lit(value) self._value = _assure_lit(value)
def __repr__(self) -> str: def __repr__(self) -> str:
brepr = super().__repr__() brepr = super().__repr__()
return f"Formatless({brepr})" return f"Formatless({brepr})"
def match(self, server: IServer, arg: str) -> bool: def match(self, server: IServer, arg: str) -> bool:
strip = formatting.strip(arg) strip = formatting.strip(arg)
return self._value.match(server, strip) return self._value.match(server, strip)
class Regex(IMatchResponseParam): class Regex(IMatchResponseParam):
def __init__(self, value: str): def __init__(self, value: str):
self._value = value self._value = value
self._pattern: Optional[Pattern] = None self._pattern: Optional[Pattern] = None
def match(self, server: IServer, arg: str) -> bool: def match(self, server: IServer, arg: str) -> bool:
if self._pattern is None: if self._pattern is None:
self._pattern = re_compile(self._value) self._pattern = re_compile(self._value)
return bool(self._pattern.search(arg)) return bool(self._pattern.search(arg))
class Self(IMatchResponseParam): class Self(IMatchResponseParam):
def __repr__(self) -> str: def __repr__(self) -> str:
return "Self()" return "Self()"
def match(self, server: IServer, arg: str) -> bool: def match(self, server: IServer, arg: str) -> bool:
return server.casefold(arg) == server.nickname_lower return server.casefold(arg) == server.nickname_lower
SELF = Self() SELF = Self()
class MaskSelf(IMatchResponseHostmask): class MaskSelf(IMatchResponseHostmask):
def __repr__(self) -> str: def __repr__(self) -> str:
return "MaskSelf()" return "MaskSelf()"
def match(self, server: IServer, hostmask: Hostmask): def match(self, server: IServer, hostmask: Hostmask):
return server.casefold(hostmask.nickname) == server.nickname_lower return server.casefold(hostmask.nickname) == server.nickname_lower
MASK_SELF = MaskSelf() MASK_SELF = MaskSelf()
class Nick(IMatchResponseHostmask): class Nick(IMatchResponseHostmask):
def __init__(self, nickname: str): def __init__(self, nickname: str):
self._nickname = nickname self._nickname = nickname
self._folded: Optional[str] = None self._folded: Optional[str] = None
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Nick({self._nickname!r})" return f"Nick({self._nickname!r})"
def match(self, server: IServer, hostmask: Hostmask): def match(self, server: IServer, hostmask: Hostmask):
if self._folded is None: if self._folded is None:
self._folded = server.casefold(self._nickname) self._folded = server.casefold(self._nickname)
return self._folded == server.casefold(hostmask.nickname) return self._folded == server.casefold(hostmask.nickname)
class Mask(IMatchResponseHostmask): class Mask(IMatchResponseHostmask):
def __init__(self, mask: str): def __init__(self, mask: str):
self._mask = mask self._mask = mask
self._compiled: Optional[Glob] self._compiled: Optional[Glob]
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Mask({self._mask!r})" return f"Mask({self._mask!r})"
def match(self, server: IServer, hostmask: Hostmask): def match(self, server: IServer, hostmask: Hostmask):
if self._compiled is None: if self._compiled is None:
self._compiled = glob_compile(self._mask) self._compiled = glob_compile(self._mask)

View File

@ -1,17 +1,25 @@
from typing import List, Optional, Sequence, Union from typing import List, Optional, Sequence, Union
from irctokens import Line from irctokens import Line
from ..interface import (IServer, IMatchResponse, IMatchResponseParam, from ..interface import (
IMatchResponseHostmask) IServer,
from .params import * IMatchResponse,
IMatchResponseParam,
IMatchResponseHostmask,
)
from .params import *
TYPE_PARAM = Union[str, IMatchResponseParam] TYPE_PARAM = Union[str, IMatchResponseParam]
class Responses(IMatchResponse): class Responses(IMatchResponse):
def __init__(self, def __init__(
commands: Sequence[str], self,
params: Sequence[TYPE_PARAM]=[], commands: Sequence[str],
source: Optional[IMatchResponseHostmask]=None): params: Sequence[TYPE_PARAM] = [],
source: Optional[IMatchResponseHostmask] = None,
):
self._commands = commands self._commands = commands
self._source = source self._source = source
self._params: Sequence[IMatchResponseParam] = [] self._params: Sequence[IMatchResponseParam] = []
for param in params: for param in params:
@ -25,36 +33,43 @@ class Responses(IMatchResponse):
def match(self, server: IServer, line: Line) -> bool: def match(self, server: IServer, line: Line) -> bool:
for command in self._commands: for command in self._commands:
if (line.command == command and ( if line.command == command and (
self._source is None or ( self._source is None
line.hostmask is not None and or (
self._source.match(server, line.hostmask) line.hostmask is not None
))): and self._source.match(server, line.hostmask)
)
):
for i, param in enumerate(self._params): for i, param in enumerate(self._params):
if (i >= len(line.params) or if i >= len(line.params) or not param.match(server, line.params[i]):
not param.match(server, line.params[i])):
break break
else: else:
return True return True
else: else:
return False return False
class Response(Responses): class Response(Responses):
def __init__(self, def __init__(
command: str, self,
params: Sequence[TYPE_PARAM]=[], command: str,
source: Optional[IMatchResponseHostmask]=None): params: Sequence[TYPE_PARAM] = [],
source: Optional[IMatchResponseHostmask] = None,
):
super().__init__([command], params, source=source) super().__init__([command], params, source=source)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Response({self._commands[0]}: {self._params!r})" return f"Response({self._commands[0]}: {self._params!r})"
class ResponseOr(IMatchResponse): class ResponseOr(IMatchResponse):
def __init__(self, *responses: IMatchResponse): def __init__(self, *responses: IMatchResponse):
self._responses = responses self._responses = responses
def __repr__(self) -> str: def __repr__(self) -> str:
return f"ResponseOr({self._responses!r})" return f"ResponseOr({self._responses!r})"
def match(self, server: IServer, line: Line) -> bool: def match(self, server: IServer, line: Line) -> bool:
for response in self._responses: for response in self._responses:
if response.match(server, line): if response.match(server, line):

View File

@ -1,74 +1,80 @@
from re import compile as re_compile from re import compile as re_compile
from typing import List, Optional from typing import List, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from .security import TLS, TLS_NOVERIFY, TLS_VERIFYCHAIN from .security import TLS, TLS_NOVERIFY, TLS_VERIFYCHAIN
class SASLParams(object): class SASLParams(object):
mechanism: str mechanism: str
@dataclass @dataclass
class _SASLUserPass(SASLParams): class _SASLUserPass(SASLParams):
username: str username: str
password: str password: str
class SASLUserPass(_SASLUserPass): class SASLUserPass(_SASLUserPass):
mechanism = "USERPASS" mechanism = "USERPASS"
class SASLSCRAM(_SASLUserPass): class SASLSCRAM(_SASLUserPass):
mechanism = "SCRAM" mechanism = "SCRAM"
class SASLExternal(SASLParams): class SASLExternal(SASLParams):
mechanism = "EXTERNAL" mechanism = "EXTERNAL"
@dataclass @dataclass
class STSPolicy(object): class STSPolicy(object):
created: int created: int
port: int port: int
duration: int duration: int
preload: bool preload: bool
@dataclass @dataclass
class ResumePolicy(object): class ResumePolicy(object):
address: str address: str
token: str token: str
RE_IPV6HOST = re_compile("\[([a-fA-F0-9:]+)\]") RE_IPV6HOST = re_compile("\[([a-fA-F0-9:]+)\]")
_TLS_TYPES = { _TLS_TYPES = {"+": TLS_VERIFYCHAIN, "~": TLS_NOVERIFY}
"+": TLS_VERIFYCHAIN,
"~": TLS_NOVERIFY
}
@dataclass @dataclass
class ConnectionParams(object): class ConnectionParams(object):
nickname: str nickname: str
host: str host: str
port: int port: int
tls: Optional[TLS] = TLS_VERIFYCHAIN tls: Optional[TLS] = TLS_VERIFYCHAIN
username: Optional[str] = None username: Optional[str] = None
realname: Optional[str] = None realname: Optional[str] = None
bindhost: Optional[str] = None bindhost: Optional[str] = None
password: Optional[str] = None password: Optional[str] = None
sasl: Optional[SASLParams] = None sasl: Optional[SASLParams] = None
sts: Optional[STSPolicy] = None sts: Optional[STSPolicy] = None
resume: Optional[ResumePolicy] = None resume: Optional[ResumePolicy] = None
reconnect: int = 10 # seconds reconnect: int = 10 # seconds
alt_nicknames: List[str] = field(default_factory=list) alt_nicknames: List[str] = field(default_factory=list)
autojoin: List[str] = field(default_factory=list) autojoin: List[str] = field(default_factory=list)
@staticmethod @staticmethod
def from_hoststring( def from_hoststring(nickname: str, hoststring: str) -> "ConnectionParams":
nickname: str,
hoststring: str
) -> "ConnectionParams":
ipv6host = RE_IPV6HOST.search(hoststring) ipv6host = RE_IPV6HOST.search(hoststring)
if ipv6host is not None and ipv6host.start() == 0: if ipv6host is not None and ipv6host.start() == 0:
host = ipv6host.group(1) host = ipv6host.group(1)
port_s = hoststring[ipv6host.end()+1:] port_s = hoststring[ipv6host.end() + 1 :]
else: else:
host, _, port_s = hoststring.strip().partition(":") host, _, port_s = hoststring.strip().partition(":")

View File

@ -1,52 +1,62 @@
from typing import List from typing import List
from enum import Enum from enum import Enum
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from irctokens import build from irctokens import build
from ircstates.numerics import * from ircstates.numerics import *
from .matching import Responses, Response, ANY from .matching import Responses, Response, ANY
from .contexts import ServerContext from .contexts import ServerContext
from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal
from .scram import SCRAMContext, SCRAMAlgorithm from .scram import SCRAMContext, SCRAMAlgorithm
SASL_SCRAM_MECHANISMS = [ SASL_SCRAM_MECHANISMS = [
"SCRAM-SHA-512", "SCRAM-SHA-512",
"SCRAM-SHA-256", "SCRAM-SHA-256",
"SCRAM-SHA-1", "SCRAM-SHA-1",
] ]
SASL_USERPASS_MECHANISMS = SASL_SCRAM_MECHANISMS+["PLAIN"] SASL_USERPASS_MECHANISMS = SASL_SCRAM_MECHANISMS + ["PLAIN"]
class SASLResult(Enum): class SASLResult(Enum):
NONE = 0 NONE = 0
SUCCESS = 1 SUCCESS = 1
FAILURE = 2 FAILURE = 2
ALREADY = 3 ALREADY = 3
class SASLError(Exception): class SASLError(Exception):
pass pass
class SASLUnknownMechanismError(SASLError): class SASLUnknownMechanismError(SASLError):
pass pass
AUTH_BYTE_MAX = 400 AUTH_BYTE_MAX = 400
AUTHENTICATE_ANY = Response("AUTHENTICATE", [ANY]) AUTHENTICATE_ANY = Response("AUTHENTICATE", [ANY])
NUMERICS_FAIL = Response(ERR_SASLFAIL) NUMERICS_FAIL = Response(ERR_SASLFAIL)
NUMERICS_INITIAL = Responses([ NUMERICS_INITIAL = Responses(
ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS, ERR_SASLABORTED [ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS, ERR_SASLABORTED]
]) )
NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL]) NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL])
def _b64e(s: str): def _b64e(s: str):
return b64encode(s.encode("utf8")).decode("ascii") return b64encode(s.encode("utf8")).decode("ascii")
def _b64eb(s: bytes) -> str: def _b64eb(s: bytes) -> str:
# encode-from-bytes # encode-from-bytes
return b64encode(s).decode("ascii") return b64encode(s).decode("ascii")
def _b64db(s: str) -> bytes: def _b64db(s: str) -> bytes:
# decode-to-bytes # decode-to-bytes
return b64decode(s) return b64decode(s)
class SASLContext(ServerContext): class SASLContext(ServerContext):
async def from_params(self, params: SASLParams) -> SASLResult: async def from_params(self, params: SASLParams) -> SASLResult:
if isinstance(params, SASLUserPass): if isinstance(params, SASLUserPass):
@ -57,15 +67,12 @@ class SASLContext(ServerContext):
return await self.external() return await self.external()
else: else:
raise SASLUnknownMechanismError( raise SASLUnknownMechanismError(
"SASLParams given with unknown mechanism " "SASLParams given with unknown mechanism " f"{params.mechanism!r}"
f"{params.mechanism!r}") )
async def external(self) -> SASLResult: async def external(self) -> SASLResult:
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"])) await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
line = await self.server.wait_for({ line = await self.server.wait_for({AUTHENTICATE_ANY, NUMERICS_INITIAL})
AUTHENTICATE_ANY,
NUMERICS_INITIAL
})
if line.command == "907": if line.command == "907":
# we've done SASL already. cleanly abort # we've done SASL already. cleanly abort
@ -73,8 +80,8 @@ class SASLContext(ServerContext):
elif line.command == "908": elif line.command == "908":
available = line.params[1].split(",") available = line.params[1].split(",")
raise SASLUnknownMechanismError( raise SASLUnknownMechanismError(
"Server does not support SASL EXTERNAL " "Server does not support SASL EXTERNAL " f"(it supports {available}"
f"(it supports {available}") )
elif line.command == "AUTHENTICATE" and line.params[0] == "+": elif line.command == "AUTHENTICATE" and line.params[0] == "+":
await self.server.send(build("AUTHENTICATE", ["+"])) await self.server.send(build("AUTHENTICATE", ["+"]))
@ -89,11 +96,12 @@ class SASLContext(ServerContext):
async def scram(self, username: str, password: str) -> SASLResult: async def scram(self, username: str, password: str) -> SASLResult:
return await self.userpass(username, password, SASL_SCRAM_MECHANISMS) return await self.userpass(username, password, SASL_SCRAM_MECHANISMS)
async def userpass(self, async def userpass(
username: str, self,
password: str, username: str,
mechanisms: List[str]=SASL_USERPASS_MECHANISMS password: str,
) -> SASLResult: mechanisms: List[str] = SASL_USERPASS_MECHANISMS,
) -> SASLResult:
def _common(server_mechs) -> List[str]: def _common(server_mechs) -> List[str]:
mechs: List[str] = [] mechs: List[str] = []
for our_mech in mechanisms: for our_mech in mechanisms:
@ -106,23 +114,21 @@ class SASLContext(ServerContext):
raise SASLUnknownMechanismError( raise SASLUnknownMechanismError(
"No matching SASL mechanims. " "No matching SASL mechanims. "
f"(we want: {mechanisms} " f"(we want: {mechanisms} "
f"server has: {server_mechs})") f"server has: {server_mechs})"
)
if self.server.available_caps["sasl"]: if self.server.available_caps["sasl"]:
# CAP v3.2 tells us what mechs it supports # CAP v3.2 tells us what mechs it supports
available = self.server.available_caps["sasl"].split(",") available = self.server.available_caps["sasl"].split(",")
match = _common(available) match = _common(available)
else: else:
# CAP v3.1 does not. pick the pick and wait for 907 to inform us of # CAP v3.1 does not. pick the pick and wait for 907 to inform us of
# what mechanisms are supported # what mechanisms are supported
match = mechanisms match = mechanisms
while match: while match:
await self.server.send(build("AUTHENTICATE", [match[0]])) await self.server.send(build("AUTHENTICATE", [match[0]]))
line = await self.server.wait_for({ line = await self.server.wait_for({AUTHENTICATE_ANY, NUMERICS_INITIAL})
AUTHENTICATE_ANY,
NUMERICS_INITIAL
})
if line.command == "907": if line.command == "907":
# we've done SASL already. cleanly abort # we've done SASL already. cleanly abort
@ -130,7 +136,7 @@ class SASLContext(ServerContext):
elif line.command == "908": elif line.command == "908":
# prior to CAP v3.2 - ERR telling us which mechs are supported # prior to CAP v3.2 - ERR telling us which mechs are supported
available = line.params[1].split(",") available = line.params[1].split(",")
match = _common(available) match = _common(available)
await self.server.wait_for(NUMERICS_FAIL) await self.server.wait_for(NUMERICS_FAIL)
elif line.command == "AUTHENTICATE" and line.params[0] == "+": elif line.command == "AUTHENTICATE" and line.params[0] == "+":
auth_text = "" auth_text = ""
@ -138,8 +144,7 @@ class SASLContext(ServerContext):
if match[0] == "PLAIN": if match[0] == "PLAIN":
auth_text = f"{username}\0{username}\0{password}" auth_text = f"{username}\0{username}\0{password}"
elif match[0].startswith("SCRAM-SHA-"): elif match[0].startswith("SCRAM-SHA-"):
auth_text = await self._scram( auth_text = await self._scram(match[0], username, password)
match[0], username, password)
if not auth_text == "+": if not auth_text == "+":
auth_text = _b64e(auth_text) auth_text = _b64e(auth_text)
@ -148,7 +153,7 @@ class SASLContext(ServerContext):
await self._send_auth_text(auth_text) await self._send_auth_text(auth_text)
line = await self.server.wait_for(NUMERICS_LAST) line = await self.server.wait_for(NUMERICS_LAST)
if line.command == "903": if line.command == "903":
return SASLResult.SUCCESS return SASLResult.SUCCESS
elif line.command == "904": elif line.command == "904":
match.pop(0) match.pop(0)
@ -157,11 +162,8 @@ class SASLContext(ServerContext):
return SASLResult.FAILURE return SASLResult.FAILURE
async def _scram(self, algo_str: str, async def _scram(self, algo_str: str, username: str, password: str) -> str:
username: str, algo_str_prep = algo_str.replace("SCRAM-", "", 1).replace("-", "").upper()
password: str) -> str:
algo_str_prep = algo_str.replace("SCRAM-", "", 1
).replace("-", "").upper()
try: try:
algo = SCRAMAlgorithm(algo_str_prep) algo = SCRAMAlgorithm(algo_str_prep)
except ValueError: except ValueError:
@ -179,15 +181,15 @@ class SASLContext(ServerContext):
line = await self.server.wait_for(AUTHENTICATE_ANY) line = await self.server.wait_for(AUTHENTICATE_ANY)
server_final = _b64db(line.params[0]) server_final = _b64db(line.params[0])
verified = scram.server_final(server_final) verified = scram.server_final(server_final)
#TODO PANIC if verified is false! # TODO PANIC if verified is false!
return "+" return "+"
else: else:
return "" return ""
async def _send_auth_text(self, text: str): async def _send_auth_text(self, text: str):
n = AUTH_BYTE_MAX n = AUTH_BYTE_MAX
chunks = [text[i:i+n] for i in range(0, len(text), n)] chunks = [text[i : i + n] for i in range(0, len(text), n)]
if len(chunks[-1]) == 400: if len(chunks[-1]) == 400:
chunks.append("+") chunks.append("+")

View File

@ -7,51 +7,60 @@ from typing import Dict
# https://www.iana.org/assignments/hash-function-text-names/ # https://www.iana.org/assignments/hash-function-text-names/
# MD2 has been removed as it's unacceptably weak # MD2 has been removed as it's unacceptably weak
class SCRAMAlgorithm(Enum): class SCRAMAlgorithm(Enum):
MD5 = "MD5" MD5 = "MD5"
SHA_1 = "SHA1" SHA_1 = "SHA1"
SHA_224 = "SHA224" SHA_224 = "SHA224"
SHA_256 = "SHA256" SHA_256 = "SHA256"
SHA_384 = "SHA384" SHA_384 = "SHA384"
SHA_512 = "SHA512" SHA_512 = "SHA512"
SCRAM_ERRORS = [ SCRAM_ERRORS = [
"invalid-encoding", "invalid-encoding",
"extensions-not-supported", # unrecognized 'm' value "extensions-not-supported", # unrecognized 'm' value
"invalid-proof", "invalid-proof",
"channel-bindings-dont-match", "channel-bindings-dont-match",
"server-does-support-channel-binding", "server-does-support-channel-binding",
"channel-binding-not-supported", "channel-binding-not-supported",
"unsupported-channel-binding-type", "unsupported-channel-binding-type",
"unknown-user", "unknown-user",
"invalid-username-encoding", # invalid utf8 or bad SASLprep "invalid-username-encoding", # invalid utf8 or bad SASLprep
"no-resources" "no-resources",
] ]
def _scram_nonce() -> bytes: def _scram_nonce() -> bytes:
return base64.b64encode(os.urandom(32)) return base64.b64encode(os.urandom(32))
def _scram_escape(s: bytes) -> bytes: def _scram_escape(s: bytes) -> bytes:
return s.replace(b"=", b"=3D").replace(b",", b"=2C") return s.replace(b"=", b"=3D").replace(b",", b"=2C")
def _scram_unescape(s: bytes) -> bytes: def _scram_unescape(s: bytes) -> bytes:
return s.replace(b"=3D", b"=").replace(b"=2C", b",") return s.replace(b"=3D", b"=").replace(b"=2C", b",")
def _scram_xor(s1: bytes, s2: bytes) -> bytes: def _scram_xor(s1: bytes, s2: bytes) -> bytes:
return bytes(a ^ b for a, b in zip(s1, s2)) return bytes(a ^ b for a, b in zip(s1, s2))
class SCRAMState(Enum): class SCRAMState(Enum):
NONE = 0 NONE = 0
CLIENT_FIRST = 1 CLIENT_FIRST = 1
CLIENT_FINAL = 2 CLIENT_FINAL = 2
SUCCESS = 3 SUCCESS = 3
FAILURE = 4 FAILURE = 4
VERIFY_FAILURE = 5 VERIFY_FAILURE = 5
class SCRAMError(Exception): class SCRAMError(Exception):
pass pass
class SCRAMContext(object): class SCRAMContext(object):
def __init__(self, algo: SCRAMAlgorithm, def __init__(self, algo: SCRAMAlgorithm, username: str, password: str):
username: str, self._algo = algo
password: str):
self._algo = algo
self._username = username.encode("utf8") self._username = username.encode("utf8")
self._password = password.encode("utf8") self._password = password.encode("utf8")
@ -59,11 +68,11 @@ class SCRAMContext(object):
self.error = "" self.error = ""
self.raw_error = "" self.raw_error = ""
self._client_first = b"" self._client_first = b""
self._client_nonce = b"" self._client_nonce = b""
self._salted_password = b"" self._salted_password = b""
self._auth_message = b"" self._auth_message = b""
def _get_pieces(self, data: bytes) -> Dict[bytes, bytes]: def _get_pieces(self, data: bytes) -> Dict[bytes, bytes]:
pieces = (piece.split(b"=", 1) for piece in data.split(b",")) pieces = (piece.split(b"=", 1) for piece in data.split(b","))
@ -71,6 +80,7 @@ class SCRAMContext(object):
def _hmac(self, key: bytes, msg: bytes) -> bytes: def _hmac(self, key: bytes, msg: bytes) -> bytes:
return hmac.new(key, msg, self._algo.value).digest() return hmac.new(key, msg, self._algo.value).digest()
def _hash(self, msg: bytes) -> bytes: def _hash(self, msg: bytes) -> bytes:
return hashlib.new(self._algo.value, msg).digest() return hashlib.new(self._algo.value, msg).digest()
@ -89,7 +99,9 @@ class SCRAMContext(object):
self.state = SCRAMState.CLIENT_FIRST self.state = SCRAMState.CLIENT_FIRST
self._client_nonce = _scram_nonce() self._client_nonce = _scram_nonce()
self._client_first = b"n=%s,r=%s" % ( self._client_first = b"n=%s,r=%s" % (
_scram_escape(self._username), self._client_nonce) _scram_escape(self._username),
self._client_nonce,
)
# n,,n=<username>,r=<nonce> # n,,n=<username>,r=<nonce>
return b"n,,%s" % self._client_first return b"n,,%s" % self._client_first
@ -109,17 +121,17 @@ class SCRAMContext(object):
if self._assert_error(pieces): if self._assert_error(pieces):
return b"" return b""
nonce = pieces[b"r"] # server combines your nonce with it's own nonce = pieces[b"r"] # server combines your nonce with it's own
if (not nonce.startswith(self._client_nonce) or if not nonce.startswith(self._client_nonce) or nonce == self._client_nonce:
nonce == self._client_nonce):
self._fail("nonce-unacceptable") self._fail("nonce-unacceptable")
return b"" return b""
salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded
iterations = int(pieces[b"i"]) iterations = int(pieces[b"i"])
salted_password = hashlib.pbkdf2_hmac(self._algo.value, salted_password = hashlib.pbkdf2_hmac(
self._password, salt, iterations, dklen=None) self._algo.value, self._password, salt, iterations, dklen=None
)
self._salted_password = salted_password self._salted_password = salted_password
client_key = self._hmac(salted_password, b"Client Key") client_key = self._hmac(salted_password, b"Client Key")

View File

@ -1,26 +1,35 @@
import ssl import ssl
class TLS: class TLS:
pass pass
# tls without verification # tls without verification
class TLSNoVerify(TLS): class TLSNoVerify(TLS):
pass pass
TLS_NOVERIFY = TLSNoVerify() TLS_NOVERIFY = TLSNoVerify()
# verify via CAs # verify via CAs
class TLSVerifyChain(TLS): class TLSVerifyChain(TLS):
pass pass
TLS_VERIFYCHAIN = TLSVerifyChain() TLS_VERIFYCHAIN = TLSVerifyChain()
# verify by a pinned hash # verify by a pinned hash
class TLSVerifyHash(TLSNoVerify): class TLSVerifyHash(TLSNoVerify):
def __init__(self, sum: str): def __init__(self, sum: str):
self.sum = sum.lower() self.sum = sum.lower()
class TLSVerifySHA512(TLSVerifyHash): class TLSVerifySHA512(TLSVerifyHash):
pass pass
def tls_context(verify: bool=True) -> ssl.SSLContext:
def tls_context(verify: bool = True) -> ssl.SSLContext:
ctx = ssl.create_default_context() ctx = ssl.create_default_context()
if not verify: if not verify:
ctx.check_hostname = False ctx.check_hostname = False

View File

@ -1,36 +1,58 @@
import asyncio import asyncio
from asyncio import Future, PriorityQueue from asyncio import Future, PriorityQueue
from typing import (AsyncIterable, Awaitable, Deque, Dict, Iterable, List, from typing import (
Optional, Set, Tuple, Union) AsyncIterable,
Awaitable,
Deque,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
from collections import deque from collections import deque
from time import monotonic from time import monotonic
import anyio import anyio
from asyncio_rlock import RLock from asyncio_rlock import RLock
from asyncio_throttle import Throttler from asyncio_throttle import Throttler
from async_timeout import timeout as timeout_ from async_timeout import timeout as timeout_
from ircstates import Emit, Channel, ChannelUser from ircstates import Emit, Channel, ChannelUser
from ircstates.numerics import * from ircstates.numerics import *
from ircstates.server import ServerDisconnectedException from ircstates.server import ServerDisconnectedException
from ircstates.names import Name from ircstates.names import Name
from irctokens import build, Line, tokenise from irctokens import build, Line, tokenise
from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL, from .ircv3 import (
CAP_LABEL, LABEL_TAG_MAP, resume_transmute) CAPContext,
from .sasl import SASLContext, SASLResult sts_transmute,
from .matching import (ResponseOr, Responses, Response, ANY, SELF, MASK_SELF, CAP_ECHO,
Folded) CAP_SASL,
from .asyncs import MaybeAwait, WaitFor CAP_LABEL,
from .struct import Whois LABEL_TAG_MAP,
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy resume_transmute,
from .interface import (IBot, ICapability, IServer, SentLine, SendPriority, )
IMatchResponse) from .sasl import SASLContext, SASLResult
from .matching import ResponseOr, Responses, Response, ANY, SELF, MASK_SELF, Folded
from .asyncs import MaybeAwait, WaitFor
from .struct import Whois
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
from .interface import (
IBot,
ICapability,
IServer,
SentLine,
SendPriority,
IMatchResponse,
)
from .interface import ITCPTransport, ITCPReader, ITCPWriter from .interface import ITCPTransport, ITCPReader, ITCPWriter
THROTTLE_RATE = 4 # lines THROTTLE_RATE = 4 # lines
THROTTLE_TIME = 2 # seconds THROTTLE_TIME = 2 # seconds
PING_TIMEOUT = 60 # seconds PING_TIMEOUT = 60 # seconds
WAIT_TIMEOUT = 20 # seconds WAIT_TIMEOUT = 20 # seconds
JOIN_ERR_FIRST = [ JOIN_ERR_FIRST = [
ERR_NOSUCHCHANNEL, ERR_NOSUCHCHANNEL,
@ -41,13 +63,14 @@ JOIN_ERR_FIRST = [
ERR_INVITEONLYCHAN, ERR_INVITEONLYCHAN,
ERR_BADCHANNELKEY, ERR_BADCHANNELKEY,
ERR_NEEDREGGEDNICK, ERR_NEEDREGGEDNICK,
ERR_THROTTLE ERR_THROTTLE,
] ]
class Server(IServer): class Server(IServer):
_reader: ITCPReader _reader: ITCPReader
_writer: ITCPWriter _writer: ITCPWriter
params: ConnectionParams params: ConnectionParams
def __init__(self, bot: IBot, name: str): def __init__(self, bot: IBot, name: str):
super().__init__(name) super().__init__(name)
@ -58,23 +81,23 @@ class Server(IServer):
self.throttle = Throttler(rate_limit=100, period=1) self.throttle = Throttler(rate_limit=100, period=1)
self.sasl_state = SASLResult.NONE self.sasl_state = SASLResult.NONE
self.last_read = monotonic() self.last_read = monotonic()
self._sent_count: int = 0 self._sent_count: int = 0
self._send_queue: PriorityQueue[SentLine] = PriorityQueue() self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
self.desired_caps: Set[ICapability] = set([]) self.desired_caps: Set[ICapability] = set([])
self._read_queue: Deque[Line] = deque() self._read_queue: Deque[Line] = deque()
self._process_queue: Deque[Tuple[Line, Optional[Emit]]] = deque() self._process_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
self._ping_sent = False self._ping_sent = False
self._read_lguard = RLock() self._read_lguard = RLock()
self.read_lock = self._read_lguard self.read_lock = self._read_lguard
self._read_lwork = asyncio.Lock() self._read_lwork = asyncio.Lock()
self._wait_for = asyncio.Event() self._wait_for = asyncio.Event()
self._pending_who: Deque[str] = deque() self._pending_who: Deque[str] = deque()
self._alt_nicks: List[str] = [] self._alt_nicks: List[str] = []
def hostmask(self) -> str: def hostmask(self) -> str:
hostmask = self.nickname hostmask = self.nickname
@ -84,13 +107,10 @@ class Server(IServer):
hostmask += f"@{self.hostname}" hostmask += f"@{self.hostname}"
return hostmask return hostmask
def send_raw(self, line: str, priority=SendPriority.DEFAULT def send_raw(self, line: str, priority=SendPriority.DEFAULT) -> Awaitable[SentLine]:
) -> Awaitable[SentLine]:
return self.send(tokenise(line), priority) return self.send(tokenise(line), priority)
def send(self,
line: Line, def send(self, line: Line, priority=SendPriority.DEFAULT) -> Awaitable[SentLine]:
priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]:
self.line_presend(line) self.line_presend(line)
sent_line = SentLine(self._sent_count, priority, line) sent_line = SentLine(self._sent_count, priority, line)
@ -110,28 +130,25 @@ class Server(IServer):
def set_throttle(self, rate: int, time: float): def set_throttle(self, rate: int, time: float):
self.throttle.rate_limit = rate self.throttle.rate_limit = rate
self.throttle.period = time self.throttle.period = time
def server_address(self) -> Tuple[str, int]: def server_address(self) -> Tuple[str, int]:
return self._writer.get_peer() return self._writer.get_peer()
async def connect(self, async def connect(self, transport: ITCPTransport, params: ConnectionParams):
transport: ITCPTransport,
params: ConnectionParams):
await sts_transmute(params) await sts_transmute(params)
await resume_transmute(params) await resume_transmute(params)
reader, writer = await transport.connect( reader, writer = await transport.connect(
params.host, params.host, params.port, tls=params.tls, bindhost=params.bindhost
params.port, )
tls =params.tls,
bindhost =params.bindhost)
self._reader = reader self._reader = reader
self._writer = writer self._writer = writer
self.params = params self.params = params
await self.handshake() await self.handshake()
async def disconnect(self): async def disconnect(self):
if not self._writer is None: if not self._writer is None:
await self._writer.close() await self._writer.close()
@ -145,29 +162,35 @@ class Server(IServer):
alt_nicks = self.params.alt_nicknames alt_nicks = self.params.alt_nicknames
if not alt_nicks: if not alt_nicks:
alt_nicks = [nickname+"_"*i for i in range(1, 4)] alt_nicks = [nickname + "_" * i for i in range(1, 4)]
self._alt_nicks = alt_nicks self._alt_nicks = alt_nicks
# these must remain non-awaited; reading hasn't started yet # these must remain non-awaited; reading hasn't started yet
if not self.params.password is None: if not self.params.password is None:
self.send(build("PASS", [self.params.password])) self.send(build("PASS", [self.params.password]))
self.send(build("CAP", ["LS", "302"])) self.send(build("CAP", ["LS", "302"]))
self.send(build("NICK", [nickname])) self.send(build("NICK", [nickname]))
self.send(build("USER", [username, "0", "*", realname])) self.send(build("USER", [username, "0", "*", realname]))
# to be overridden # to be overridden
def line_preread(self, line: Line): def line_preread(self, line: Line):
pass pass
def line_presend(self, line: Line): def line_presend(self, line: Line):
pass pass
async def line_read(self, line: Line): async def line_read(self, line: Line):
pass pass
async def line_send(self, line: Line): async def line_send(self, line: Line):
pass pass
async def sts_policy(self, sts: STSPolicy): async def sts_policy(self, sts: STSPolicy):
pass pass
async def resume_policy(self, resume: ResumePolicy): async def resume_policy(self, resume: ResumePolicy):
pass pass
# /to be overriden # /to be overriden
async def _on_read(self, line: Line, emit: Optional[Emit]): async def _on_read(self, line: Line, emit: Optional[Emit]):
@ -176,13 +199,14 @@ class Server(IServer):
elif line.command == RPL_ENDOFWHO: elif line.command == RPL_ENDOFWHO:
chan = self.casefold(line.params[1]) chan = self.casefold(line.params[1])
if (self._pending_who and if self._pending_who and self._pending_who[0] == chan:
self._pending_who[0] == chan):
self._pending_who.popleft() self._pending_who.popleft()
await self._next_who() await self._next_who()
elif (line.command in { elif (
ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE line.command
} and not self.registered): in {ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE}
and not self.registered
):
if self._alt_nicks: if self._alt_nicks:
nick = self._alt_nicks.pop(0) nick = self._alt_nicks.pop(0)
await self.send(build("NICK", [nick])) await self.send(build("NICK", [nick]))
@ -203,8 +227,7 @@ class Server(IServer):
await self._check_regain([line.params[1]]) await self._check_regain([line.params[1]])
elif line.command == RPL_MONOFFLINE: elif line.command == RPL_MONOFFLINE:
await self._check_regain(line.params[1].split(",")) await self._check_regain(line.params[1].split(","))
elif (line.command in ["NICK", "QUIT"] and elif line.command in ["NICK", "QUIT"] and line.source is not None:
line.source is not None):
await self._check_regain([line.hostmask.nickname]) await self._check_regain([line.hostmask.nickname])
elif emit is not None: elif emit is not None:
@ -216,10 +239,9 @@ class Server(IServer):
await self._batch_joins(self.params.autojoin) await self._batch_joins(self.params.autojoin)
elif emit.command == "CAP": elif emit.command == "CAP":
if emit.subcommand == "NEW": if emit.subcommand == "NEW":
await self._cap_ls(emit) await self._cap_ls(emit)
elif (emit.subcommand == "LS" and elif emit.subcommand == "LS" and emit.finished:
emit.finished):
if not self.registered: if not self.registered:
await CAPContext(self).handshake() await CAPContext(self).handshake()
else: else:
@ -227,7 +249,7 @@ class Server(IServer):
elif emit.command == "JOIN": elif emit.command == "JOIN":
if emit.self and not emit.channel is None: if emit.self and not emit.channel is None:
chan = emit.channel.name_lower chan = emit.channel.name_lower
await self.send(build("MODE", [chan])) await self.send(build("MODE", [chan]))
modes = "".join(self.isupport.chanmodes.a_modes) modes = "".join(self.isupport.chanmodes.a_modes)
@ -241,18 +263,18 @@ class Server(IServer):
async def _check_regain(self, nicks: List[str]): async def _check_regain(self, nicks: List[str]):
for nick in nicks: for nick in nicks:
if (self.casefold_equals(nick, self.params.nickname) and if (
not self.nickname == self.params.nickname): self.casefold_equals(nick, self.params.nickname)
and not self.nickname == self.params.nickname
):
await self.send(build("NICK", [self.params.nickname])) await self.send(build("NICK", [self.params.nickname]))
async def _batch_joins(self, async def _batch_joins(self, channels: List[str], batch_n: int = 10):
channels: List[str], # TODO: do as many JOINs in one line as we can fit
batch_n: int=10): # TODO: channel keys
#TODO: do as many JOINs in one line as we can fit
#TODO: channel keys
for i in range(0, len(channels), batch_n): for i in range(0, len(channels), batch_n):
batch = channels[i:i+batch_n] batch = channels[i : i + batch_n]
await self.send(build("JOIN", [",".join(batch)])) await self.send(build("JOIN", [",".join(batch)]))
async def _next_who(self): async def _next_who(self):
@ -275,7 +297,7 @@ class Server(IServer):
return None return None
self.last_read = monotonic() self.last_read = monotonic()
lines = self.recv(data) lines = self.recv(data)
for line in lines: for line in lines:
self.line_preread(line) self.line_preread(line)
self._read_queue.append(line) self._read_queue.append(line)
@ -287,10 +309,10 @@ class Server(IServer):
if not self._process_queue: if not self._process_queue:
async with self._read_lwork: async with self._read_lwork:
read_aw = self._read_line(PING_TIMEOUT) read_aw = self._read_line(PING_TIMEOUT)
dones, notdones = await asyncio.wait( dones, notdones = await asyncio.wait(
[read_aw, self._wait_for.wait()], [read_aw, self._wait_for.wait()],
return_when=asyncio.FIRST_COMPLETED return_when=asyncio.FIRST_COMPLETED,
) )
self._wait_for.clear() self._wait_for.clear()
@ -314,11 +336,12 @@ class Server(IServer):
line, emit = self._process_queue.popleft() line, emit = self._process_queue.popleft()
await self._on_read(line, emit) await self._on_read(line, emit)
async def wait_for(self, async def wait_for(
response: Union[IMatchResponse, Set[IMatchResponse]], self,
sent_aw: Optional[Awaitable[SentLine]]=None, response: Union[IMatchResponse, Set[IMatchResponse]],
timeout: float=WAIT_TIMEOUT sent_aw: Optional[Awaitable[SentLine]] = None,
) -> Line: timeout: float = WAIT_TIMEOUT,
) -> Line:
response_obj: IMatchResponse response_obj: IMatchResponse
if isinstance(response, set): if isinstance(response, set):
@ -340,8 +363,9 @@ class Server(IServer):
return line return line
async def _on_send_line(self, line: Line): async def _on_send_line(self, line: Line):
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and if line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and not self.cap_agreed(
not self.cap_agreed(CAP_ECHO)): CAP_ECHO
):
new_line = line.with_source(self.hostmask()) new_line = line.with_source(self.hostmask())
self._read_queue.append(new_line) self._read_queue.append(new_line)
@ -349,15 +373,13 @@ class Server(IServer):
while True: while True:
lines: List[SentLine] = [] lines: List[SentLine] = []
while (not lines or while not lines or (len(lines) < 5 and self._send_queue.qsize() > 0):
(len(lines) < 5 and self._send_queue.qsize() > 0)):
prio_line = await self._send_queue.get() prio_line = await self._send_queue.get()
lines.append(prio_line) lines.append(prio_line)
for line in lines: for line in lines:
async with self.throttle: async with self.throttle:
self._writer.write( self._writer.write(f"{line.line.format()}\r\n".encode("utf8"))
f"{line.line.format()}\r\n".encode("utf8"))
await self._writer.drain() await self._writer.drain()
@ -369,6 +391,7 @@ class Server(IServer):
# CAP-related # CAP-related
def cap_agreed(self, capability: ICapability) -> bool: def cap_agreed(self, capability: ICapability) -> bool:
return bool(self.cap_available(capability)) return bool(self.cap_available(capability))
def cap_available(self, capability: ICapability) -> Optional[str]: def cap_available(self, capability: ICapability) -> Optional[str]:
return capability.available(self.agreed_caps) return capability.available(self.agreed_caps)
@ -381,78 +404,81 @@ class Server(IServer):
await CAPContext(self).on_ls(tokens) await CAPContext(self).on_ls(tokens)
async def sasl_auth(self, params: SASLParams) -> bool: async def sasl_auth(self, params: SASLParams) -> bool:
if (self.sasl_state == SASLResult.NONE and if self.sasl_state == SASLResult.NONE and self.cap_agreed(CAP_SASL):
self.cap_agreed(CAP_SASL)):
res = await SASLContext(self).from_params(params) res = await SASLContext(self).from_params(params)
self.sasl_state = res self.sasl_state = res
return True return True
else: else:
return False return False
# /CAP-related # /CAP-related
def send_nick(self, new_nick: str) -> Awaitable[bool]: def send_nick(self, new_nick: str) -> Awaitable[bool]:
fut = self.send(build("NICK", [new_nick])) fut = self.send(build("NICK", [new_nick]))
async def _assure() -> bool: async def _assure() -> bool:
line = await self.wait_for({ line = await self.wait_for(
Response("NICK", [Folded(new_nick)], source=MASK_SELF), {
Responses([ Response("NICK", [Folded(new_nick)], source=MASK_SELF),
ERR_BANNICKCHANGE, Responses(
ERR_NICKTOOFAST, [ERR_BANNICKCHANGE, ERR_NICKTOOFAST, ERR_CANTCHANGENICK], [ANY]
ERR_CANTCHANGENICK ),
], [ANY]), Responses(
Responses([ [ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE],
ERR_NICKNAMEINUSE, [ANY, Folded(new_nick)],
ERR_ERRONEUSNICKNAME, ),
ERR_UNAVAILRESOURCE },
], [ANY, Folded(new_nick)]) fut,
}, fut) )
return line.command == "NICK" return line.command == "NICK"
return MaybeAwait(_assure) return MaybeAwait(_assure)
def send_join(self, def send_join(self, name: str, key: Optional[str] = None) -> Awaitable[Channel]:
name: str,
key: Optional[str]=None
) -> Awaitable[Channel]:
fut = self.send_joins([name], [] if key is None else [key]) fut = self.send_joins([name], [] if key is None else [key])
async def _assure(): async def _assure():
channels = await fut channels = await fut
return channels[0] return channels[0]
return MaybeAwait(_assure) return MaybeAwait(_assure)
def send_part(self, name: str): def send_part(self, name: str):
fut = self.send(build("PART", [name])) fut = self.send(build("PART", [name]))
async def _assure(): async def _assure():
line = await self.wait_for( line = await self.wait_for(
Response("PART", [Folded(name)], source=MASK_SELF), Response("PART", [Folded(name)], source=MASK_SELF), fut
fut
) )
return return
return MaybeAwait(_assure) return MaybeAwait(_assure)
def send_joins(self, def send_joins(
names: List[str], self, names: List[str], keys: List[str] = []
keys: List[str]=[] ) -> Awaitable[List[Channel]]:
) -> Awaitable[List[Channel]]:
folded_names = [self.casefold(name) for name in names] folded_names = [self.casefold(name) for name in names]
if not keys: if not keys:
fut = self.send(build("JOIN", [",".join(names)])) fut = self.send(build("JOIN", [",".join(names)]))
else: else:
fut = self.send(build("JOIN", [",".join(names)]+keys)) fut = self.send(build("JOIN", [",".join(names)] + keys))
async def _assure(): async def _assure():
channels: List[Channel] = [] channels: List[Channel] = []
while folded_names: while folded_names:
line = await self.wait_for({ line = await self.wait_for(
Response(RPL_CHANNELMODEIS, [ANY, ANY]), {
Responses(JOIN_ERR_FIRST, [ANY, ANY]), Response(RPL_CHANNELMODEIS, [ANY, ANY]),
Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]), Responses(JOIN_ERR_FIRST, [ANY, ANY]),
Response(ERR_LINKCHANNEL, [ANY, ANY, ANY]) Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]),
}, fut) Response(ERR_LINKCHANNEL, [ANY, ANY, ANY]),
},
fut,
)
chan: Optional[str] = None chan: Optional[str] = None
if line.command == RPL_CHANNELMODEIS: if line.command == RPL_CHANNELMODEIS:
@ -462,7 +488,7 @@ class Server(IServer):
elif line.command == ERR_USERONCHANNEL: elif line.command == ERR_USERONCHANNEL:
chan = line.params[2] chan = line.params[2]
elif line.command == ERR_LINKCHANNEL: elif line.command == ERR_LINKCHANNEL:
#XXX i dont like this # XXX i dont like this
chan = line.params[2] chan = line.params[2]
await self.wait_for( await self.wait_for(
Response(RPL_CHANNELMODEIS, [ANY, Folded(chan)]) Response(RPL_CHANNELMODEIS, [ANY, Folded(chan)])
@ -477,51 +503,58 @@ class Server(IServer):
channels.append(self.channels[folded]) channels.append(self.channels[folded])
return channels return channels
return MaybeAwait(_assure) return MaybeAwait(_assure)
def send_message(self, target: str, message: str def send_message(self, target: str, message: str) -> Awaitable[Optional[str]]:
) -> Awaitable[Optional[str]]:
fut = self.send(build("PRIVMSG", [target, message])) fut = self.send(build("PRIVMSG", [target, message]))
async def _assure(): async def _assure():
line = await self.wait_for( line = await self.wait_for(
Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF), Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF), fut
fut
) )
if line.command == "PRIVMSG": if line.command == "PRIVMSG":
return line.params[1] return line.params[1]
else: else:
return None return None
return MaybeAwait(_assure) return MaybeAwait(_assure)
def send_whois(self, def send_whois(
target: str, self, target: str, remote: bool = False
remote: bool=False ) -> Awaitable[Optional[Whois]]:
) -> Awaitable[Optional[Whois]]:
args = [target] args = [target]
if remote: if remote:
args.append(target) args.append(target)
fut = self.send(build("WHOIS", args)) fut = self.send(build("WHOIS", args))
async def _assure() -> Optional[Whois]: async def _assure() -> Optional[Whois]:
folded = self.casefold(target) folded = self.casefold(target)
params = [ANY, Folded(folded)] params = [ANY, Folded(folded)]
obj = Whois() obj = Whois()
while True: while True:
line = await self.wait_for(Responses([ line = await self.wait_for(
ERR_NOSUCHNICK, Responses(
ERR_NOSUCHSERVER, [
RPL_WHOISUSER, ERR_NOSUCHNICK,
RPL_WHOISSERVER, ERR_NOSUCHSERVER,
RPL_WHOISOPERATOR, RPL_WHOISUSER,
RPL_WHOISIDLE, RPL_WHOISSERVER,
RPL_WHOISCHANNELS, RPL_WHOISOPERATOR,
RPL_WHOISHOST, RPL_WHOISIDLE,
RPL_WHOISACCOUNT, RPL_WHOISCHANNELS,
RPL_WHOISSECURE, RPL_WHOISHOST,
RPL_ENDOFWHOIS RPL_WHOISACCOUNT,
], params), fut) RPL_WHOISSECURE,
if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]: RPL_ENDOFWHOIS,
],
params,
),
fut,
)
if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]:
return None return None
elif line.command == RPL_WHOISUSER: elif line.command == RPL_WHOISUSER:
nick, user, host, _, real = line.params[1:] nick, user, host, _, real = line.params[1:]
@ -531,7 +564,7 @@ class Server(IServer):
obj.realname = real obj.realname = real
elif line.command == RPL_WHOISIDLE: elif line.command == RPL_WHOISIDLE:
idle, signon, _ = line.params[2:] idle, signon, _ = line.params[2:]
obj.idle = int(idle) obj.idle = int(idle)
obj.signon = int(signon) obj.signon = int(signon)
elif line.command == RPL_WHOISACCOUNT: elif line.command == RPL_WHOISACCOUNT:
obj.account = line.params[2] obj.account = line.params[2]
@ -544,11 +577,11 @@ class Server(IServer):
symbols = "" symbols = ""
while channel[0] in self.isupport.prefix.prefixes: while channel[0] in self.isupport.prefix.prefixes:
symbols += channel[0] symbols += channel[0]
channel = channel[1:] channel = channel[1:]
channel_user = ChannelUser( channel_user = ChannelUser(
Name(obj.nickname, folded), Name(obj.nickname, folded),
Name(channel, self.casefold(channel)) Name(channel, self.casefold(channel)),
) )
for symbol in symbols: for symbol in symbols:
mode = self.isupport.prefix.from_prefix(symbol) mode = self.isupport.prefix.from_prefix(symbol)
@ -558,4 +591,5 @@ class Server(IServer):
obj.channels.append(channel_user) obj.channels.append(channel_user)
elif line.command == RPL_ENDOFWHOIS: elif line.command == RPL_ENDOFWHOIS:
return obj return obj
return MaybeAwait(_assure) return MaybeAwait(_assure)

View File

@ -3,21 +3,21 @@ from dataclasses import dataclass
from ircstates import ChannelUser from ircstates import ChannelUser
class Whois(object): class Whois(object):
server: Optional[str] = None server: Optional[str] = None
server_info: Optional[str] = None server_info: Optional[str] = None
operator: bool = False operator: bool = False
secure: bool = False secure: bool = False
signon: Optional[int] = None signon: Optional[int] = None
idle: Optional[int] = None idle: Optional[int] = None
channels: Optional[List[ChannelUser]] = None channels: Optional[List[ChannelUser]] = None
nickname: str = "" nickname: str = ""
username: str = "" username: str = ""
hostname: str = "" hostname: str = ""
realname: str = "" realname: str = ""
account: Optional[str] = None account: Optional[str] = None

View File

@ -1,12 +1,12 @@
from hashlib import sha512 from hashlib import sha512
from ssl import SSLContext from ssl import SSLContext
from typing import Optional, Tuple from typing import Optional, Tuple
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
from async_stagger import open_connection from async_stagger import open_connection
from .interface import ITCPTransport, ITCPReader, ITCPWriter from .interface import ITCPTransport, ITCPReader, ITCPWriter
from .security import (tls_context, TLS, TLSNoVerify, TLSVerifyHash, from .security import tls_context, TLS, TLSNoVerify, TLSVerifyHash, TLSVerifySHA512
TLSVerifySHA512)
class TCPReader(ITCPReader): class TCPReader(ITCPReader):
def __init__(self, reader: StreamReader): def __init__(self, reader: StreamReader):
@ -14,6 +14,8 @@ class TCPReader(ITCPReader):
async def read(self, byte_count: int) -> bytes: async def read(self, byte_count: int) -> bytes:
return await self._reader.read(byte_count) return await self._reader.read(byte_count)
class TCPWriter(ITCPWriter): class TCPWriter(ITCPWriter):
def __init__(self, writer: StreamWriter): def __init__(self, writer: StreamWriter):
self._writer = writer self._writer = writer
@ -32,13 +34,15 @@ class TCPWriter(ITCPWriter):
self._writer.close() self._writer.close()
await self._writer.wait_closed() await self._writer.wait_closed()
class TCPTransport(ITCPTransport): class TCPTransport(ITCPTransport):
async def connect(self, async def connect(
hostname: str, self,
port: int, hostname: str,
tls: Optional[TLS], port: int,
bindhost: Optional[str]=None tls: Optional[TLS],
) -> Tuple[ITCPReader, ITCPWriter]: bindhost: Optional[str] = None,
) -> Tuple[ITCPReader, ITCPWriter]:
cur_ssl: Optional[SSLContext] = None cur_ssl: Optional[SSLContext] = None
if tls is not None: if tls is not None:
@ -54,22 +58,20 @@ class TCPTransport(ITCPTransport):
hostname, hostname,
port, port,
server_hostname=server_hostname, server_hostname=server_hostname,
ssl =cur_ssl, ssl=cur_ssl,
local_addr =local_addr) local_addr=local_addr,
)
if isinstance(tls, TLSVerifyHash): if isinstance(tls, TLSVerifyHash):
cert: bytes = writer.transport.get_extra_info( cert: bytes = writer.transport.get_extra_info("ssl_object").getpeercert(
"ssl_object" True
).getpeercert(True) )
if isinstance(tls, TLSVerifySHA512): if isinstance(tls, TLSVerifySHA512):
sum = sha512(cert).hexdigest() sum = sha512(cert).hexdigest()
else: else:
raise ValueError(f"unknown hash pinning {type(tls)}") raise ValueError(f"unknown hash pinning {type(tls)}")
if not sum == tls.sum: if not sum == tls.sum:
raise ValueError( raise ValueError(f"pinned hash for {hostname} does not match ({sum})")
f"pinned hash for {hostname} does not match ({sum})"
)
return (TCPReader(reader), TCPWriter(writer)) return (TCPReader(reader), TCPWriter(writer))

View File

@ -24,8 +24,8 @@ setup(
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Operating System :: POSIX", "Operating System :: POSIX",
"Operating System :: Microsoft :: Windows", "Operating System :: Microsoft :: Windows",
"Topic :: Communications :: Chat :: Internet Relay Chat" "Topic :: Communications :: Chat :: Internet Relay Chat",
], ],
python_requires='>=3.7', python_requires=">=3.7",
install_requires=install_requires install_requires=install_requires,
) )

View File

@ -1,6 +1,7 @@
import unittest import unittest
from ircrobots import glob from ircrobots import glob
class GlobTestCollapse(unittest.TestCase): class GlobTestCollapse(unittest.TestCase):
def test(self): def test(self):
c1 = glob.collapse("**?*") c1 = glob.collapse("**?*")