Compare commits

...

58 Commits

Author SHA1 Message Date
jesopo 7c9a144124 v0.6.6 release 2023-08-17 22:48:21 +00:00
jesopo e3c91a50e1 update ircstates 2023-08-17 22:47:57 +00:00
jesopo f2ba48a582 v0.6.5 release 2023-07-06 00:57:08 +00:00
jesopo cf2e69a9e2 asyncio.wait(..) now requires Tasks 2023-07-06 00:56:45 +00:00
jesopo a1a459c13e v0.6.4 release 2023-07-06 00:44:25 +00:00
jesopo 81fa77cf29 missed some TLS_ uses 2023-07-06 00:44:13 +00:00
jesopo 422a9a93c1 v0.6.3 release 2023-07-06 00:35:47 +00:00
jesopo b04a0e0136 python no longer likes having mutables in non-default_factory 2023-07-06 00:35:13 +00:00
jesopo 7bb4c3d069 v0.6.2 release 2023-02-06 19:43:14 +00:00
jesopo 9a2f2156fe support specifying tls client keypair 2023-02-06 19:42:27 +00:00
alicetries 0435404ec3 Small tweak to how repr() of Formatless() displays 2022-03-28 23:43:27 +01:00
jesopo 63025af311 v0.6.1 release 2022-02-19 13:51:05 +00:00
jesopo 20c4f8f98c upgrade async-timeout to v4.0.2 2022-02-19 13:49:21 +00:00
jesopo 0ce3b9b0b0 v0.6.0 release 2022-01-24 10:01:51 +00:00
jesopo 5b347f95c9 combine params.tls and .tls_verify, support pinned certs 2022-01-24 09:53:32 +00:00
jesopo 0a5c774965 v0.5.0 release 2022-01-20 21:28:23 +00:00
jesopo 8245a411c0 hmm no this isnt how you ask for cert validation apparently 2022-01-20 21:24:54 +00:00
jesopo 80b941fa53 handle ERR_UNAVAILRESOURCE for prereg NICK failure too 2022-01-16 15:24:22 +00:00
jesopo b7019d35c1 upgrade ircstates 2022-01-07 19:04:48 +00:00
jesopo 66358f77e3 v0.4.7 release 2022-01-07 11:45:32 +00:00
jesopo fcd2f5b1b2 upgrade ircstates 2022-01-07 11:43:30 +00:00
jesopo 3e18deef86 we don't support py3.6; support py3.9 2022-01-07 11:41:35 +00:00
jesopo 9ba5b2b90f add `transport` (ITCPTransport) param to bot.add_server 2021-12-18 16:48:01 +00:00
jesopo 025fde97ee v0.4.6 release 2021-12-09 23:53:14 +00:00
jesopo 05750f00d9 make sure 'tls' is defined 2021-12-09 23:49:40 +00:00
jesopo ac4c144d58 v0.4.5 release 2021-11-29 16:11:54 +00:00
jesopo 6c91ebc7ec add ConnectionParams.from_hoststring("nick", "host:+port") 2021-11-29 16:09:26 +00:00
jesopo 0edcbfa234 v0.4.4 release 2021-09-19 21:36:20 +00:00
jesopo 7b6a845927 don't infinitely loop SASLUserPass attempts on FAIL or ABORT 2021-09-19 21:34:57 +00:00
jesopo dfd78b3d3e v0.4.3 release 2021-09-19 21:32:02 +00:00
jesopo ab65e39ab9 handle ERR_SASLABORTED 2021-09-18 17:34:52 +00:00
jesopo 9ca1ec21c9 v0.4.2 release 2021-09-18 17:15:53 +00:00
jesopo a03f11449c upgrade ircstates to v0.11.10 2021-09-18 17:11:40 +00:00
jesopo bb87c86b37 v0.4.1 release 2021-09-11 15:44:15 +00:00
jesopo 8ee692f1be upgrade ircstates to 0.11.9 2021-09-11 15:43:40 +00:00
jesopo c7604686a2 channel_user.modes is now a set 2021-09-11 15:42:47 +00:00
jesopo 64935c7a8d react to pre-reg ERR_ERRONEUSNICKNAME the same as ERR_NICKNAMEINUSE 2021-09-11 15:40:18 +00:00
jesopo fb93d59c43 v0.4.0 release 2021-06-26 15:11:32 +00:00
jesopo ab17645d83 catch reconnection failures, do exponential backoff 2021-06-26 15:08:48 +00:00
jesopo 8d3681eba1 freenode is dead long live libera.chat 2021-05-24 18:08:26 +00:00
jesopo 930342d74f v0.3.14 release 2021-05-22 08:43:50 +00:00
jesopo dd41b0dbde parse tokens in wait_for - waity things expect state change 2021-05-22 08:43:11 +00:00
jesopo f22471993a v0.3.13 release 2021-05-12 12:35:39 +00:00
jesopo 6fddfb7fe9 reset ping_sent in wait_for too 2021-05-12 12:34:06 +00:00
jesopo b4eaf6c24c v0.3.12 release 2021-05-12 11:56:24 +00:00
jesopo bdfb91b51d invert ping check 2021-05-12 11:52:33 +00:00
jesopo a14c7c34a2 v0.3.11 release 2021-05-12 11:28:27 +00:00
jesopo 3574868458 reset ping timer when we read a line 2021-05-12 11:24:54 +00:00
jesopo 0253aba99e v0.3.10 release 2021-05-12 10:58:51 +00:00
jesopo bfb5b4ec61 v0.3.9 release 2021-05-12 10:54:47 +00:00
jesopo 6a05370a12 simplify wait_for 2021-05-12 10:52:39 +00:00
jesopo 90fb4b7bba v0.3.8 release 2021-04-10 13:55:04 +00:00
jesopo d0c6b4a43d update ircstates to 0.11.8 2021-04-10 13:54:00 +00:00
jesopo fc0e8470cc change pre-001 throttle to 100 lines in 1 second 2021-03-26 12:35:02 +00:00
jesopo d0e0314169 _check_regain wants a string list, not a string 2020-12-20 00:42:18 +00:00
jesopo a15e2bd1fb "001" literal -> RPL_WELCOME 2020-12-20 00:40:31 +00:00
jesopo 7a59ece687 try to regain nick on servers that have WATCH or MONITOR 2020-12-20 00:40:11 +00:00
jesopo e7779bcf17 update ircstates to v0.11.7 2020-12-20 00:39:26 +00:00
18 changed files with 215 additions and 142 deletions

