formatted with black

This commit is contained in:
A_D 2022-01-30 20:05:31 +02:00
parent 0ce3b9b0b0
commit 9b31aff951
No known key found for this signature in database
GPG Key ID: 4BE9EB7DF45076C4
23 changed files with 629 additions and 469 deletions

View File

@ -1,32 +1,34 @@
import asyncio, re
from argparse import ArgumentParser
from typing import Dict, List, Optional
from argparse import ArgumentParser
from typing import Dict, List, Optional
from irctokens import build, Line
from ircrobots import Bot as BaseBot
from ircrobots import Bot as BaseBot
from ircrobots import Server as BaseServer
from ircrobots import ConnectionParams
TRIGGER = "!"
def _delims(s: str, delim: str):
s_copy = list(s)
while s_copy:
char = s_copy.pop(0)
if char == delim:
if not s_copy:
yield len(s)-(len(s_copy)+1)
yield len(s) - (len(s_copy) + 1)
elif not s_copy.pop(0) == delim:
yield len(s)-(len(s_copy)+2)
yield len(s) - (len(s_copy) + 2)
def _sed(sed: str, s: str) -> Optional[str]:
if len(sed) > 1:
delim = sed[1]
last = 0
delim = sed[1]
last = 0
parts: List[str] = []
for i in _delims(sed, delim):
parts.append(sed[last:i])
last = i+1
last = i + 1
if len(parts) == 4:
break
if last < (len(sed)):
@ -36,10 +38,10 @@ def _sed(sed: str, s: str) -> Optional[str]:
flags_s = (args or [""])[0]
flags = re.I if "i" in flags_s else 0
count = 0 if "g" in flags_s else 1
count = 0 if "g" in flags_s else 1
for i in reversed(list(_delims(replace, "&"))):
replace = replace[:i] + "\\g<0>" + replace[i+1:]
replace = replace[:i] + "\\g<0>" + replace[i + 1 :]
try:
compiled = re.compile(pattern, flags)
@ -49,18 +51,22 @@ def _sed(sed: str, s: str) -> Optional[str]:
else:
return None
class Database:
def __init__(self):
self._settings: Dict[str, str] = {}
async def get(self, context: str, setting: str) -> Optional[str]:
return self._settings.get(setting, None)
async def set(self, context: str, setting: str, value: str):
self._settings[setting] = value
async def rem(self, context: str, setting: str):
if setting in self._settings:
del self._settings[setting]
class Server(BaseServer):
def __init__(self, bot: Bot, name: str, channel: str, database: Database):
super().__init__(bot, name)
@ -78,24 +84,24 @@ class Server(BaseServer):
await self.send(build("JOIN", [self._channel]))
if (
line.command == "PRIVMSG" and
self.has_channel(line.params[0]) and
not line.hostmask is None and
not self.casefold(line.hostmask.nickname) == me and
self.has_user(line.hostmask.nickname) and
line.params[1].startswith(TRIGGER)):
line.command == "PRIVMSG"
and self.has_channel(line.params[0])
and not line.hostmask is None
and not self.casefold(line.hostmask.nickname) == me
and self.has_user(line.hostmask.nickname)
and line.params[1].startswith(TRIGGER)
):
channel = self.channels[self.casefold(line.params[0])]
user = self.users[self.casefold(line.hostmask.nickname)]
cuser = channel.users[user.nickname_lower]
text = line.params[1].replace(TRIGGER, "", 1)
user = self.users[self.casefold(line.hostmask.nickname)]
cuser = channel.users[user.nickname_lower]
text = line.params[1].replace(TRIGGER, "", 1)
db_context = f"{self.name}:{channel.name}"
name, _, text = text.partition(" ")
name, _, text = text.partition(" ")
action, _, text = text.partition(" ")
name = name.lower()
key = f"factoid-{name}"
key = f"factoid-{name}"
out = ""
if not action or action == "@":
@ -125,10 +131,8 @@ class Server(BaseServer):
elif value:
changed = _sed(value, current)
if not changed is None:
await self._database.set(
db_context, key, changed)
out = (f"{user.nickname}: "
f"changed '{name}' factoid")
await self._database.set(db_context, key, changed)
out = f"{user.nickname}: " f"changed '{name}' factoid"
else:
out = f"{user.nickname}: invalid sed"
else:
@ -136,29 +140,28 @@ class Server(BaseServer):
else:
out = f"{user.nickname}: you are not an op"
else:
out = f"{user.nickname}: unknown action '{action}'"
await self.send(build("PRIVMSG", [line.params[0], out]))
class Bot(BaseBot):
def __init__(self, channel: str):
super().__init__()
self._channel = channel
def create_server(self, name: str):
return Server(self, name, self._channel, Database())
async def main(hostname: str, channel: str, nickname: str):
bot = Bot(channel)
params = ConnectionParams(
nickname,
hostname,
6697
)
params = ConnectionParams(nickname, hostname, 6697)
await bot.add_server("freenode", params)
await bot.run()
if __name__ == "__main__":
parser = ArgumentParser(description="A simple IRC bot for factoids")
parser.add_argument("hostname")

View File

@ -5,28 +5,31 @@ from ircrobots import Bot as BaseBot
from ircrobots import Server as BaseServer
from ircrobots import ConnectionParams, SASLUserPass, SASLSCRAM
class Server(BaseServer):
async def line_read(self, line: Line):
print(f"{self.name} < {line.format()}")
async def line_send(self, line: Line):
print(f"{self.name} > {line.format()}")
class Bot(BaseBot):
def create_server(self, name: str):
return Server(self, name)
async def main():
bot = Bot()
sasl_params = SASLUserPass("myusername", "invalidpassword")
params = ConnectionParams(
"MyNickname",
host = "chat.freenode.invalid",
port = 6697,
sasl = sasl_params)
params = ConnectionParams(
"MyNickname", host="chat.freenode.invalid", port=6697, sasl=sasl_params
)
await bot.add_server("freenode", params)
await bot.run()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -5,9 +5,8 @@ from ircrobots import Bot as BaseBot
from ircrobots import Server as BaseServer
from ircrobots import ConnectionParams
SERVERS = [
("freenode", "chat.freenode.invalid")
]
SERVERS = [("freenode", "chat.freenode.invalid")]
class Server(BaseServer):
async def line_read(self, line: Line):
@ -15,13 +14,16 @@ class Server(BaseServer):
if line.command == "001":
print(f"connected to {self.isupport.network}")
await self.send(build("JOIN", ["#testchannel"]))
async def line_send(self, line: Line):
print(f"{self.name} > {line.format()}")
class Bot(BaseBot):
def create_server(self, name: str):
return Server(self, name)
async def main():
bot = Bot()
for name, host in SERVERS:
@ -30,5 +32,6 @@ async def main():
await bot.run()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,5 +1,11 @@
from .bot import Bot
from .bot import Bot
from .server import Server
from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM,
STSPolicy, ResumePolicy)
from .ircv3 import Capability
from .params import (
ConnectionParams,
SASLUserPass,
SASLExternal,
SASLSCRAM,
STSPolicy,
ResumePolicy,
)
from .ircv3 import Capability

View File

@ -1,13 +1,14 @@
from asyncio import Future
from typing import (Any, Awaitable, Callable, Generator, Generic, Optional,
TypeVar)
from asyncio import Future
from typing import Any, Awaitable, Callable, Generator, Generic, Optional, TypeVar
from irctokens import Line
from .matching import IMatchResponse
from irctokens import Line
from .matching import IMatchResponse
from .interface import IServer
from .ircv3 import TAG_LABEL
from .ircv3 import TAG_LABEL
TEvent = TypeVar("TEvent")
class MaybeAwait(Generic[TEvent]):
def __init__(self, func: Callable[[], Awaitable[TEvent]]):
self._func = func
@ -16,13 +17,12 @@ class MaybeAwait(Generic[TEvent]):
coro = self._func()
return coro.__await__()
class WaitFor(object):
def __init__(self,
response: IMatchResponse,
deadline: float):
def __init__(self, response: IMatchResponse, deadline: float):
self.response = response
self.deadline = deadline
self._label: Optional[str] = None
self._label: Optional[str] = None
self._our_fut: "Future[Line]" = Future()
def __await__(self) -> Generator[Any, None, Line]:
@ -32,11 +32,9 @@ class WaitFor(object):
self._label = label
def match(self, server: IServer, line: Line):
if (self._label is not None and
line.tags is not None):
if self._label is not None and line.tags is not None:
label = TAG_LABEL.get(line.tags)
if (label is not None and
label == self._label):
if label is not None and label == self._label:
return True
return self.response.match(server, line)

