mirror of https://github.com/jesopo/ircrobots
Compare commits
67 Commits
Author | SHA1 | Date |
---|---|---|
jesopo | 7c9a144124 | |
jesopo | e3c91a50e1 | |
jesopo | f2ba48a582 | |
jesopo | cf2e69a9e2 | |
jesopo | a1a459c13e | |
jesopo | 81fa77cf29 | |
jesopo | 422a9a93c1 | |
jesopo | b04a0e0136 | |
jesopo | 7bb4c3d069 | |
jesopo | 9a2f2156fe | |
alicetries | 0435404ec3 | |
jesopo | 63025af311 | |
jesopo | 20c4f8f98c | |
jesopo | 0ce3b9b0b0 | |
jesopo | 5b347f95c9 | |
jesopo | 0a5c774965 | |
jesopo | 8245a411c0 | |
jesopo | 80b941fa53 | |
jesopo | b7019d35c1 | |
jesopo | 66358f77e3 | |
jesopo | fcd2f5b1b2 | |
jesopo | 3e18deef86 | |
jesopo | 9ba5b2b90f | |
jesopo | 025fde97ee | |
jesopo | 05750f00d9 | |
jesopo | ac4c144d58 | |
jesopo | 6c91ebc7ec | |
jesopo | 0edcbfa234 | |
jesopo | 7b6a845927 | |
jesopo | dfd78b3d3e | |
jesopo | ab65e39ab9 | |
jesopo | 9ca1ec21c9 | |
jesopo | a03f11449c | |
jesopo | bb87c86b37 | |
jesopo | 8ee692f1be | |
jesopo | c7604686a2 | |
jesopo | 64935c7a8d | |
jesopo | fb93d59c43 | |
jesopo | ab17645d83 | |
jesopo | 8d3681eba1 | |
jesopo | 930342d74f | |
jesopo | dd41b0dbde | |
jesopo | f22471993a | |
jesopo | 6fddfb7fe9 | |
jesopo | b4eaf6c24c | |
jesopo | bdfb91b51d | |
jesopo | a14c7c34a2 | |
jesopo | 3574868458 | |
jesopo | 0253aba99e | |
jesopo | bfb5b4ec61 | |
jesopo | 6a05370a12 | |
jesopo | 90fb4b7bba | |
jesopo | d0c6b4a43d | |
jesopo | fc0e8470cc | |
jesopo | d0e0314169 | |
jesopo | a15e2bd1fb | |
jesopo | 7a59ece687 | |
jesopo | e7779bcf17 | |
jesopo | 04b44e2e94 | |
jesopo | 69e303dfa9 | |
jesopo | def58730bc | |
jesopo | 4f5fd90ca5 | |
jesopo | efc280b2e9 | |
jesopo | 834ca4b817 | |
jesopo | 48b0748b92 | |
jesopo | bd4758e97c | |
jesopo | 805b247375 |
|
@ -3,7 +3,7 @@ cache: pip
|
|||
python:
|
||||
- "3.7"
|
||||
- "3.8"
|
||||
- "3.8-dev"
|
||||
- "3.9"
|
||||
install:
|
||||
- pip3 install mypy -r requirements.txt
|
||||
script:
|
||||
|
|
|
@ -11,4 +11,4 @@ see [examples/](examples/) for some usage demonstration.
|
|||
|
||||
## contact
|
||||
|
||||
Come say hi at [##irctokens on freenode](https://webchat.freenode.net/?channels=%23%23irctokens)
|
||||
Come say hi at `#irctokens` on irc.libera.chat
|
||||
|
|
|
@ -154,8 +154,8 @@ async def main(hostname: str, channel: str, nickname: str):
|
|||
params = ConnectionParams(
|
||||
nickname,
|
||||
hostname,
|
||||
6697,
|
||||
tls=True)
|
||||
6697
|
||||
)
|
||||
await bot.add_server("freenode", params)
|
||||
await bot.run()
|
||||
|
||||
|
|
|
@ -23,7 +23,6 @@ async def main():
|
|||
"MyNickname",
|
||||
host = "chat.freenode.invalid",
|
||||
port = 6697,
|
||||
tls = True,
|
||||
sasl = sasl_params)
|
||||
|
||||
await bot.add_server("freenode", params)
|
||||
|
|
|
@ -25,7 +25,7 @@ class Bot(BaseBot):
|
|||
async def main():
|
||||
bot = Bot()
|
||||
for name, host in SERVERS:
|
||||
params = ConnectionParams("BitBotNewTest", host, 6697, True)
|
||||
params = ConnectionParams("BitBotNewTest", host, 6697)
|
||||
await bot.add_server(name, params)
|
||||
|
||||
await bot.run()
|
||||
|
|
|
@ -3,3 +3,4 @@ from .server import Server
|
|||
from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM,
|
||||
STSPolicy, ResumePolicy)
|
||||
from .ircv3 import Capability
|
||||
from .security import TLS
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import asyncio
|
||||
import asyncio, traceback
|
||||
import anyio
|
||||
from typing import Dict
|
||||
|
||||
|
@ -6,32 +6,45 @@ from ircstates.server import ServerDisconnectedException
|
|||
|
||||
from .server import ConnectionParams, Server
|
||||
from .transport import TCPTransport
|
||||
from .interface import IBot, IServer
|
||||
from .interface import IBot, IServer, ITCPTransport
|
||||
|
||||
class Bot(IBot):
|
||||
def __init__(self):
|
||||
self.servers: Dict[str, Server] = {}
|
||||
self._server_queue: asyncio.Queue[Server] = asyncio.Queue()
|
||||
|
||||
# methods designed to be overridden
|
||||
def create_server(self, name: str):
|
||||
return Server(self, name)
|
||||
|
||||
async def disconnected(self, server: IServer):
|
||||
if (server.name in self.servers and
|
||||
server.params is not None and
|
||||
server.disconnected):
|
||||
await asyncio.sleep(server.params.reconnect)
|
||||
await self.add_server(server.name, server.params)
|
||||
# /methods designed to be overridden
|
||||
|
||||
reconnect = server.params.reconnect
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(reconnect)
|
||||
try:
|
||||
await self.add_server(server.name, server.params)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
# let's try again, exponential backoff up to 5 mins
|
||||
reconnect = min(reconnect*2, 300)
|
||||
else:
|
||||
break
|
||||
|
||||
async def disconnect(self, server: IServer):
|
||||
await server.disconnect()
|
||||
del self.servers[server.name]
|
||||
await server.disconnect()
|
||||
|
||||
async def add_server(self, name: str, params: ConnectionParams) -> Server:
|
||||
async def add_server(self,
|
||||
name: str,
|
||||
params: ConnectionParams,
|
||||
transport: ITCPTransport = TCPTransport()) -> Server:
|
||||
server = self.create_server(name)
|
||||
self.servers[name] = server
|
||||
await server.connect(TCPTransport(), params)
|
||||
await server.connect(transport, params)
|
||||
await self._server_queue.put(server)
|
||||
return server
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ from ircstates import Server, Emit
|
|||
from irctokens import Line, Hostmask
|
||||
|
||||
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
|
||||
from .security import TLS
|
||||
|
||||
class ITCPReader(object):
|
||||
async def read(self, byte_count: int):
|
||||
|
@ -24,11 +25,10 @@ class ITCPWriter(object):
|
|||
|
||||
class ITCPTransport(object):
|
||||
async def connect(self,
|
||||
hostname: str,
|
||||
port: int,
|
||||
tls: bool,
|
||||
tls_verify: bool=True,
|
||||
bindhost: Optional[str]=None
|
||||
hostname: str,
|
||||
port: int,
|
||||
tls: Optional[TLS],
|
||||
bindhost: Optional[str]=None
|
||||
) -> Tuple[ITCPReader, ITCPWriter]:
|
||||
pass
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from .contexts import ServerContext
|
|||
from .matching import Response, ANY
|
||||
from .interface import ICapability
|
||||
from .params import ConnectionParams, STSPolicy, ResumePolicy
|
||||
from .security import TLSVerifyChain
|
||||
|
||||
class Capability(ICapability):
|
||||
def __init__(self,
|
||||
|
@ -101,12 +102,12 @@ def _cap_dict(s: str) -> Dict[str, str]:
|
|||
return d
|
||||
|
||||
async def sts_transmute(params: ConnectionParams):
|
||||
if not params.sts is None and not params.tls:
|
||||
if not params.sts is None and params.tls is None:
|
||||
now = time()
|
||||
since = (now-params.sts.created)
|
||||
if since <= params.sts.duration:
|
||||
params.port = params.sts.port
|
||||
params.tls = True
|
||||
params.tls = TLSVerifyChain()
|
||||
async def resume_transmute(params: ConnectionParams):
|
||||
if params.resume is not None:
|
||||
params.host = params.resume.address
|
||||
|
@ -182,7 +183,7 @@ class CAPContext(ServerContext):
|
|||
if not params.tls:
|
||||
if "port" in sts_dict:
|
||||
params.port = int(sts_dict["port"])
|
||||
params.tls = True
|
||||
params.tls = TLSVerifyChain()
|
||||
|
||||
await self.server.bot.disconnect(self.server)
|
||||
await self.server.bot.add_server(self.server.name, params)
|
||||
|
|
|
@ -73,8 +73,7 @@ class Formatless(IMatchResponseParam):
|
|||
def __init__(self, value: TYPE_MAYBELIT_VALUE):
|
||||
self._value = _assure_lit(value)
|
||||
def __repr__(self) -> str:
|
||||
brepr = super().__repr__()
|
||||
return f"Formatless({brepr})"
|
||||
return f"Formatless({self._value!r})"
|
||||
def match(self, server: IServer, arg: str) -> bool:
|
||||
strip = formatting.strip(arg)
|
||||
return self._value.match(server, strip)
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
from re import compile as re_compile
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .security import TLS, TLSNoVerify, TLSVerifyChain
|
||||
|
||||
class SASLParams(object):
|
||||
mechanism: str
|
||||
|
||||
|
@ -28,19 +31,24 @@ class ResumePolicy(object):
|
|||
address: str
|
||||
token: str
|
||||
|
||||
RE_IPV6HOST = re_compile("\[([a-fA-F0-9:]+)\]")
|
||||
|
||||
_TLS_TYPES = {
|
||||
"+": TLSVerifyChain,
|
||||
"~": TLSNoVerify,
|
||||
}
|
||||
@dataclass
|
||||
class ConnectionParams(object):
|
||||
nickname: str
|
||||
host: str
|
||||
port: int
|
||||
tls: bool
|
||||
tls: Optional[TLS] = field(default_factory=TLSVerifyChain)
|
||||
|
||||
username: Optional[str] = None
|
||||
realname: Optional[str] = None
|
||||
bindhost: Optional[str] = None
|
||||
|
||||
password: Optional[str] = None
|
||||
tls_verify: bool = True
|
||||
sasl: Optional[SASLParams] = None
|
||||
|
||||
sts: Optional[STSPolicy] = None
|
||||
|
@ -50,3 +58,26 @@ class ConnectionParams(object):
|
|||
alt_nicknames: List[str] = field(default_factory=list)
|
||||
|
||||
autojoin: List[str] = field(default_factory=list)
|
||||
|
||||
@staticmethod
|
||||
def from_hoststring(
|
||||
nickname: str,
|
||||
hoststring: str
|
||||
) -> "ConnectionParams":
|
||||
|
||||
ipv6host = RE_IPV6HOST.search(hoststring)
|
||||
if ipv6host is not None and ipv6host.start() == 0:
|
||||
host = ipv6host.group(1)
|
||||
port_s = hoststring[ipv6host.end()+1:]
|
||||
else:
|
||||
host, _, port_s = hoststring.strip().partition(":")
|
||||
|
||||
tls_type: Optional[TLS] = None
|
||||
if not port_s:
|
||||
port_s = "6667"
|
||||
else:
|
||||
tls_type = _TLS_TYPES.get(port_s[0], lambda: None)()
|
||||
if tls_type is not None:
|
||||
port_s = port_s[1:] or "6697"
|
||||
|
||||
return ConnectionParams(nickname, host, int(port_s), tls_type)
|
||||
|
|
|
@ -32,7 +32,9 @@ AUTH_BYTE_MAX = 400
|
|||
AUTHENTICATE_ANY = Response("AUTHENTICATE", [ANY])
|
||||
|
||||
NUMERICS_FAIL = Response(ERR_SASLFAIL)
|
||||
NUMERICS_INITIAL = Responses([ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS])
|
||||
NUMERICS_INITIAL = Responses([
|
||||
ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS, ERR_SASLABORTED
|
||||
])
|
||||
NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL])
|
||||
|
||||
def _b64e(s: str):
|
||||
|
@ -150,6 +152,8 @@ class SASLContext(ServerContext):
|
|||
return SASLResult.SUCCESS
|
||||
elif line.command == "904":
|
||||
match.pop(0)
|
||||
else:
|
||||
break
|
||||
|
||||
return SASLResult.FAILURE
|
||||
|
||||
|
|
|
@ -1,13 +1,29 @@
|
|||
import ssl
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@dataclass
|
||||
class TLS:
|
||||
client_keypair: Optional[Tuple[str, str]] = None
|
||||
|
||||
# tls without verification
|
||||
class TLSNoVerify(TLS):
|
||||
pass
|
||||
|
||||
# verify via CAs
|
||||
class TLSVerifyChain(TLS):
|
||||
pass
|
||||
|
||||
# verify by a pinned hash
|
||||
class TLSVerifyHash(TLSNoVerify):
|
||||
def __init__(self, sum: str):
|
||||
self.sum = sum.lower()
|
||||
class TLSVerifySHA512(TLSVerifyHash):
|
||||
pass
|
||||
|
||||
def tls_context(verify: bool=True) -> ssl.SSLContext:
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS)
|
||||
context.options |= ssl.OP_NO_SSLv2
|
||||
context.options |= ssl.OP_NO_SSLv3
|
||||
context.options |= ssl.OP_NO_TLSv1
|
||||
context.load_default_certs()
|
||||
|
||||
if verify:
|
||||
context.verify_mode = ssl.CERT_REQUIRED
|
||||
|
||||
return context
|
||||
ctx = ssl.create_default_context()
|
||||
if not verify:
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
return ctx
|
||||
|
|
|
@ -6,6 +6,7 @@ from collections import deque
|
|||
from time import monotonic
|
||||
|
||||
import anyio
|
||||
from asyncio_rlock import RLock
|
||||
from asyncio_throttle import Throttler
|
||||
from async_timeout import timeout as timeout_
|
||||
from ircstates import Emit, Channel, ChannelUser
|
||||
|
@ -54,8 +55,7 @@ class Server(IServer):
|
|||
|
||||
self.disconnected = False
|
||||
|
||||
self.throttle = Throttler(
|
||||
rate_limit=100, period=THROTTLE_TIME)
|
||||
self.throttle = Throttler(rate_limit=100, period=1)
|
||||
|
||||
self.sasl_state = SASLResult.NONE
|
||||
self.last_read = monotonic()
|
||||
|
@ -64,13 +64,17 @@ class Server(IServer):
|
|||
self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
|
||||
self.desired_caps: Set[ICapability] = set([])
|
||||
|
||||
self._read_queue: Deque[Line] = deque()
|
||||
self._read_queue: Deque[Line] = deque()
|
||||
self._process_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
|
||||
|
||||
self._wait_for: Optional[Tuple[Awaitable, WaitFor]] = None
|
||||
self._wait_for_fut: Optional[Future[WaitFor]] = None
|
||||
self._ping_sent = False
|
||||
self._read_lguard = RLock()
|
||||
self.read_lock = self._read_lguard
|
||||
self._read_lwork = asyncio.Lock()
|
||||
self._wait_for = asyncio.Event()
|
||||
|
||||
self._pending_who: Deque[str] = deque()
|
||||
self._initial_nick: Optional[str] = None
|
||||
self._pending_who: Deque[str] = deque()
|
||||
self._alt_nicks: List[str] = []
|
||||
|
||||
def hostmask(self) -> str:
|
||||
hostmask = self.nickname
|
||||
|
@ -120,9 +124,8 @@ class Server(IServer):
|
|||
reader, writer = await transport.connect(
|
||||
params.host,
|
||||
params.port,
|
||||
tls =params.tls,
|
||||
tls_verify=params.tls_verify,
|
||||
bindhost =params.bindhost)
|
||||
tls =params.tls,
|
||||
bindhost =params.bindhost)
|
||||
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
|
@ -140,7 +143,10 @@ class Server(IServer):
|
|||
username = self.params.username or nickname
|
||||
realname = self.params.realname or nickname
|
||||
|
||||
self._initial_nick = nickname
|
||||
alt_nicks = self.params.alt_nicknames
|
||||
if not alt_nicks:
|
||||
alt_nicks = [nickname+"_"*i for i in range(1, 4)]
|
||||
self._alt_nicks = alt_nicks
|
||||
|
||||
# these must remain non-awaited; reading hasn't started yet
|
||||
if not self.params.password is None:
|
||||
|
@ -174,24 +180,35 @@ class Server(IServer):
|
|||
self._pending_who[0] == chan):
|
||||
self._pending_who.popleft()
|
||||
await self._next_who()
|
||||
|
||||
elif (line.command == ERR_NICKNAMEINUSE and
|
||||
self._initial_nick is not None):
|
||||
nick = self._initial_nick
|
||||
|
||||
alt_nicks = self.params.alt_nicknames
|
||||
if not alt_nicks:
|
||||
alt_nicks = [nick+"_"*i for i in range(3)]
|
||||
|
||||
for alt_nick in alt_nicks:
|
||||
if await self.send_nick(alt_nick):
|
||||
break
|
||||
elif (line.command in {
|
||||
ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE
|
||||
} and not self.registered):
|
||||
if self._alt_nicks:
|
||||
nick = self._alt_nicks.pop(0)
|
||||
await self.send(build("NICK", [nick]))
|
||||
else:
|
||||
self._initial_nick = None
|
||||
await self.send(build("QUIT"))
|
||||
|
||||
elif line.command in [RPL_ENDOFMOTD, ERR_NOMOTD]:
|
||||
# we didn't get the nickname we wanted. watch for it if we can
|
||||
if not self.nickname == self.params.nickname:
|
||||
target = self.params.nickname
|
||||
if self.isupport.monitor is not None:
|
||||
await self.send(build("MONITOR", ["+", target]))
|
||||
elif self.isupport.watch is not None:
|
||||
await self.send(build("WATCH", [f"+{target}"]))
|
||||
|
||||
# has someone just stopped using the nickname we want?
|
||||
elif line.command == RPL_LOGOFF:
|
||||
await self._check_regain([line.params[1]])
|
||||
elif line.command == RPL_MONOFFLINE:
|
||||
await self._check_regain(line.params[1].split(","))
|
||||
elif (line.command in ["NICK", "QUIT"] and
|
||||
line.source is not None):
|
||||
await self._check_regain([line.hostmask.nickname])
|
||||
|
||||
elif emit is not None:
|
||||
if emit.command == "001":
|
||||
if emit.command == RPL_WELCOME:
|
||||
await self.send(build("WHO", [self.nickname]))
|
||||
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
|
||||
|
||||
|
@ -222,6 +239,12 @@ class Server(IServer):
|
|||
|
||||
await self.line_read(line)
|
||||
|
||||
async def _check_regain(self, nicks: List[str]):
|
||||
for nick in nicks:
|
||||
if (self.casefold_equals(nick, self.params.nickname) and
|
||||
not self.nickname == self.params.nickname):
|
||||
await self.send(build("NICK", [self.params.nickname]))
|
||||
|
||||
async def _batch_joins(self,
|
||||
channels: List[str],
|
||||
batch_n: int=10):
|
||||
|
@ -254,76 +277,43 @@ class Server(IServer):
|
|||
self.last_read = monotonic()
|
||||
lines = self.recv(data)
|
||||
for line in lines:
|
||||
self.line_preread(line)
|
||||
self._read_queue.append(line)
|
||||
|
||||
async def _line_or_wait(self,
|
||||
line_aw: asyncio.Task
|
||||
) -> Optional[Tuple[Awaitable, WaitFor]]:
|
||||
wait_for_fut: Future[WaitFor] = Future()
|
||||
self._wait_for_fut = wait_for_fut
|
||||
|
||||
done, pend = await asyncio.wait([line_aw, wait_for_fut],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
self._wait_for_fut = None
|
||||
|
||||
if wait_for_fut.done():
|
||||
new_line_aw = list(pend)[0]
|
||||
return (new_line_aw, wait_for_fut.result())
|
||||
else:
|
||||
return None
|
||||
|
||||
async def _read_lines(self):
|
||||
waiting_lines: List[Tuple[Line, Optional[Emit]]] = []
|
||||
sent_ping = False
|
||||
while True:
|
||||
now = monotonic()
|
||||
timeouts: List[float] = []
|
||||
timeouts.append((self.last_read+PING_TIMEOUT)-now)
|
||||
if self._wait_for is not None:
|
||||
_, wait_for = self._wait_for
|
||||
timeouts.append(wait_for.deadline-now)
|
||||
async with self._read_lguard:
|
||||
pass
|
||||
|
||||
line = await self._read_line(max([0.1, min(timeouts)]))
|
||||
if line is None:
|
||||
now = monotonic()
|
||||
since = now-self.last_read
|
||||
if not self._process_queue:
|
||||
async with self._read_lwork:
|
||||
read_aw = asyncio.create_task(self._read_line(PING_TIMEOUT))
|
||||
wait_aw = asyncio.create_task(self._wait_for.wait())
|
||||
dones, notdones = await asyncio.wait(
|
||||
[read_aw, wait_aw],
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
self._wait_for.clear()
|
||||
|
||||
if self._wait_for is not None:
|
||||
aw, wait_for = self._wait_for
|
||||
if wait_for.deadline <= now:
|
||||
self._wait_for = None
|
||||
await aw
|
||||
for done in dones:
|
||||
if isinstance(done.result(), Line):
|
||||
self._ping_sent = False
|
||||
line = done.result()
|
||||
emit = self.parse_tokens(line)
|
||||
self._process_queue.append((line, emit))
|
||||
elif done.result() is None:
|
||||
if not self._ping_sent:
|
||||
await self.send(build("PING", ["hello"]))
|
||||
self._ping_sent = True
|
||||
else:
|
||||
await self.disconnect()
|
||||
raise ServerDisconnectedException()
|
||||
for notdone in notdones:
|
||||
notdone.cancel()
|
||||
|
||||
if since >= PING_TIMEOUT:
|
||||
if since >= (PING_TIMEOUT*2):
|
||||
raise ServerDisconnectedException()
|
||||
elif not sent_ping:
|
||||
sent_ping = True
|
||||
await self.send(build("PING", ["hello"]))
|
||||
continue
|
||||
else:
|
||||
sent_ping = False
|
||||
emit = self.parse_tokens(line)
|
||||
|
||||
waiting_lines.append((line, emit))
|
||||
self.line_preread(line)
|
||||
|
||||
if self._wait_for is not None:
|
||||
aw, wait_for = self._wait_for
|
||||
if wait_for.match(self, line):
|
||||
wait_for.resolve(line)
|
||||
self._wait_for = await self._line_or_wait(aw)
|
||||
if self._wait_for is not None:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
for i in range(len(waiting_lines)):
|
||||
line, emit = waiting_lines.pop(0)
|
||||
line_aw = self._on_read(line, emit)
|
||||
self._wait_for = await self._line_or_wait(line_aw)
|
||||
if self._wait_for is not None:
|
||||
break
|
||||
line, emit = self._process_queue.popleft()
|
||||
await self._on_read(line, emit)
|
||||
|
||||
async def wait_for(self,
|
||||
response: Union[IMatchResponse, Set[IMatchResponse]],
|
||||
|
@ -337,22 +327,18 @@ class Server(IServer):
|
|||
else:
|
||||
response_obj = response
|
||||
|
||||
deadline = monotonic()+timeout
|
||||
our_wait_for = WaitFor(response_obj, deadline)
|
||||
if self._wait_for_fut is not None:
|
||||
self._wait_for_fut.set_result(our_wait_for)
|
||||
else:
|
||||
cur_task = asyncio.current_task()
|
||||
if cur_task is not None:
|
||||
self._wait_for = (cur_task, our_wait_for)
|
||||
|
||||
if sent_aw is not None:
|
||||
sent_line = await sent_aw
|
||||
label = str(sent_line.id)
|
||||
our_wait_for.with_label(label)
|
||||
|
||||
async with timeout_(timeout):
|
||||
return (await our_wait_for)
|
||||
async with self._read_lguard:
|
||||
self._wait_for.set()
|
||||
async with self._read_lwork:
|
||||
async with timeout_(timeout):
|
||||
while True:
|
||||
line = await self._read_line(timeout)
|
||||
if line:
|
||||
self._ping_sent = False
|
||||
emit = self.parse_tokens(line)
|
||||
self._process_queue.append((line, emit))
|
||||
if response_obj.match(self, line):
|
||||
return line
|
||||
|
||||
async def _on_send_line(self, line: Line):
|
||||
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and
|
||||
|
@ -568,7 +554,7 @@ class Server(IServer):
|
|||
for symbol in symbols:
|
||||
mode = self.isupport.prefix.from_prefix(symbol)
|
||||
if mode is not None:
|
||||
channel_user.modes.append(mode)
|
||||
channel_user.modes.add(mode)
|
||||
|
||||
obj.channels.append(channel_user)
|
||||
elif line.command == RPL_ENDOFWHOIS:
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from hashlib import sha512
|
||||
from ssl import SSLContext
|
||||
from typing import Optional, Tuple
|
||||
from asyncio import StreamReader, StreamWriter
|
||||
from async_stagger import open_connection
|
||||
|
||||
from .interface import ITCPTransport, ITCPReader, ITCPWriter
|
||||
from .security import tls_context
|
||||
from .security import (tls_context, TLS, TLSNoVerify, TLSVerifyHash,
|
||||
TLSVerifySHA512)
|
||||
|
||||
class TCPReader(ITCPReader):
|
||||
def __init__(self, reader: StreamReader):
|
||||
|
@ -32,16 +34,18 @@ class TCPWriter(ITCPWriter):
|
|||
|
||||
class TCPTransport(ITCPTransport):
|
||||
async def connect(self,
|
||||
hostname: str,
|
||||
port: int,
|
||||
tls: bool,
|
||||
tls_verify: bool=True,
|
||||
bindhost: Optional[str]=None
|
||||
hostname: str,
|
||||
port: int,
|
||||
tls: Optional[TLS],
|
||||
bindhost: Optional[str]=None
|
||||
) -> Tuple[ITCPReader, ITCPWriter]:
|
||||
|
||||
cur_ssl: Optional[SSLContext] = None
|
||||
if tls:
|
||||
cur_ssl = tls_context(tls_verify)
|
||||
if tls is not None:
|
||||
cur_ssl = tls_context(not isinstance(tls, TLSNoVerify))
|
||||
if tls.client_keypair is not None:
|
||||
(client_cert, client_key) = tls.client_keypair
|
||||
cur_ssl.load_cert_chain(client_cert, keyfile=client_key)
|
||||
|
||||
local_addr: Optional[Tuple[str, int]] = None
|
||||
if not bindhost is None:
|
||||
|
@ -55,5 +59,20 @@ class TCPTransport(ITCPTransport):
|
|||
server_hostname=server_hostname,
|
||||
ssl =cur_ssl,
|
||||
local_addr =local_addr)
|
||||
|
||||
if isinstance(tls, TLSVerifyHash):
|
||||
cert: bytes = writer.transport.get_extra_info(
|
||||
"ssl_object"
|
||||
).getpeercert(True)
|
||||
if isinstance(tls, TLSVerifySHA512):
|
||||
sum = sha512(cert).hexdigest()
|
||||
else:
|
||||
raise ValueError(f"unknown hash pinning {type(tls)}")
|
||||
|
||||
if not sum == tls.sum:
|
||||
raise ValueError(
|
||||
f"pinned hash for {hostname} does not match ({sum})"
|
||||
)
|
||||
|
||||
return (TCPReader(reader), TCPWriter(writer))
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
anyio ==1.3.0
|
||||
asyncio-throttle ==1.0.1
|
||||
dataclasses ==0.6
|
||||
ircstates ==0.11.2
|
||||
async_stagger ==0.3.0
|
||||
async_timeout ==3.0.1
|
||||
anyio ~=2.0.2
|
||||
asyncio-rlock ~=0.1.0
|
||||
asyncio-throttle ~=1.0.1
|
||||
ircstates ~=0.12.1
|
||||
async_stagger ~=0.3.0
|
||||
async_timeout ~=4.0.2
|
||||
|
|
Loading…
Reference in New Issue