View File

@ -3,7 +3,7 @@ cache: pip
python: python:
- "3.7" - "3.7"
- "3.8" - "3.8"
- "3.8-dev" - "3.9"
install: install:
- pip3 install mypy -r requirements.txt - pip3 install mypy -r requirements.txt
script: script:

View File

@ -11,4 +11,4 @@ see [examples/](examples/) for some usage demonstration.
## contact ## contact
Come say hi at [##irctokens on freenode](https://webchat.freenode.net/?channels=%23%23irctokens) Come say hi at `#irctokens` on irc.libera.chat

View File

@ -1 +1 @@
0.3.7 0.6.6

View File

@ -154,8 +154,8 @@ async def main(hostname: str, channel: str, nickname: str):
params = ConnectionParams( params = ConnectionParams(
nickname, nickname,
hostname, hostname,
6697, 6697
tls=True) )
await bot.add_server("freenode", params) await bot.add_server("freenode", params)
await bot.run() await bot.run()

View File

@ -23,7 +23,6 @@ async def main():
"MyNickname", "MyNickname",
host = "chat.freenode.invalid", host = "chat.freenode.invalid",
port = 6697, port = 6697,
tls = True,
sasl = sasl_params) sasl = sasl_params)
await bot.add_server("freenode", params) await bot.add_server("freenode", params)

View File

@ -25,7 +25,7 @@ class Bot(BaseBot):
async def main(): async def main():
bot = Bot() bot = Bot()
for name, host in SERVERS: for name, host in SERVERS:
params = ConnectionParams("BitBotNewTest", host, 6697, True) params = ConnectionParams("BitBotNewTest", host, 6697)
await bot.add_server(name, params) await bot.add_server(name, params)
await bot.run() await bot.run()

View File

@ -3,3 +3,4 @@ from .server import Server
from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM, from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM,
STSPolicy, ResumePolicy) STSPolicy, ResumePolicy)
from .ircv3 import Capability from .ircv3 import Capability
from .security import TLS

View File