View File

@ -4,10 +4,11 @@ from typing import Dict
from ircstates.server import ServerDisconnectedException
from .server import ConnectionParams, Server
from .server import ConnectionParams, Server
from .transport import TCPTransport
from .interface import IBot, IServer, ITCPTransport
class Bot(IBot):
def __init__(self):
self.servers: Dict[str, Server] = {}
@ -17,9 +18,11 @@ class Bot(IBot):
return Server(self, name)
async def disconnected(self, server: IServer):
if (server.name in self.servers and
server.params is not None and
server.disconnected):
if (
server.name in self.servers
and server.params is not None
and server.disconnected
):
reconnect = server.params.reconnect
@ -30,7 +33,7 @@ class Bot(IBot):
except Exception as e:
traceback.print_exc()
# let's try again, exponential backoff up to 5 mins
reconnect = min(reconnect*2, 300)
reconnect = min(reconnect * 2, 300)
else:
break
@ -38,10 +41,12 @@ class Bot(IBot):
del self.servers[server.name]
await server.disconnect()
async def add_server(self,
name: str,
params: ConnectionParams,
transport: ITCPTransport = TCPTransport()) -> Server:
async def add_server(
self,
name: str,
params: ConnectionParams,
transport: ITCPTransport = TCPTransport(),
) -> Server:
server = self.create_server(name)
self.servers[name] = server
await server.connect(transport, params)

View File

@ -1,6 +1,7 @@
from dataclasses import dataclass
from .interface import IServer
@dataclass
class ServerContext(object):
server: IServer

View File

@ -1,19 +1,14 @@
from typing import List
BOLD = "\x02"
COLOR = "\x03"
INVERT = "\x16"
ITALIC = "\x1D"
BOLD = "\x02"
COLOR = "\x03"
INVERT = "\x16"
ITALIC = "\x1D"
UNDERLINE = "\x1F"
RESET = "\x0F"
RESET = "\x0F"
FORMATTERS = [BOLD, INVERT, ITALIC, UNDERLINE, RESET]
FORMATTERS = [
BOLD,
INVERT,
ITALIC,
UNDERLINE,
RESET
]
def tokens(s: str) -> List[str]:
tokens: List[str] = []
@ -25,9 +20,7 @@ def tokens(s: str) -> List[str]:
for i in range(2):
if s_copy and s_copy[0].isdigit():
token += s_copy.pop(0)
if (len(s_copy) > 1 and
s_copy[0] == "," and
s_copy[1].isdigit()):
if len(s_copy) > 1 and s_copy[0] == "," and s_copy[1].isdigit():
token += s_copy.pop(0)
token += s_copy.pop(0)
if s_copy and s_copy[0].isdigit():
@ -38,6 +31,7 @@ def tokens(s: str) -> List[str]:
tokens.append(token)
return tokens
def strip(s: str):
for token in tokens(s):
s = s.replace(token, "", 1)

View File

@ -1,4 +1,3 @@
def collapse(pattern: str) -> str:
out = ""
i = 0
@ -15,9 +14,10 @@ def collapse(pattern: str) -> str:
if pattern[i:]:
out += pattern[i]
i += 1
i += 1
return out
def _match(pattern: str, s: str):
i, j = 0, 0
@ -45,10 +45,14 @@ def _match(pattern: str, s: str):
return i == len(pattern)
class Glob(object):
def __init__(self, pattern: str):
self._pattern = pattern
def match(self, s: str) -> bool:
return _match(self._pattern, s)
def compile(pattern: str) -> Glob:
return Glob(collapse(pattern))

View File

@ -1,16 +1,19 @@
from asyncio import Future
from typing import Awaitable, Iterable, List, Optional, Set, Tuple, Union
from enum import IntEnum
from typing import Awaitable, Iterable, List, Optional, Set, Tuple, Union
from enum import IntEnum
from ircstates import Server, Emit
from irctokens import Line, Hostmask
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
from .security import TLS
class ITCPReader(object):
async def read(self, byte_count: int):
pass
class ITCPWriter(object):
def write(self, data: bytes):
pass
@ -20,37 +23,40 @@ class ITCPWriter(object):
async def drain(self):
pass
async def close(self):
pass
class ITCPTransport(object):
async def connect(self,
hostname: str,
port: int,
tls: Optional[TLS],
bindhost: Optional[str]=None
) -> Tuple[ITCPReader, ITCPWriter]:
async def connect(
self,
hostname: str,
port: int,
tls: Optional[TLS],
bindhost: Optional[str] = None,
) -> Tuple[ITCPReader, ITCPWriter]:
pass
class SendPriority(IntEnum):
HIGH = 0
HIGH = 0
MEDIUM = 10
LOW = 20
LOW = 20
DEFAULT = MEDIUM
class SentLine(object):
def __init__(self,
id: int,
priority: int,
line: Line):
self.id = id
self.priority = priority
self.line = line
def __init__(self, id: int, priority: int, line: Line):
self.id = id
self.priority = priority
self.line = line
self.future: "Future[SentLine]" = Future()
def __lt__(self, other: "SentLine") -> bool:
return self.priority < other.priority
class ICapability(object):
def available(self, capabilities: Iterable[str]) -> Optional[str]:
pass
@ -61,38 +67,46 @@ class ICapability(object):
def copy(self) -> "ICapability":
pass
class IMatchResponse(object):
def match(self, server: "IServer", line: Line) -> bool:
pass
class IMatchResponseParam(object):
def match(self, server: "IServer", arg: str) -> bool:
pass
class IMatchResponseValueParam(IMatchResponseParam):
def value(self, server: "IServer"):
pass
def set_value(self, value: str):
pass
class IMatchResponseHostmask(object):
def match(self, server: "IServer", hostmask: Hostmask) -> bool:
pass
class IServer(Server):
bot: "IBot"
bot: "IBot"
disconnected: bool
params: ConnectionParams
params: ConnectionParams
desired_caps: Set[ICapability]
last_read: float
last_read: float
def send_raw(self, line: str, priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]:
pass
def send(self, line: Line, priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]:
def send_raw(self, line: str, priority=SendPriority.DEFAULT) -> Awaitable[SentLine]:
pass
def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]]
) -> Awaitable[Line]:
def send(self, line: Line, priority=SendPriority.DEFAULT) -> Awaitable[SentLine]:
pass
def wait_for(
self, response: Union[IMatchResponse, Set[IMatchResponse]]
) -> Awaitable[Line]:
pass
def set_throttle(self, rate: int, time: float):
@ -101,37 +115,44 @@ class IServer(Server):
def server_address(self) -> Tuple[str, int]:
pass
async def connect(self,
transport: ITCPTransport,
params: ConnectionParams):
async def connect(self, transport: ITCPTransport, params: ConnectionParams):
pass
async def disconnect(self):
pass
def line_preread(self, line: Line):
pass
def line_presend(self, line: Line):
pass
async def line_read(self, line: Line):
pass
async def line_send(self, line: Line):
pass
async def sts_policy(self, sts: STSPolicy):
pass
async def resume_policy(self, resume: ResumePolicy):
pass
def cap_agreed(self, capability: ICapability) -> bool:
pass
def cap_available(self, capability: ICapability) -> Optional[str]:
pass
async def sasl_auth(self, sasl: SASLParams) -> bool:
pass
class IBot(object):
def create_server(self, name: str) -> IServer:
pass
async def disconnected(self, server: IServer):
pass

View File