@ -1,4 +1,4 @@
import asyncio import asyncio, traceback
import anyio import anyio
from typing import Dict from typing import Dict
@ -6,32 +6,45 @@ from ircstates.server import ServerDisconnectedException
from .server import ConnectionParams, Server from .server import ConnectionParams, Server
from .transport import TCPTransport from .transport import TCPTransport
from .interface import IBot, IServer from .interface import IBot, IServer, ITCPTransport
class Bot(IBot): class Bot(IBot):
def __init__(self): def __init__(self):
self.servers: Dict[str, Server] = {} self.servers: Dict[str, Server] = {}
self._server_queue: asyncio.Queue[Server] = asyncio.Queue() self._server_queue: asyncio.Queue[Server] = asyncio.Queue()
# methods designed to be overridden
def create_server(self, name: str): def create_server(self, name: str):
return Server(self, name) return Server(self, name)
async def disconnected(self, server: IServer): async def disconnected(self, server: IServer):
if (server.name in self.servers and if (server.name in self.servers and
server.params is not None and server.params is not None and
server.disconnected): server.disconnected):
await asyncio.sleep(server.params.reconnect)
await self.add_server(server.name, server.params) reconnect = server.params.reconnect
# /methods designed to be overridden
while True:
await asyncio.sleep(reconnect)
try:
await self.add_server(server.name, server.params)
except Exception as e:
traceback.print_exc()
# let's try again, exponential backoff up to 5 mins
reconnect = min(reconnect*2, 300)
else:
break
async def disconnect(self, server: IServer): async def disconnect(self, server: IServer):
await server.disconnect()
del self.servers[server.name] del self.servers[server.name]
await server.disconnect()
async def add_server(self, name: str, params: ConnectionParams) -> Server: async def add_server(self,
name: str,
params: ConnectionParams,
transport: ITCPTransport = TCPTransport()) -> Server:
server = self.create_server(name) server = self.create_server(name)
self.servers[name] = server self.servers[name] = server
await server.connect(TCPTransport(), params) await server.connect(transport, params)
await self._server_queue.put(server) await self._server_queue.put(server)
return server return server

View File

@ -6,6 +6,7 @@ from ircstates import Server, Emit
from irctokens import Line, Hostmask from irctokens import Line, Hostmask
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
from .security import TLS
class ITCPReader(object): class ITCPReader(object):
async def read(self, byte_count: int): async def read(self, byte_count: int):
@ -24,11 +25,10 @@ class ITCPWriter(object):
class ITCPTransport(object): class ITCPTransport(object):
async def connect(self, async def connect(self,
hostname: str, hostname: str,
port: int, port: int,
tls: bool, tls: Optional[TLS],
tls_verify: bool=True, bindhost: Optional[str]=None
bindhost: Optional[str]=None
) -> Tuple[ITCPReader, ITCPWriter]: ) -> Tuple[ITCPReader, ITCPWriter]:
pass pass

View File