@ -1,22 +1,25 @@
from time import time
from typing import Dict, Iterable, List, Optional, Tuple
from time import time
from typing import Dict, Iterable, List, Optional, Tuple
from dataclasses import dataclass
from irctokens import build
from irctokens import build
from ircstates.server import ServerDisconnectedException
from .contexts import ServerContext
from .matching import Response, ANY
from .contexts import ServerContext
from .matching import Response, ANY
from .interface import ICapability
from .params import ConnectionParams, STSPolicy, ResumePolicy
from .security import TLS_VERIFYCHAIN
from .params import ConnectionParams, STSPolicy, ResumePolicy
from .security import TLS_VERIFYCHAIN
class Capability(ICapability):
def __init__(self,
ratified_name: Optional[str],
draft_name: Optional[str]=None,
alias: Optional[str]=None,
depends_on: List[str]=[]):
self.name = ratified_name
def __init__(
self,
ratified_name: Optional[str],
draft_name: Optional[str] = None,
alias: Optional[str] = None,
depends_on: List[str] = [],
):
self.name = ratified_name
self.draft = draft_name
self.alias = alias or ratified_name
self.depends_on = depends_on.copy()
@ -26,8 +29,7 @@ class Capability(ICapability):
def match(self, capability: str) -> bool:
return capability in self._caps
def available(self, capabilities: Iterable[str]
) -> Optional[str]:
def available(self, capabilities: Iterable[str]) -> Optional[str]:
for cap in self._caps:
if not cap is None and cap in capabilities:
return cap
@ -36,16 +38,13 @@ class Capability(ICapability):
def copy(self):
return Capability(
self.name,
self.draft,
alias=self.alias,
depends_on=self.depends_on[:])
self.name, self.draft, alias=self.alias, depends_on=self.depends_on[:]
)
class MessageTag(object):
def __init__(self,
name: Optional[str],
draft_name: Optional[str]=None):
self.name = name
def __init__(self, name: Optional[str], draft_name: Optional[str] = None):
self.name = name
self.draft = draft_name
self._tags = [self.name, self.draft]
@ -63,37 +62,36 @@ class MessageTag(object):
else:
return None
CAP_SASL = Capability("sasl")
CAP_ECHO = Capability("echo-message")
CAP_STS = Capability("sts", "draft/sts")
CAP_SASL = Capability("sasl")
CAP_ECHO = Capability("echo-message")
CAP_STS = Capability("sts", "draft/sts")
CAP_RESUME = Capability(None, "draft/resume-0.5", alias="resume")
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
TAG_LABEL = MessageTag("label", "draft/label")
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
TAG_LABEL = MessageTag("label", "draft/label")
LABEL_TAG_MAP = {
"draft/labeled-response-0.2": "draft/label",
"labeled-response": "label"
"labeled-response": "label",
}
CAPS: List[ICapability] = [
Capability("multi-prefix"),
Capability("chghost"),
Capability("away-notify"),
Capability("invite-notify"),
Capability("account-tag"),
Capability("account-notify"),
Capability("extended-join"),
Capability("message-tags", "draft/message-tags-0.2"),
Capability("cap-notify"),
Capability("batch"),
Capability(None, "draft/rename", alias="rename"),
Capability("setname", "draft/setname"),
CAP_RESUME
CAP_RESUME,
]
def _cap_dict(s: str) -> Dict[str, str]:
d: Dict[str, str] = {}
for token in s.split(","):
@ -101,41 +99,44 @@ def _cap_dict(s: str) -> Dict[str, str]:
d[key] = value
return d
async def sts_transmute(params: ConnectionParams):
if not params.sts is None and params.tls is None:
now = time()
since = (now-params.sts.created)
now = time()
since = now - params.sts.created
if since <= params.sts.duration:
params.port = params.sts.port
params.tls = TLS_VERIFYCHAIN
params.tls = TLS_VERIFYCHAIN
async def resume_transmute(params: ConnectionParams):
if params.resume is not None:
params.host = params.resume.address
class HandshakeCancel(Exception):
pass
class CAPContext(ServerContext):
async def on_ls(self, tokens: Dict[str, str]):
await self._sts(tokens)
caps = list(self.server.desired_caps)+CAPS
caps = list(self.server.desired_caps) + CAPS
if (not self.server.params.sasl is None and
not CAP_SASL in caps):
if not self.server.params.sasl is None and not CAP_SASL in caps:
caps.append(CAP_SASL)
matched = (c.available(tokens) for c in caps)
matched = (c.available(tokens) for c in caps)
cap_names = [name for name in matched if not name is None]
if cap_names:
await self.server.send(build("CAP", ["REQ", " ".join(cap_names)]))
while cap_names:
line = await self.server.wait_for({
Response("CAP", [ANY, "ACK"]),
Response("CAP", [ANY, "NAK"])
})
line = await self.server.wait_for(
{Response("CAP", [ANY, "ACK"]), Response("CAP", [ANY, "NAK"])}
)
current_caps = line.params[2].split(" ")
for cap in current_caps:
@ -144,8 +145,7 @@ class CAPContext(ServerContext):
if CAP_RESUME.available(current_caps):
await self.resume_token()
if (self.server.cap_agreed(CAP_SASL) and
not self.server.params.sasl is None):
if self.server.cap_agreed(CAP_SASL) and not self.server.params.sasl is None:
await self.server.sasl_auth(self.server.params.sasl)
async def resume_token(self):
@ -160,10 +160,9 @@ class CAPContext(ServerContext):
if previous_policy is not None and not self.server.registered:
await self.server.send(build("RESUME", [previous_policy.token]))
line = await self.server.wait_for({
Response("RESUME", ["SUCCESS"]),
Response("FAIL", ["RESUME"])
})
line = await self.server.wait_for(
{Response("RESUME", ["SUCCESS"]), Response("FAIL", ["RESUME"])}
)
if line.command == "RESUME":
raise HandshakeCancel()
@ -179,11 +178,11 @@ class CAPContext(ServerContext):
cap_sts = CAP_STS.available(tokens)
if not cap_sts is None:
sts_dict = _cap_dict(tokens[cap_sts])
params = self.server.params
params = self.server.params
if not params.tls:
if "port" in sts_dict:
params.port = int(sts_dict["port"])
params.tls = TLS_VERIFYCHAIN
params.tls = TLS_VERIFYCHAIN
await self.server.bot.disconnect(self.server)
await self.server.bot.add_server(self.server.name, params)
@ -194,6 +193,6 @@ class CAPContext(ServerContext):
int(time()),
params.port,
int(sts_dict["duration"]),
"preload" in sts_dict)
"preload" in sts_dict,
)
await self.server.sts_policy(policy)

View File

@ -1,3 +1,2 @@
from .responses import *
from .params import *
from .params import *

View File