@ -8,6 +8,7 @@ from .contexts import ServerContext
from .matching import Response, ANY from .matching import Response, ANY
from .interface import ICapability from .interface import ICapability
from .params import ConnectionParams, STSPolicy, ResumePolicy from .params import ConnectionParams, STSPolicy, ResumePolicy
from .security import TLSVerifyChain
class Capability(ICapability): class Capability(ICapability):
def __init__(self, def __init__(self,
@ -101,12 +102,12 @@ def _cap_dict(s: str) -> Dict[str, str]:
return d return d
async def sts_transmute(params: ConnectionParams): async def sts_transmute(params: ConnectionParams):
if not params.sts is None and not params.tls: if not params.sts is None and params.tls is None:
now = time() now = time()
since = (now-params.sts.created) since = (now-params.sts.created)
if since <= params.sts.duration: if since <= params.sts.duration:
params.port = params.sts.port params.port = params.sts.port
params.tls = True params.tls = TLSVerifyChain()
async def resume_transmute(params: ConnectionParams): async def resume_transmute(params: ConnectionParams):
if params.resume is not None: if params.resume is not None:
params.host = params.resume.address params.host = params.resume.address
@ -182,7 +183,7 @@ class CAPContext(ServerContext):
if not params.tls: if not params.tls:
if "port" in sts_dict: if "port" in sts_dict:
params.port = int(sts_dict["port"]) params.port = int(sts_dict["port"])
params.tls = True params.tls = TLSVerifyChain()
await self.server.bot.disconnect(self.server) await self.server.bot.disconnect(self.server)
await self.server.bot.add_server(self.server.name, params) await self.server.bot.add_server(self.server.name, params)

View File

@ -73,8 +73,7 @@ class Formatless(IMatchResponseParam):
def __init__(self, value: TYPE_MAYBELIT_VALUE): def __init__(self, value: TYPE_MAYBELIT_VALUE):
self._value = _assure_lit(value) self._value = _assure_lit(value)
def __repr__(self) -> str: def __repr__(self) -> str:
brepr = super().__repr__() return f"Formatless({self._value!r})"
return f"Formatless({brepr})"
def match(self, server: IServer, arg: str) -> bool: def match(self, server: IServer, arg: str) -> bool:
strip = formatting.strip(arg) strip = formatting.strip(arg)
return self._value.match(server, strip) return self._value.match(server, strip)

View File

@ -1,6 +1,9 @@
from re import compile as re_compile
from typing import List, Optional from typing import List, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from .security import TLS, TLSNoVerify, TLSVerifyChain
class SASLParams(object): class SASLParams(object):
mechanism: str mechanism: str
@ -28,19 +31,24 @@ class ResumePolicy(object):
address: str address: str
token: str token: str
RE_IPV6HOST = re_compile("\[([a-fA-F0-9:]+)\]")
_TLS_TYPES = {
"+": TLSVerifyChain,
"~": TLSNoVerify,
}
@dataclass @dataclass
class ConnectionParams(object): class ConnectionParams(object):
nickname: str nickname: str
host: str host: str
port: int port: int
tls: bool tls: Optional[TLS] = field(default_factory=TLSVerifyChain)
username: Optional[str] = None username: Optional[str] = None
realname: Optional[str] = None realname: Optional[str] = None
bindhost: Optional[str] = None bindhost: Optional[str] = None
password: Optional[str] = None password: Optional[str] = None
tls_verify: bool = True
sasl: Optional[SASLParams] = None sasl: Optional[SASLParams] = None
sts: Optional[STSPolicy] = None sts: Optional[STSPolicy] = None
@ -50,3 +58,26 @@ class ConnectionParams(object):
alt_nicknames: List[str] = field(default_factory=list) alt_nicknames: List[str] = field(default_factory=list)
autojoin: List[str] = field(default_factory=list) autojoin: List[str] = field(default_factory=list)
@staticmethod
def from_hoststring(
nickname: str,
hoststring: str
) -> "ConnectionParams":
ipv6host = RE_IPV6HOST.search(hoststring)
if ipv6host is not None and ipv6host.start() == 0:
host = ipv6host.group(1)
port_s = hoststring[ipv6host.end()+1:]
else:
host, _, port_s = hoststring.strip().partition(":")
tls_type: Optional[TLS] = None
if not port_s:
port_s = "6667"
else:
tls_type = _TLS_TYPES.get(port_s[0], lambda: None)()
if tls_type is not None:
port_s = port_s[1:] or "6697"
return ConnectionParams(nickname, host, int(port_s), tls_type)

View File

@ -32,7 +32,9 @@ AUTH_BYTE_MAX = 400
AUTHENTICATE_ANY = Response("AUTHENTICATE", [ANY]) AUTHENTICATE_ANY = Response("AUTHENTICATE", [ANY])
NUMERICS_FAIL = Response(ERR_SASLFAIL) NUMERICS_FAIL = Response(ERR_SASLFAIL)
NUMERICS_INITIAL = Responses([ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS]) NUMERICS_INITIAL = Responses([
ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS, ERR_SASLABORTED
])
NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL]) NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL])
def _b64e(s: str): def _b64e(s: str):
@ -150,6 +152,8 @@ class SASLContext(ServerContext):
return SASLResult.SUCCESS return SASLResult.SUCCESS
elif line.command == "904": elif line.command == "904":
match.pop(0) match.pop(0)
else:
break
return SASLResult.FAILURE return SASLResult.FAILURE

View File