@ -1,16 +1,24 @@
from re import compile as re_compile
from typing import Optional, Pattern, Union
from irctokens import Hostmask
from ..interface import (IMatchResponseParam, IMatchResponseValueParam,
IMatchResponseHostmask, IServer)
from ..glob import Glob, compile as glob_compile
from re import compile as re_compile
from typing import Optional, Pattern, Union
from irctokens import Hostmask
from ..interface import (
IMatchResponseParam,
IMatchResponseValueParam,
IMatchResponseHostmask,
IServer,
)
from ..glob import Glob, compile as glob_compile
from .. import formatting
class Any(IMatchResponseParam):
def __repr__(self) -> str:
return "Any()"
def match(self, server: IServer, arg: str) -> bool:
return True
ANY = Any()
# NOT
@ -18,107 +26,142 @@ ANY = Any()
# REGEX
# LITERAL
class Literal(IMatchResponseValueParam):
def __init__(self, value: str):
self._value = value
def __repr__(self) -> str:
return f"{self._value!r}"
def value(self, server: IServer) -> str:
return self._value
def set_value(self, value: str):
self._value = value
def match(self, server: IServer, arg: str) -> bool:
return arg == self._value
TYPE_MAYBELIT = Union[str, IMatchResponseParam]
TYPE_MAYBELIT = Union[str, IMatchResponseParam]
TYPE_MAYBELIT_VALUE = Union[str, IMatchResponseValueParam]
def _assure_lit(value: TYPE_MAYBELIT_VALUE) -> IMatchResponseValueParam:
if isinstance(value, str):
return Literal(value)
else:
return value
class Not(IMatchResponseParam):
def __init__(self, param: IMatchResponseParam):
self._param = param
def __repr__(self) -> str:
return f"Not({self._param!r})"
def match(self, server: IServer, arg: str) -> bool:
return not self._param.match(server, arg)
class ParamValuePassthrough(IMatchResponseValueParam):
_value: IMatchResponseValueParam
def value(self, server: IServer):
return self._value.value(server)
def set_value(self, value: str):
self._value.set_value(value)
class Folded(ParamValuePassthrough):
def __init__(self, value: TYPE_MAYBELIT_VALUE):
self._value = _assure_lit(value)
self._folded = False
def __repr__(self) -> str:
return f"Folded({self._value!r})"
def match(self, server: IServer, arg: str) -> bool:
if not self._folded:
value = self.value(server)
value = self.value(server)
folded = server.casefold(value)
self.set_value(folded)
self._folded = True
return self._value.match(server, server.casefold(arg))
class Formatless(IMatchResponseParam):
def __init__(self, value: TYPE_MAYBELIT_VALUE):
self._value = _assure_lit(value)
def __repr__(self) -> str:
brepr = super().__repr__()
return f"Formatless({brepr})"
def match(self, server: IServer, arg: str) -> bool:
strip = formatting.strip(arg)
return self._value.match(server, strip)
class Regex(IMatchResponseParam):
def __init__(self, value: str):
self._value = value
self._pattern: Optional[Pattern] = None
def match(self, server: IServer, arg: str) -> bool:
if self._pattern is None:
self._pattern = re_compile(self._value)
return bool(self._pattern.search(arg))
class Self(IMatchResponseParam):
def __repr__(self) -> str:
return "Self()"
def match(self, server: IServer, arg: str) -> bool:
return server.casefold(arg) == server.nickname_lower
SELF = Self()
class MaskSelf(IMatchResponseHostmask):
def __repr__(self) -> str:
return "MaskSelf()"
def match(self, server: IServer, hostmask: Hostmask):
return server.casefold(hostmask.nickname) == server.nickname_lower
MASK_SELF = MaskSelf()
class Nick(IMatchResponseHostmask):
def __init__(self, nickname: str):
self._nickname = nickname
self._folded: Optional[str] = None
def __repr__(self) -> str:
return f"Nick({self._nickname!r})"
def match(self, server: IServer, hostmask: Hostmask):
if self._folded is None:
self._folded = server.casefold(self._nickname)
return self._folded == server.casefold(hostmask.nickname)
class Mask(IMatchResponseHostmask):
def __init__(self, mask: str):
self._mask = mask
self._compiled: Optional[Glob]
def __repr__(self) -> str:
return f"Mask({self._mask!r})"
def match(self, server: IServer, hostmask: Hostmask):
if self._compiled is None:
self._compiled = glob_compile(self._mask)

View File

@ -1,17 +1,25 @@
from typing import List, Optional, Sequence, Union
from irctokens import Line
from ..interface import (IServer, IMatchResponse, IMatchResponseParam,
IMatchResponseHostmask)
from .params import *
from typing import List, Optional, Sequence, Union
from irctokens import Line
from ..interface import (
IServer,
IMatchResponse,
IMatchResponseParam,
IMatchResponseHostmask,
)
from .params import *
TYPE_PARAM = Union[str, IMatchResponseParam]
class Responses(IMatchResponse):
def __init__(self,
commands: Sequence[str],
params: Sequence[TYPE_PARAM]=[],
source: Optional[IMatchResponseHostmask]=None):
def __init__(
self,
commands: Sequence[str],
params: Sequence[TYPE_PARAM] = [],
source: Optional[IMatchResponseHostmask] = None,
):
self._commands = commands
self._source = source
self._source = source
self._params: Sequence[IMatchResponseParam] = []
for param in params:
@ -25,36 +33,43 @@ class Responses(IMatchResponse):
def match(self, server: IServer, line: Line) -> bool:
for command in self._commands:
if (line.command == command and (
self._source is None or (
line.hostmask is not None and
self._source.match(server, line.hostmask)
))):
if line.command == command and (
self._source is None
or (
line.hostmask is not None
and self._source.match(server, line.hostmask)
)
):
for i, param in enumerate(self._params):
if (i >= len(line.params) or
not param.match(server, line.params[i])):
if i >= len(line.params) or not param.match(server, line.params[i]):
break
else:
return True
else:
return False
class Response(Responses):
def __init__(self,
command: str,
params: Sequence[TYPE_PARAM]=[],
source: Optional[IMatchResponseHostmask]=None):
def __init__(
self,
command: str,
params: Sequence[TYPE_PARAM] = [],
source: Optional[IMatchResponseHostmask] = None,
):
super().__init__([command], params, source=source)
def __repr__(self) -> str:
return f"Response({self._commands[0]}: {self._params!r})"
class ResponseOr(IMatchResponse):
def __init__(self, *responses: IMatchResponse):
self._responses = responses
def __repr__(self) -> str:
return f"ResponseOr({self._responses!r})"
def match(self, server: IServer, line: Line) -> bool:
for response in self._responses:
if response.match(server, line):

View File

@ -1,74 +1,80 @@
from re import compile as re_compile
from typing import List, Optional
from re import compile as re_compile
from typing import List, Optional
from dataclasses import dataclass, field
from .security import TLS, TLS_NOVERIFY, TLS_VERIFYCHAIN
class SASLParams(object):
mechanism: str
@dataclass
class _SASLUserPass(SASLParams):
username: str
password: str
username: str
password: str
class SASLUserPass(_SASLUserPass):
mechanism = "USERPASS"
class SASLSCRAM(_SASLUserPass):
mechanism = "SCRAM"
class SASLExternal(SASLParams):
mechanism = "EXTERNAL"
@dataclass
class STSPolicy(object):
created: int
port: int
created: int
port: int
duration: int
preload: bool
preload: bool
@dataclass
class ResumePolicy(object):
address: str
token: str
token: str
RE_IPV6HOST = re_compile("\[([a-fA-F0-9:]+)\]")
_TLS_TYPES = {
"+": TLS_VERIFYCHAIN,
"~": TLS_NOVERIFY
}
_TLS_TYPES = {"+": TLS_VERIFYCHAIN, "~": TLS_NOVERIFY}
@dataclass
class ConnectionParams(object):
nickname: str
host: str
port: int
tls: Optional[TLS] = TLS_VERIFYCHAIN
host: str
port: int
tls: Optional[TLS] = TLS_VERIFYCHAIN
username: Optional[str] = None
realname: Optional[str] = None
bindhost: Optional[str] = None
password: Optional[str] = None
sasl: Optional[SASLParams] = None
password: Optional[str] = None
sasl: Optional[SASLParams] = None
sts: Optional[STSPolicy] = None
sts: Optional[STSPolicy] = None
resume: Optional[ResumePolicy] = None
reconnect: int = 10 # seconds
reconnect: int = 10 # seconds
alt_nicknames: List[str] = field(default_factory=list)
autojoin: List[str] = field(default_factory=list)
autojoin: List[str] = field(default_factory=list)
@staticmethod
def from_hoststring(
nickname: str,
hoststring: str
) -> "ConnectionParams":
def from_hoststring(nickname: str, hoststring: str) -> "ConnectionParams":
ipv6host = RE_IPV6HOST.search(hoststring)
if ipv6host is not None and ipv6host.start() == 0:
host = ipv6host.group(1)
port_s = hoststring[ipv6host.end()+1:]
port_s = hoststring[ipv6host.end() + 1 :]
else:
host, _, port_s = hoststring.strip().partition(":")

View File

@ -1,52 +1,62 @@
from typing import List
from enum import Enum
from base64 import b64decode, b64encode
from typing import List
from enum import Enum
from base64 import b64decode, b64encode
from irctokens import build
from ircstates.numerics import *
from .matching import Responses, Response, ANY
from .contexts import ServerContext
from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal
from .scram import SCRAMContext, SCRAMAlgorithm
from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal
from .scram import SCRAMContext, SCRAMAlgorithm
SASL_SCRAM_MECHANISMS = [
"SCRAM-SHA-512",
"SCRAM-SHA-256",
"SCRAM-SHA-1",
]
SASL_USERPASS_MECHANISMS = SASL_SCRAM_MECHANISMS+["PLAIN"]
SASL_USERPASS_MECHANISMS = SASL_SCRAM_MECHANISMS + ["PLAIN"]
class SASLResult(Enum):
NONE = 0
NONE = 0
SUCCESS = 1
FAILURE = 2
ALREADY = 3
class SASLError(Exception):
pass
class SASLUnknownMechanismError(SASLError):
pass
AUTH_BYTE_MAX = 400
AUTHENTICATE_ANY = Response("AUTHENTICATE", [ANY])
NUMERICS_FAIL = Response(ERR_SASLFAIL)
NUMERICS_INITIAL = Responses([
ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS, ERR_SASLABORTED
])
NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL])
NUMERICS_FAIL = Response(ERR_SASLFAIL)
NUMERICS_INITIAL = Responses(
[ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS, ERR_SASLABORTED]
)
NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL])
def _b64e(s: str):
return b64encode(s.encode("utf8")).decode("ascii")
def _b64eb(s: bytes) -> str:
# encode-from-bytes
return b64encode(s).decode("ascii")
def _b64db(s: str) -> bytes:
# decode-to-bytes
return b64decode(s)
class SASLContext(ServerContext):
async def from_params(self, params: SASLParams) -> SASLResult:
if isinstance(params, SASLUserPass):
@ -57,15 +67,12 @@ class SASLContext(ServerContext):
return await self.external()
else:
raise SASLUnknownMechanismError(
"SASLParams given with unknown mechanism "
f"{params.mechanism!r}")
"SASLParams given with unknown mechanism " f"{params.mechanism!r}"
)
async def external(self) -> SASLResult:
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
line = await self.server.wait_for({
AUTHENTICATE_ANY,
NUMERICS_INITIAL
})
line = await self.server.wait_for({AUTHENTICATE_ANY, NUMERICS_INITIAL})
if line.command == "907":
# we've done SASL already. cleanly abort
@ -73,8 +80,8 @@ class SASLContext(ServerContext):
elif line.command == "908":
available = line.params[1].split(",")
raise SASLUnknownMechanismError(
"Server does not support SASL EXTERNAL "
f"(it supports {available}")
"Server does not support SASL EXTERNAL " f"(it supports {available}"
)
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
await self.server.send(build("AUTHENTICATE", ["+"]))
@ -89,11 +96,12 @@ class SASLContext(ServerContext):
async def scram(self, username: str, password: str) -> SASLResult:
return await self.userpass(username, password, SASL_SCRAM_MECHANISMS)
async def userpass(self,
username: str,
password: str,
mechanisms: List[str]=SASL_USERPASS_MECHANISMS
) -> SASLResult:
async def userpass(
self,
username: str,
password: str,
mechanisms: List[str] = SASL_USERPASS_MECHANISMS,
) -> SASLResult:
def _common(server_mechs) -> List[str]:
mechs: List[str] = []
for our_mech in mechanisms:
@ -106,23 +114,21 @@ class SASLContext(ServerContext):
raise SASLUnknownMechanismError(
"No matching SASL mechanims. "
f"(we want: {mechanisms} "
f"server has: {server_mechs})")
f"server has: {server_mechs})"
)
if self.server.available_caps["sasl"]:
# CAP v3.2 tells us what mechs it supports
available = self.server.available_caps["sasl"].split(",")
match = _common(available)
match = _common(available)
else:
# CAP v3.1 does not. pick the pick and wait for 907 to inform us of
# what mechanisms are supported
match = mechanisms
match = mechanisms
while match:
await self.server.send(build("AUTHENTICATE", [match[0]]))
line = await self.server.wait_for({
AUTHENTICATE_ANY,
NUMERICS_INITIAL
})
line = await self.server.wait_for({AUTHENTICATE_ANY, NUMERICS_INITIAL})
if line.command == "907":
# we've done SASL already. cleanly abort
@ -130,7 +136,7 @@ class SASLContext(ServerContext):
elif line.command == "908":
# prior to CAP v3.2 - ERR telling us which mechs are supported
available = line.params[1].split(",")
match = _common(available)
match = _common(available)
await self.server.wait_for(NUMERICS_FAIL)
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
auth_text = ""
@ -138,8 +144,7 @@ class SASLContext(ServerContext):
if match[0] == "PLAIN":
auth_text = f"{username}\0{username}\0{password}"
elif match[0].startswith("SCRAM-SHA-"):
auth_text = await self._scram(
match[0], username, password)
auth_text = await self._scram(match[0], username, password)
if not auth_text == "+":
auth_text = _b64e(auth_text)
@ -148,7 +153,7 @@ class SASLContext(ServerContext):
await self._send_auth_text(auth_text)
line = await self.server.wait_for(NUMERICS_LAST)
if line.command == "903":
if line.command == "903":
return SASLResult.SUCCESS
elif line.command == "904":
match.pop(0)
@ -157,11 +162,8 @@ class SASLContext(ServerContext):
return SASLResult.FAILURE
async def _scram(self, algo_str: str,
username: str,
password: str) -> str:
algo_str_prep = algo_str.replace("SCRAM-", "", 1
).replace("-", "").upper()
async def _scram(self, algo_str: str, username: str, password: str) -> str:
algo_str_prep = algo_str.replace("SCRAM-", "", 1).replace("-", "").upper()
try:
algo = SCRAMAlgorithm(algo_str_prep)
except ValueError:
@ -179,15 +181,15 @@ class SASLContext(ServerContext):
line = await self.server.wait_for(AUTHENTICATE_ANY)
server_final = _b64db(line.params[0])
verified = scram.server_final(server_final)
#TODO PANIC if verified is false!
verified = scram.server_final(server_final)
# TODO PANIC if verified is false!
return "+"
else:
return ""
async def _send_auth_text(self, text: str):
n = AUTH_BYTE_MAX
chunks = [text[i:i+n] for i in range(0, len(text), n)]
chunks = [text[i : i + n] for i in range(0, len(text), n)]
if len(chunks[-1]) == 400:
chunks.append("+")

View File

@ -7,51 +7,60 @@ from typing import Dict
# https://www.iana.org/assignments/hash-function-text-names/
# MD2 has been removed as it's unacceptably weak
class SCRAMAlgorithm(Enum):
MD5 = "MD5"
SHA_1 = "SHA1"
MD5 = "MD5"
SHA_1 = "SHA1"
SHA_224 = "SHA224"
SHA_256 = "SHA256"
SHA_384 = "SHA384"
SHA_512 = "SHA512"
SCRAM_ERRORS = [
"invalid-encoding",
"extensions-not-supported", # unrecognized 'm' value
"extensions-not-supported", # unrecognized 'm' value
"invalid-proof",
"channel-bindings-dont-match",
"server-does-support-channel-binding",
"channel-binding-not-supported",
"unsupported-channel-binding-type",
"unknown-user",
"invalid-username-encoding", # invalid utf8 or bad SASLprep
"no-resources"
"invalid-username-encoding", # invalid utf8 or bad SASLprep
"no-resources",
]
def _scram_nonce() -> bytes:
return base64.b64encode(os.urandom(32))
def _scram_escape(s: bytes) -> bytes:
return s.replace(b"=", b"=3D").replace(b",", b"=2C")
def _scram_unescape(s: bytes) -> bytes:
return s.replace(b"=3D", b"=").replace(b"=2C", b",")
def _scram_xor(s1: bytes, s2: bytes) -> bytes:
return bytes(a ^ b for a, b in zip(s1, s2))
class SCRAMState(Enum):
NONE = 0
CLIENT_FIRST = 1
CLIENT_FINAL = 2
SUCCESS = 3
FAILURE = 4
NONE = 0
CLIENT_FIRST = 1
CLIENT_FINAL = 2
SUCCESS = 3
FAILURE = 4
VERIFY_FAILURE = 5
class SCRAMError(Exception):
pass
class SCRAMContext(object):
def __init__(self, algo: SCRAMAlgorithm,
username: str,
password: str):
self._algo = algo
def __init__(self, algo: SCRAMAlgorithm, username: str, password: str):
self._algo = algo
self._username = username.encode("utf8")
self._password = password.encode("utf8")
@ -59,11 +68,11 @@ class SCRAMContext(object):
self.error = ""
self.raw_error = ""
self._client_first = b""
self._client_nonce = b""
self._client_first = b""
self._client_nonce = b""
self._salted_password = b""
self._auth_message = b""
self._auth_message = b""
def _get_pieces(self, data: bytes) -> Dict[bytes, bytes]:
pieces = (piece.split(b"=", 1) for piece in data.split(b","))
@ -71,6 +80,7 @@ class SCRAMContext(object):
def _hmac(self, key: bytes, msg: bytes) -> bytes:
return hmac.new(key, msg, self._algo.value).digest()
def _hash(self, msg: bytes) -> bytes:
return hashlib.new(self._algo.value, msg).digest()
@ -89,7 +99,9 @@ class SCRAMContext(object):
self.state = SCRAMState.CLIENT_FIRST
self._client_nonce = _scram_nonce()
self._client_first = b"n=%s,r=%s" % (
_scram_escape(self._username), self._client_nonce)
_scram_escape(self._username),
self._client_nonce,
)
# n,,n=<username>,r=<nonce>
return b"n,,%s" % self._client_first
@ -109,17 +121,17 @@ class SCRAMContext(object):
if self._assert_error(pieces):
return b""
nonce = pieces[b"r"] # server combines your nonce with it's own
if (not nonce.startswith(self._client_nonce) or
nonce == self._client_nonce):
nonce = pieces[b"r"] # server combines your nonce with it's own
if not nonce.startswith(self._client_nonce) or nonce == self._client_nonce:
self._fail("nonce-unacceptable")
return b""
salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded
salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded
iterations = int(pieces[b"i"])
salted_password = hashlib.pbkdf2_hmac(self._algo.value,
self._password, salt, iterations, dklen=None)
salted_password = hashlib.pbkdf2_hmac(
self._algo.value, self._password, salt, iterations, dklen=None
)
self._salted_password = salted_password
client_key = self._hmac(salted_password, b"Client Key")