@ -1,13 +1,29 @@
import ssl import ssl
from dataclasses import dataclass
from typing import Optional, Tuple
@dataclass
class TLS:
client_keypair: Optional[Tuple[str, str]] = None
# tls without verification
class TLSNoVerify(TLS):
pass
# verify via CAs
class TLSVerifyChain(TLS):
pass
# verify by a pinned hash
class TLSVerifyHash(TLSNoVerify):
def __init__(self, sum: str):
self.sum = sum.lower()
class TLSVerifySHA512(TLSVerifyHash):
pass
def tls_context(verify: bool=True) -> ssl.SSLContext: def tls_context(verify: bool=True) -> ssl.SSLContext:
context = ssl.SSLContext(ssl.PROTOCOL_TLS) ctx = ssl.create_default_context()
context.options |= ssl.OP_NO_SSLv2 if not verify:
context.options |= ssl.OP_NO_SSLv3 ctx.check_hostname = False
context.options |= ssl.OP_NO_TLSv1 ctx.verify_mode = ssl.CERT_NONE
context.load_default_certs() return ctx
if verify:
context.verify_mode = ssl.CERT_REQUIRED
return context

View File

@ -6,6 +6,7 @@ from collections import deque
from time import monotonic from time import monotonic
import anyio import anyio
from asyncio_rlock import RLock
from asyncio_throttle import Throttler from asyncio_throttle import Throttler
from async_timeout import timeout as timeout_ from async_timeout import timeout as timeout_
from ircstates import Emit, Channel, ChannelUser from ircstates import Emit, Channel, ChannelUser
@ -54,8 +55,7 @@ class Server(IServer):
self.disconnected = False self.disconnected = False
self.throttle = Throttler( self.throttle = Throttler(rate_limit=100, period=1)
rate_limit=100, period=THROTTLE_TIME)
self.sasl_state = SASLResult.NONE self.sasl_state = SASLResult.NONE
self.last_read = monotonic() self.last_read = monotonic()
@ -64,10 +64,14 @@ class Server(IServer):
self._send_queue: PriorityQueue[SentLine] = PriorityQueue() self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
self.desired_caps: Set[ICapability] = set([]) self.desired_caps: Set[ICapability] = set([])
self._read_queue: Deque[Line] = deque() self._read_queue: Deque[Line] = deque()
self._process_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
self._wait_for: Optional[Tuple[Awaitable, WaitFor]] = None self._ping_sent = False
self._wait_for_fut: Optional[Future[WaitFor]] = None self._read_lguard = RLock()
self.read_lock = self._read_lguard
self._read_lwork = asyncio.Lock()
self._wait_for = asyncio.Event()
self._pending_who: Deque[str] = deque() self._pending_who: Deque[str] = deque()
self._alt_nicks: List[str] = [] self._alt_nicks: List[str] = []
@ -120,9 +124,8 @@ class Server(IServer):
reader, writer = await transport.connect( reader, writer = await transport.connect(
params.host, params.host,
params.port, params.port,
tls =params.tls, tls =params.tls,
tls_verify=params.tls_verify, bindhost =params.bindhost)
bindhost =params.bindhost)
self._reader = reader self._reader = reader
self._writer = writer self._writer = writer
@ -177,17 +180,35 @@ class Server(IServer):
self._pending_who[0] == chan): self._pending_who[0] == chan):
self._pending_who.popleft() self._pending_who.popleft()
await self._next_who() await self._next_who()
elif (line.command in {
elif (line.command == ERR_NICKNAMEINUSE and ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE
not self.registered): } and not self.registered):
if self._alt_nicks: if self._alt_nicks:
nick = self._alt_nicks.pop(0) nick = self._alt_nicks.pop(0)
await self.send(build("NICK", [nick])) await self.send(build("NICK", [nick]))
else: else:
await self.send(build("QUIT")) await self.send(build("QUIT"))
elif line.command in [RPL_ENDOFMOTD, ERR_NOMOTD]:
# we didn't get the nickname we wanted. watch for it if we can
if not self.nickname == self.params.nickname:
target = self.params.nickname
if self.isupport.monitor is not None:
await self.send(build("MONITOR", ["+", target]))
elif self.isupport.watch is not None:
await self.send(build("WATCH", [f"+{target}"]))
# has someone just stopped using the nickname we want?
elif line.command == RPL_LOGOFF:
await self._check_regain([line.params[1]])
elif line.command == RPL_MONOFFLINE:
await self._check_regain(line.params[1].split(","))
elif (line.command in ["NICK", "QUIT"] and
line.source is not None):
await self._check_regain([line.hostmask.nickname])
elif emit is not None: elif emit is not None:
if emit.command == "001": if emit.command == RPL_WELCOME:
await self.send(build("WHO", [self.nickname])) await self.send(build("WHO", [self.nickname]))
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME) self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
@ -218,6 +239,12 @@ class Server(IServer):
await self.line_read(line) await self.line_read(line)
async def _check_regain(self, nicks: List[str]):
for nick in nicks:
if (self.casefold_equals(nick, self.params.nickname) and
not self.nickname == self.params.nickname):
await self.send(build("NICK", [self.params.nickname]))
async def _batch_joins(self, async def _batch_joins(self,
channels: List[str], channels: List[str],
batch_n: int=10): batch_n: int=10):
@ -250,76 +277,43 @@ class Server(IServer):
self.last_read = monotonic() self.last_read = monotonic()
lines = self.recv(data) lines = self.recv(data)
for line in lines: for line in lines:
self.line_preread(line)
self._read_queue.append(line) self._read_queue.append(line)
async def _line_or_wait(self,
line_aw: asyncio.Task
) -> Optional[Tuple[Awaitable, WaitFor]]:
wait_for_fut: Future[WaitFor] = Future()
self._wait_for_fut = wait_for_fut
done, pend = await asyncio.wait([line_aw, wait_for_fut],
return_when=asyncio.FIRST_COMPLETED)
self._wait_for_fut = None
if wait_for_fut.done():
new_line_aw = list(pend)[0]
return (new_line_aw, wait_for_fut.result())
else:
return None
async def _read_lines(self): async def _read_lines(self):
waiting_lines: List[Tuple[Line, Optional[Emit]]] = []
sent_ping = False
while True: while True:
now = monotonic() async with self._read_lguard:
timeouts: List[float] = [] pass
timeouts.append((self.last_read+PING_TIMEOUT)-now)
if self._wait_for is not None:
_, wait_for = self._wait_for
timeouts.append(wait_for.deadline-now)
line = await self._read_line(max([0.1, min(timeouts)])) if not self._process_queue:
if line is None: async with self._read_lwork:
now = monotonic() read_aw = asyncio.create_task(self._read_line(PING_TIMEOUT))
since = now-self.last_read wait_aw = asyncio.create_task(self._wait_for.wait())
dones, notdones = await asyncio.wait(
[read_aw, wait_aw],
return_when=asyncio.FIRST_COMPLETED
)
self._wait_for.clear()
if self._wait_for is not None: for done in dones:
aw, wait_for = self._wait_for if isinstance(done.result(), Line):
if wait_for.deadline <= now: self._ping_sent = False
self._wait_for = None line = done.result()
await aw emit = self.parse_tokens(line)
self._process_queue.append((line, emit))
elif done.result() is None:
if not self._ping_sent:
await self.send(build("PING", ["hello"]))
self._ping_sent = True
else:
await self.disconnect()
raise ServerDisconnectedException()
for notdone in notdones:
notdone.cancel()
if since >= PING_TIMEOUT:
if since >= (PING_TIMEOUT*2):
raise ServerDisconnectedException()
elif not sent_ping:
sent_ping = True
await self.send(build("PING", ["hello"]))
continue
else: else:
sent_ping = False line, emit = self._process_queue.popleft()
emit = self.parse_tokens(line) await self._on_read(line, emit)
waiting_lines.append((line, emit))
self.line_preread(line)
if self._wait_for is not None:
aw, wait_for = self._wait_for
if wait_for.match(self, line):
wait_for.resolve(line)
self._wait_for = await self._line_or_wait(aw)
if self._wait_for is not None:
continue
else:
continue
for i in range(len(waiting_lines)):
line, emit = waiting_lines.pop(0)
line_aw = self._on_read(line, emit)
self._wait_for = await self._line_or_wait(line_aw)
if self._wait_for is not None:
break
async def wait_for(self, async def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]], response: Union[IMatchResponse, Set[IMatchResponse]],
@ -333,22 +327,18 @@ class Server(IServer):
else: else:
response_obj = response response_obj = response
deadline = monotonic()+timeout async with self._read_lguard:
our_wait_for = WaitFor(response_obj, deadline) self._wait_for.set()
if self._wait_for_fut is not None: async with self._read_lwork:
self._wait_for_fut.set_result(our_wait_for) async with timeout_(timeout):
else: while True:
cur_task = asyncio.current_task() line = await self._read_line(timeout)
if cur_task is not None: if line:
self._wait_for = (cur_task, our_wait_for) self._ping_sent = False
emit = self.parse_tokens(line)
if sent_aw is not None: self._process_queue.append((line, emit))
sent_line = await sent_aw if response_obj.match(self, line):
label = str(sent_line.id) return line
our_wait_for.with_label(label)
async with timeout_(timeout):
return (await our_wait_for)
async def _on_send_line(self, line: Line): async def _on_send_line(self, line: Line):
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and
@ -564,7 +554,7 @@ class Server(IServer):
for symbol in symbols: for symbol in symbols:
mode = self.isupport.prefix.from_prefix(symbol) mode = self.isupport.prefix.from_prefix(symbol)
if mode is not None: if mode is not None:
channel_user.modes.append(mode) channel_user.modes.add(mode)
obj.channels.append(channel_user) obj.channels.append(channel_user)
elif line.command == RPL_ENDOFWHOIS: elif line.command == RPL_ENDOFWHOIS:

View File

@ -1,10 +1,12 @@
from hashlib import sha512
from ssl import SSLContext from ssl import SSLContext
from typing import Optional, Tuple from typing import Optional, Tuple
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
from async_stagger import open_connection from async_stagger import open_connection
from .interface import ITCPTransport, ITCPReader, ITCPWriter from .interface import ITCPTransport, ITCPReader, ITCPWriter
from .security import tls_context from .security import (tls_context, TLS, TLSNoVerify, TLSVerifyHash,
TLSVerifySHA512)
class TCPReader(ITCPReader): class TCPReader(ITCPReader):
def __init__(self, reader: StreamReader): def __init__(self, reader: StreamReader):
@ -32,16 +34,18 @@ class TCPWriter(ITCPWriter):
class TCPTransport(ITCPTransport): class TCPTransport(ITCPTransport):
async def connect(self, async def connect(self,
hostname: str, hostname: str,
port: int, port: int,
tls: bool, tls: Optional[TLS],
tls_verify: bool=True, bindhost: Optional[str]=None
bindhost: Optional[str]=None
) -> Tuple[ITCPReader, ITCPWriter]: ) -> Tuple[ITCPReader, ITCPWriter]:
cur_ssl: Optional[SSLContext] = None cur_ssl: Optional[SSLContext] = None
if tls: if tls is not None:
cur_ssl = tls_context(tls_verify) cur_ssl = tls_context(not isinstance(tls, TLSNoVerify))
if tls.client_keypair is not None:
(client_cert, client_key) = tls.client_keypair
cur_ssl.load_cert_chain(client_cert, keyfile=client_key)
local_addr: Optional[Tuple[str, int]] = None local_addr: Optional[Tuple[str, int]] = None
if not bindhost is None: if not bindhost is None:
@ -55,5 +59,20 @@ class TCPTransport(ITCPTransport):
server_hostname=server_hostname, server_hostname=server_hostname,
ssl =cur_ssl, ssl =cur_ssl,
local_addr =local_addr) local_addr =local_addr)
if isinstance(tls, TLSVerifyHash):
cert: bytes = writer.transport.get_extra_info(
"ssl_object"
).getpeercert(True)
if isinstance(tls, TLSVerifySHA512):
sum = sha512(cert).hexdigest()
else:
raise ValueError(f"unknown hash pinning {type(tls)}")
if not sum == tls.sum:
raise ValueError(
f"pinned hash for {hostname} does not match ({sum})"
)
return (TCPReader(reader), TCPWriter(writer)) return (TCPReader(reader), TCPWriter(writer))

View File

@ -1,6 +1,6 @@
anyio ~=2.0.2 anyio ~=2.0.2
asyncio-rlock ~=0.1.0
asyncio-throttle ~=1.0.1 asyncio-throttle ~=1.0.1
dataclasses ~=0.6; python_version<"3.7" ircstates ~=0.12.1
ircstates ~=0.11.6
async_stagger ~=0.3.0 async_stagger ~=0.3.0
async_timeout ~=3.0.1 async_timeout ~=4.0.2

View File

@ -26,6 +26,6 @@ setup(
"Operating System :: Microsoft :: Windows", "Operating System :: Microsoft :: Windows",
"Topic :: Communications :: Chat :: Internet Relay Chat" "Topic :: Communications :: Chat :: Internet Relay Chat"
], ],
python_requires='>=3.6', python_requires='>=3.7',
install_requires=install_requires install_requires=install_requires
) )