View File

@ -1,26 +1,35 @@
import ssl
class TLS:
pass
# tls without verification
class TLSNoVerify(TLS):
pass
TLS_NOVERIFY = TLSNoVerify()
# verify via CAs
class TLSVerifyChain(TLS):
pass
TLS_VERIFYCHAIN = TLSVerifyChain()
# verify by a pinned hash
class TLSVerifyHash(TLSNoVerify):
def __init__(self, sum: str):
self.sum = sum.lower()
class TLSVerifySHA512(TLSVerifyHash):
pass
def tls_context(verify: bool=True) -> ssl.SSLContext:
def tls_context(verify: bool = True) -> ssl.SSLContext:
ctx = ssl.create_default_context()
if not verify:
ctx.check_hostname = False

View File

@ -1,36 +1,58 @@
import asyncio
from asyncio import Future, PriorityQueue
from typing import (AsyncIterable, Awaitable, Deque, Dict, Iterable, List,
Optional, Set, Tuple, Union)
from asyncio import Future, PriorityQueue
from typing import (
AsyncIterable,
Awaitable,
Deque,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
from collections import deque
from time import monotonic
from time import monotonic
import anyio
from asyncio_rlock import RLock
from asyncio_throttle import Throttler
from async_timeout import timeout as timeout_
from ircstates import Emit, Channel, ChannelUser
from asyncio_rlock import RLock
from asyncio_throttle import Throttler
from async_timeout import timeout as timeout_
from ircstates import Emit, Channel, ChannelUser
from ircstates.numerics import *
from ircstates.server import ServerDisconnectedException
from ircstates.names import Name
from irctokens import build, Line, tokenise
from ircstates.server import ServerDisconnectedException
from ircstates.names import Name
from irctokens import build, Line, tokenise
from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL,
CAP_LABEL, LABEL_TAG_MAP, resume_transmute)
from .sasl import SASLContext, SASLResult
from .matching import (ResponseOr, Responses, Response, ANY, SELF, MASK_SELF,
Folded)
from .asyncs import MaybeAwait, WaitFor
from .struct import Whois
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
from .interface import (IBot, ICapability, IServer, SentLine, SendPriority,
IMatchResponse)
from .ircv3 import (
CAPContext,
sts_transmute,
CAP_ECHO,
CAP_SASL,
CAP_LABEL,
LABEL_TAG_MAP,
resume_transmute,
)
from .sasl import SASLContext, SASLResult
from .matching import ResponseOr, Responses, Response, ANY, SELF, MASK_SELF, Folded
from .asyncs import MaybeAwait, WaitFor
from .struct import Whois
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
from .interface import (
IBot,
ICapability,
IServer,
SentLine,
SendPriority,
IMatchResponse,
)
from .interface import ITCPTransport, ITCPReader, ITCPWriter
THROTTLE_RATE = 4 # lines
THROTTLE_TIME = 2 # seconds
PING_TIMEOUT = 60 # seconds
WAIT_TIMEOUT = 20 # seconds
PING_TIMEOUT = 60 # seconds
WAIT_TIMEOUT = 20 # seconds
JOIN_ERR_FIRST = [
ERR_NOSUCHCHANNEL,
@ -41,13 +63,14 @@ JOIN_ERR_FIRST = [
ERR_INVITEONLYCHAN,
ERR_BADCHANNELKEY,
ERR_NEEDREGGEDNICK,
ERR_THROTTLE
ERR_THROTTLE,
]
class Server(IServer):
_reader: ITCPReader
_writer: ITCPWriter
params: ConnectionParams
params: ConnectionParams
def __init__(self, bot: IBot, name: str):
super().__init__(name)
@ -58,23 +81,23 @@ class Server(IServer):
self.throttle = Throttler(rate_limit=100, period=1)
self.sasl_state = SASLResult.NONE
self.last_read = monotonic()
self.last_read = monotonic()
self._sent_count: int = 0
self._sent_count: int = 0
self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
self.desired_caps: Set[ICapability] = set([])
self._read_queue: Deque[Line] = deque()
self._read_queue: Deque[Line] = deque()
self._process_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
self._ping_sent = False
self._ping_sent = False
self._read_lguard = RLock()
self.read_lock = self._read_lguard
self._read_lwork = asyncio.Lock()
self._wait_for = asyncio.Event()
self.read_lock = self._read_lguard
self._read_lwork = asyncio.Lock()
self._wait_for = asyncio.Event()
self._pending_who: Deque[str] = deque()
self._alt_nicks: List[str] = []
self._alt_nicks: List[str] = []
def hostmask(self) -> str:
hostmask = self.nickname
@ -84,13 +107,10 @@ class Server(IServer):
hostmask += f"@{self.hostname}"
return hostmask
def send_raw(self, line: str, priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]:
def send_raw(self, line: str, priority=SendPriority.DEFAULT) -> Awaitable[SentLine]:
return self.send(tokenise(line), priority)
def send(self,
line: Line,
priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]:
def send(self, line: Line, priority=SendPriority.DEFAULT) -> Awaitable[SentLine]:
self.line_presend(line)
sent_line = SentLine(self._sent_count, priority, line)
@ -110,28 +130,25 @@ class Server(IServer):
def set_throttle(self, rate: int, time: float):
self.throttle.rate_limit = rate
self.throttle.period = time
self.throttle.period = time
def server_address(self) -> Tuple[str, int]:
return self._writer.get_peer()
async def connect(self,
transport: ITCPTransport,
params: ConnectionParams):
async def connect(self, transport: ITCPTransport, params: ConnectionParams):
await sts_transmute(params)
await resume_transmute(params)
reader, writer = await transport.connect(
params.host,
params.port,
tls =params.tls,
bindhost =params.bindhost)
params.host, params.port, tls=params.tls, bindhost=params.bindhost
)
self._reader = reader
self._writer = writer
self.params = params
await self.handshake()
async def disconnect(self):
if not self._writer is None:
await self._writer.close()
@ -145,29 +162,35 @@ class Server(IServer):
alt_nicks = self.params.alt_nicknames
if not alt_nicks:
alt_nicks = [nickname+"_"*i for i in range(1, 4)]
self._alt_nicks = alt_nicks
alt_nicks = [nickname + "_" * i for i in range(1, 4)]
self._alt_nicks = alt_nicks
# these must remain non-awaited; reading hasn't started yet
if not self.params.password is None:
self.send(build("PASS", [self.params.password]))
self.send(build("CAP", ["LS", "302"]))
self.send(build("CAP", ["LS", "302"]))
self.send(build("NICK", [nickname]))
self.send(build("USER", [username, "0", "*", realname]))
# to be overridden
def line_preread(self, line: Line):
pass
def line_presend(self, line: Line):
pass
async def line_read(self, line: Line):
pass
async def line_send(self, line: Line):
pass
async def sts_policy(self, sts: STSPolicy):
pass
async def resume_policy(self, resume: ResumePolicy):
pass
# /to be overriden
async def _on_read(self, line: Line, emit: Optional[Emit]):
@ -176,13 +199,14 @@ class Server(IServer):
elif line.command == RPL_ENDOFWHO:
chan = self.casefold(line.params[1])
if (self._pending_who and
self._pending_who[0] == chan):
if self._pending_who and self._pending_who[0] == chan:
self._pending_who.popleft()
await self._next_who()
elif (line.command in {
ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE
} and not self.registered):
elif (
line.command
in {ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE}
and not self.registered
):
if self._alt_nicks:
nick = self._alt_nicks.pop(0)
await self.send(build("NICK", [nick]))
@ -203,8 +227,7 @@ class Server(IServer):
await self._check_regain([line.params[1]])
elif line.command == RPL_MONOFFLINE:
await self._check_regain(line.params[1].split(","))
elif (line.command in ["NICK", "QUIT"] and
line.source is not None):
elif line.command in ["NICK", "QUIT"] and line.source is not None:
await self._check_regain([line.hostmask.nickname])
elif emit is not None:
@ -216,10 +239,9 @@ class Server(IServer):
await self._batch_joins(self.params.autojoin)
elif emit.command == "CAP":
if emit.subcommand == "NEW":
if emit.subcommand == "NEW":
await self._cap_ls(emit)
elif (emit.subcommand == "LS" and
emit.finished):
elif emit.subcommand == "LS" and emit.finished:
if not self.registered:
await CAPContext(self).handshake()
else:
@ -227,7 +249,7 @@ class Server(IServer):
elif emit.command == "JOIN":
if emit.self and not emit.channel is None:
chan = emit.channel.name_lower
chan = emit.channel.name_lower
await self.send(build("MODE", [chan]))
modes = "".join(self.isupport.chanmodes.a_modes)
@ -241,18 +263,18 @@ class Server(IServer):
async def _check_regain(self, nicks: List[str]):
for nick in nicks:
if (self.casefold_equals(nick, self.params.nickname) and
not self.nickname == self.params.nickname):
if (
self.casefold_equals(nick, self.params.nickname)
and not self.nickname == self.params.nickname
):
await self.send(build("NICK", [self.params.nickname]))
async def _batch_joins(self,
channels: List[str],
batch_n: int=10):
#TODO: do as many JOINs in one line as we can fit
#TODO: channel keys
async def _batch_joins(self, channels: List[str], batch_n: int = 10):
# TODO: do as many JOINs in one line as we can fit
# TODO: channel keys
for i in range(0, len(channels), batch_n):
batch = channels[i:i+batch_n]
batch = channels[i : i + batch_n]
await self.send(build("JOIN", [",".join(batch)]))
async def _next_who(self):
@ -275,7 +297,7 @@ class Server(IServer):
return None
self.last_read = monotonic()
lines = self.recv(data)
lines = self.recv(data)
for line in lines:
self.line_preread(line)
self._read_queue.append(line)
@ -287,10 +309,10 @@ class Server(IServer):
if not self._process_queue:
async with self._read_lwork:
read_aw = self._read_line(PING_TIMEOUT)
read_aw = self._read_line(PING_TIMEOUT)
dones, notdones = await asyncio.wait(
[read_aw, self._wait_for.wait()],
return_when=asyncio.FIRST_COMPLETED
return_when=asyncio.FIRST_COMPLETED,
)
self._wait_for.clear()
@ -314,11 +336,12 @@ class Server(IServer):
line, emit = self._process_queue.popleft()
await self._on_read(line, emit)
async def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]],
sent_aw: Optional[Awaitable[SentLine]]=None,
timeout: float=WAIT_TIMEOUT
) -> Line:
async def wait_for(
self,
response: Union[IMatchResponse, Set[IMatchResponse]],
sent_aw: Optional[Awaitable[SentLine]] = None,
timeout: float = WAIT_TIMEOUT,
) -> Line:
response_obj: IMatchResponse
if isinstance(response, set):
@ -340,8 +363,9 @@ class Server(IServer):
return line
async def _on_send_line(self, line: Line):
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and
not self.cap_agreed(CAP_ECHO)):
if line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and not self.cap_agreed(
CAP_ECHO
):
new_line = line.with_source(self.hostmask())
self._read_queue.append(new_line)
@ -349,15 +373,13 @@ class Server(IServer):
while True:
lines: List[SentLine] = []
while (not lines or
(len(lines) < 5 and self._send_queue.qsize() > 0)):
while not lines or (len(lines) < 5 and self._send_queue.qsize() > 0):
prio_line = await self._send_queue.get()
lines.append(prio_line)
for line in lines:
async with self.throttle:
self._writer.write(
f"{line.line.format()}\r\n".encode("utf8"))
self._writer.write(f"{line.line.format()}\r\n".encode("utf8"))
await self._writer.drain()
@ -369,6 +391,7 @@ class Server(IServer):
# CAP-related
def cap_agreed(self, capability: ICapability) -> bool:
return bool(self.cap_available(capability))
def cap_available(self, capability: ICapability) -> Optional[str]:
return capability.available(self.agreed_caps)
@ -381,78 +404,81 @@ class Server(IServer):
await CAPContext(self).on_ls(tokens)
async def sasl_auth(self, params: SASLParams) -> bool:
if (self.sasl_state == SASLResult.NONE and
self.cap_agreed(CAP_SASL)):
if self.sasl_state == SASLResult.NONE and self.cap_agreed(CAP_SASL):
res = await SASLContext(self).from_params(params)
self.sasl_state = res
return True
else:
return False
# /CAP-related
def send_nick(self, new_nick: str) -> Awaitable[bool]:
fut = self.send(build("NICK", [new_nick]))
async def _assure() -> bool:
line = await self.wait_for({
Response("NICK", [Folded(new_nick)], source=MASK_SELF),
Responses([
ERR_BANNICKCHANGE,
ERR_NICKTOOFAST,
ERR_CANTCHANGENICK
], [ANY]),
Responses([
ERR_NICKNAMEINUSE,
ERR_ERRONEUSNICKNAME,
ERR_UNAVAILRESOURCE
], [ANY, Folded(new_nick)])
}, fut)
line = await self.wait_for(
{
Response("NICK", [Folded(new_nick)], source=MASK_SELF),
Responses(
[ERR_BANNICKCHANGE, ERR_NICKTOOFAST, ERR_CANTCHANGENICK], [ANY]
),
Responses(
[ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE],
[ANY, Folded(new_nick)],
),
},
fut,
)
return line.command == "NICK"
return MaybeAwait(_assure)
def send_join(self,
name: str,
key: Optional[str]=None
) -> Awaitable[Channel]:
def send_join(self, name: str, key: Optional[str] = None) -> Awaitable[Channel]:
fut = self.send_joins([name], [] if key is None else [key])
async def _assure():
channels = await fut
return channels[0]
return MaybeAwait(_assure)
def send_part(self, name: str):
fut = self.send(build("PART", [name]))
async def _assure():
line = await self.wait_for(
Response("PART", [Folded(name)], source=MASK_SELF),
fut
Response("PART", [Folded(name)], source=MASK_SELF), fut
)
return
return MaybeAwait(_assure)
def send_joins(self,
names: List[str],
keys: List[str]=[]
) -> Awaitable[List[Channel]]:
def send_joins(
self, names: List[str], keys: List[str] = []
) -> Awaitable[List[Channel]]:
folded_names = [self.casefold(name) for name in names]
if not keys:
fut = self.send(build("JOIN", [",".join(names)]))
else:
fut = self.send(build("JOIN", [",".join(names)]+keys))
fut = self.send(build("JOIN", [",".join(names)] + keys))
async def _assure():
channels: List[Channel] = []
while folded_names:
line = await self.wait_for({
Response(RPL_CHANNELMODEIS, [ANY, ANY]),
Responses(JOIN_ERR_FIRST, [ANY, ANY]),
Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]),
Response(ERR_LINKCHANNEL, [ANY, ANY, ANY])
}, fut)
line = await self.wait_for(
{
Response(RPL_CHANNELMODEIS, [ANY, ANY]),
Responses(JOIN_ERR_FIRST, [ANY, ANY]),
Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]),
Response(ERR_LINKCHANNEL, [ANY, ANY, ANY]),
},
fut,
)
chan: Optional[str] = None
if line.command == RPL_CHANNELMODEIS:
@ -462,7 +488,7 @@ class Server(IServer):
elif line.command == ERR_USERONCHANNEL:
chan = line.params[2]
elif line.command == ERR_LINKCHANNEL:
#XXX i dont like this
# XXX i dont like this
chan = line.params[2]
await self.wait_for(
Response(RPL_CHANNELMODEIS, [ANY, Folded(chan)])
@ -477,51 +503,58 @@ class Server(IServer):
channels.append(self.channels[folded])
return channels
return MaybeAwait(_assure)
def send_message(self, target: str, message: str
) -> Awaitable[Optional[str]]:
def send_message(self, target: str, message: str) -> Awaitable[Optional[str]]:
fut = self.send(build("PRIVMSG", [target, message]))
async def _assure():
line = await self.wait_for(
Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF),
fut
Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF), fut
)
if line.command == "PRIVMSG":
return line.params[1]
else:
return None
return MaybeAwait(_assure)
def send_whois(self,
target: str,
remote: bool=False
) -> Awaitable[Optional[Whois]]:
def send_whois(
self, target: str, remote: bool = False
) -> Awaitable[Optional[Whois]]:
args = [target]
if remote:
args.append(target)
fut = self.send(build("WHOIS", args))
async def _assure() -> Optional[Whois]:
folded = self.casefold(target)
params = [ANY, Folded(folded)]
obj = Whois()
while True:
line = await self.wait_for(Responses([
ERR_NOSUCHNICK,
ERR_NOSUCHSERVER,
RPL_WHOISUSER,
RPL_WHOISSERVER,
RPL_WHOISOPERATOR,
RPL_WHOISIDLE,
RPL_WHOISCHANNELS,
RPL_WHOISHOST,
RPL_WHOISACCOUNT,
RPL_WHOISSECURE,
RPL_ENDOFWHOIS
], params), fut)
if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]:
line = await self.wait_for(
Responses(
[
ERR_NOSUCHNICK,
ERR_NOSUCHSERVER,
RPL_WHOISUSER,
RPL_WHOISSERVER,
RPL_WHOISOPERATOR,
RPL_WHOISIDLE,
RPL_WHOISCHANNELS,
RPL_WHOISHOST,
RPL_WHOISACCOUNT,
RPL_WHOISSECURE,
RPL_ENDOFWHOIS,
],
params,
),
fut,
)
if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]:
return None
elif line.command == RPL_WHOISUSER:
nick, user, host, _, real = line.params[1:]
@ -531,7 +564,7 @@ class Server(IServer):
obj.realname = real
elif line.command == RPL_WHOISIDLE:
idle, signon, _ = line.params[2:]
obj.idle = int(idle)
obj.idle = int(idle)
obj.signon = int(signon)
elif line.command == RPL_WHOISACCOUNT:
obj.account = line.params[2]
@ -544,11 +577,11 @@ class Server(IServer):
symbols = ""
while channel[0] in self.isupport.prefix.prefixes:
symbols += channel[0]
channel = channel[1:]
channel = channel[1:]
channel_user = ChannelUser(
Name(obj.nickname, folded),
Name(channel, self.casefold(channel))
Name(channel, self.casefold(channel)),
)
for symbol in symbols:
mode = self.isupport.prefix.from_prefix(symbol)
@ -558,4 +591,5 @@ class Server(IServer):
obj.channels.append(channel_user)
elif line.command == RPL_ENDOFWHOIS:
return obj
return MaybeAwait(_assure)

View File

@ -3,21 +3,21 @@ from dataclasses import dataclass
from ircstates import ChannelUser
class Whois(object):
server: Optional[str] = None
server_info: Optional[str] = None
operator: bool = False
server: Optional[str] = None
server_info: Optional[str] = None
operator: bool = False
secure: bool = False
secure: bool = False
signon: Optional[int] = None
idle: Optional[int] = None
signon: Optional[int] = None
idle: Optional[int] = None
channels: Optional[List[ChannelUser]] = None
channels: Optional[List[ChannelUser]] = None
nickname: str = ""
username: str = ""
hostname: str = ""
realname: str = ""
account: Optional[str] = None
account: Optional[str] = None

View File

@ -1,12 +1,12 @@
from hashlib import sha512
from ssl import SSLContext
from typing import Optional, Tuple
from asyncio import StreamReader, StreamWriter
from hashlib import sha512
from ssl import SSLContext
from typing import Optional, Tuple
from asyncio import StreamReader, StreamWriter
from async_stagger import open_connection
from .interface import ITCPTransport, ITCPReader, ITCPWriter
from .security import (tls_context, TLS, TLSNoVerify, TLSVerifyHash,
TLSVerifySHA512)
from .security import tls_context, TLS, TLSNoVerify, TLSVerifyHash, TLSVerifySHA512
class TCPReader(ITCPReader):
def __init__(self, reader: StreamReader):
@ -14,6 +14,8 @@ class TCPReader(ITCPReader):
async def read(self, byte_count: int) -> bytes:
return await self._reader.read(byte_count)
class TCPWriter(ITCPWriter):
def __init__(self, writer: StreamWriter):
self._writer = writer
@ -32,13 +34,15 @@ class TCPWriter(ITCPWriter):
self._writer.close()
await self._writer.wait_closed()
class TCPTransport(ITCPTransport):
async def connect(self,
hostname: str,
port: int,
tls: Optional[TLS],
bindhost: Optional[str]=None
) -> Tuple[ITCPReader, ITCPWriter]:
async def connect(
self,
hostname: str,
port: int,
tls: Optional[TLS],
bindhost: Optional[str] = None,
) -> Tuple[ITCPReader, ITCPWriter]:
cur_ssl: Optional[SSLContext] = None
if tls is not None:
@ -54,22 +58,20 @@ class TCPTransport(ITCPTransport):
hostname,
port,
server_hostname=server_hostname,
ssl =cur_ssl,
local_addr =local_addr)
ssl=cur_ssl,
local_addr=local_addr,
)
if isinstance(tls, TLSVerifyHash):
cert: bytes = writer.transport.get_extra_info(
"ssl_object"
).getpeercert(True)
cert: bytes = writer.transport.get_extra_info("ssl_object").getpeercert(
True
)
if isinstance(tls, TLSVerifySHA512):
sum = sha512(cert).hexdigest()
else:
raise ValueError(f"unknown hash pinning {type(tls)}")
if not sum == tls.sum:
raise ValueError(
f"pinned hash for {hostname} does not match ({sum})"
)
raise ValueError(f"pinned hash for {hostname} does not match ({sum})")
return (TCPReader(reader), TCPWriter(writer))

View File

@ -24,8 +24,8 @@ setup(
"Operating System :: OS Independent",
"Operating System :: POSIX",
"Operating System :: Microsoft :: Windows",
"Topic :: Communications :: Chat :: Internet Relay Chat"
"Topic :: Communications :: Chat :: Internet Relay Chat",
],
python_requires='>=3.7',
install_requires=install_requires
python_requires=">=3.7",
install_requires=install_requires,
)

View File

@ -1,6 +1,7 @@
import unittest
from ircrobots import glob
class GlobTestCollapse(unittest.TestCase):
def test(self):
c1 = glob.collapse("**?*")