ircrobots/ircrobots/server.py

563 lines
20 KiB
Python

import asyncio
from asyncio import Future, PriorityQueue
from typing import (AsyncIterable, Awaitable, Deque, Dict, Iterable, List,
Optional, Set, Tuple, Union)
from collections import deque
from time import monotonic
import anyio
from asyncio_rlock import RLock
from asyncio_throttle import Throttler
from async_timeout import timeout as timeout_
from ircstates import Emit, Channel, ChannelUser
from ircstates.numerics import *
from ircstates.server import ServerDisconnectedException
from ircstates.names import Name
from irctokens import build, Line, tokenise
from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL,
CAP_LABEL, LABEL_TAG_MAP, resume_transmute)
from .sasl import SASLContext, SASLResult
from .matching import (ResponseOr, Responses, Response, ANY, SELF, MASK_SELF,
Folded)
from .asyncs import MaybeAwait, WaitFor
from .struct import Whois
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
from .interface import (IBot, ICapability, IServer, SentLine, SendPriority,
IMatchResponse)
from .interface import ITCPTransport, ITCPReader, ITCPWriter
THROTTLE_RATE = 4 # lines
THROTTLE_TIME = 2 # seconds
PING_TIMEOUT = 60 # seconds
WAIT_TIMEOUT = 20 # seconds
JOIN_ERR_FIRST = [
ERR_NOSUCHCHANNEL,
ERR_BADCHANNAME,
ERR_UNAVAILRESOURCE,
ERR_TOOMANYCHANNELS,
ERR_BANNEDFROMCHAN,
ERR_INVITEONLYCHAN,
ERR_BADCHANNELKEY,
ERR_NEEDREGGEDNICK,
ERR_THROTTLE
]
class Server(IServer):
_reader: ITCPReader
_writer: ITCPWriter
params: ConnectionParams
def __init__(self, bot: IBot, name: str):
super().__init__(name)
self.bot = bot
self.disconnected = False
self.throttle = Throttler(rate_limit=100, period=1)
self.sasl_state = SASLResult.NONE
self.last_read = monotonic()
self._sent_count: int = 0
self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
self.desired_caps: Set[ICapability] = set([])
self._read_queue: Deque[Line] = deque()
self._process_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
self._ping_sent = False
self._read_lguard = RLock()
self.read_lock = self._read_lguard
self._read_lwork = asyncio.Lock()
self._wait_for = asyncio.Event()
self._pending_who: Deque[str] = deque()
self._alt_nicks: List[str] = []
def hostmask(self) -> str:
hostmask = self.nickname
if not self.username is None:
hostmask += f"!{self.username}"
if not self.hostname is None:
hostmask += f"@{self.hostname}"
return hostmask
def send_raw(self, line: str, priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]:
return self.send(tokenise(line), priority)
def send(self,
line: Line,
priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]:
self.line_presend(line)
sent_line = SentLine(self._sent_count, priority, line)
self._sent_count += 1
label = self.cap_available(CAP_LABEL)
if not label is None:
tag = LABEL_TAG_MAP[label]
if line.tags is None or not tag in line.tags:
if line.tags is None:
line.tags = {}
line.tags[tag] = str(sent_line.id)
self._send_queue.put_nowait(sent_line)
return sent_line.future
def set_throttle(self, rate: int, time: float):
self.throttle.rate_limit = rate
self.throttle.period = time
def server_address(self) -> Tuple[str, int]:
return self._writer.get_peer()
async def connect(self,
transport: ITCPTransport,
params: ConnectionParams):
await sts_transmute(params)
await resume_transmute(params)
reader, writer = await transport.connect(
params.host,
params.port,
tls =params.tls,
bindhost =params.bindhost)
self._reader = reader
self._writer = writer
self.params = params
await self.handshake()
async def disconnect(self):
if not self._writer is None:
await self._writer.close()
self._writer = None
self._read_queue.clear()
async def handshake(self):
nickname = self.params.nickname
username = self.params.username or nickname
realname = self.params.realname or nickname
alt_nicks = self.params.alt_nicknames
if not alt_nicks:
alt_nicks = [nickname+"_"*i for i in range(1, 4)]
self._alt_nicks = alt_nicks
# these must remain non-awaited; reading hasn't started yet
if not self.params.password is None:
self.send(build("PASS", [self.params.password]))
self.send(build("CAP", ["LS", "302"]))
self.send(build("NICK", [nickname]))
self.send(build("USER", [username, "0", "*", realname]))
# to be overridden
def line_preread(self, line: Line):
pass
def line_presend(self, line: Line):
pass
async def line_read(self, line: Line):
pass
async def line_send(self, line: Line):
pass
async def sts_policy(self, sts: STSPolicy):
pass
async def resume_policy(self, resume: ResumePolicy):
pass
# /to be overriden
async def _on_read(self, line: Line, emit: Optional[Emit]):
if line.command == "PING":
await self.send(build("PONG", line.params))
elif line.command == RPL_ENDOFWHO:
chan = self.casefold(line.params[1])
if (self._pending_who and
self._pending_who[0] == chan):
self._pending_who.popleft()
await self._next_who()
elif (line.command in {
ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE
} and not self.registered):
if self._alt_nicks:
nick = self._alt_nicks.pop(0)
await self.send(build("NICK", [nick]))
else:
await self.send(build("QUIT"))
elif line.command in [RPL_ENDOFMOTD, ERR_NOMOTD]:
# we didn't get the nickname we wanted. watch for it if we can
if not self.nickname == self.params.nickname:
target = self.params.nickname
if self.isupport.monitor is not None:
await self.send(build("MONITOR", ["+", target]))
elif self.isupport.watch is not None:
await self.send(build("WATCH", [f"+{target}"]))
# has someone just stopped using the nickname we want?
elif line.command == RPL_LOGOFF:
await self._check_regain([line.params[1]])
elif line.command == RPL_MONOFFLINE:
await self._check_regain(line.params[1].split(","))
elif (line.command in ["NICK", "QUIT"] and
line.source is not None):
await self._check_regain([line.hostmask.nickname])
elif emit is not None:
if emit.command == RPL_WELCOME:
await self.send(build("WHO", [self.nickname]))
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
if self.params.autojoin:
await self._batch_joins(self.params.autojoin)
elif emit.command == "CAP":
if emit.subcommand == "NEW":
await self._cap_ls(emit)
elif (emit.subcommand == "LS" and
emit.finished):
if not self.registered:
await CAPContext(self).handshake()
else:
await self._cap_ls(emit)
elif emit.command == "JOIN":
if emit.self and not emit.channel is None:
chan = emit.channel.name_lower
await self.send(build("MODE", [chan]))
modes = "".join(self.isupport.chanmodes.a_modes)
await self.send(build("MODE", [chan, f"+{modes}"]))
self._pending_who.append(chan)
if len(self._pending_who) == 1:
await self._next_who()
await self.line_read(line)
async def _check_regain(self, nicks: List[str]):
for nick in nicks:
if (self.casefold_equals(nick, self.params.nickname) and
not self.nickname == self.params.nickname):
await self.send(build("NICK", [self.params.nickname]))
async def _batch_joins(self,
channels: List[str],
batch_n: int=10):
#TODO: do as many JOINs in one line as we can fit
#TODO: channel keys
for i in range(0, len(channels), batch_n):
batch = channels[i:i+batch_n]
await self.send(build("JOIN", [",".join(batch)]))
async def _next_who(self):
if self._pending_who:
chan = self._pending_who[0]
if self.isupport.whox:
await self.send(self.prepare_whox(chan))
else:
await self.send(build("WHO", [chan]))
async def _read_line(self, timeout: float) -> Optional[Line]:
while True:
if self._read_queue:
return self._read_queue.popleft()
try:
async with timeout_(timeout):
data = await self._reader.read(1024)
except asyncio.TimeoutError:
return None
self.last_read = monotonic()
lines = self.recv(data)
for line in lines:
self.line_preread(line)
self._read_queue.append(line)
async def _read_lines(self):
while True:
async with self._read_lguard:
pass
if not self._process_queue:
async with self._read_lwork:
read_aw = asyncio.create_task(self._read_line(PING_TIMEOUT))
wait_aw = asyncio.create_task(self._wait_for.wait())
dones, notdones = await asyncio.wait(
[read_aw, wait_aw],
return_when=asyncio.FIRST_COMPLETED
)
self._wait_for.clear()
for done in dones:
if isinstance(done.result(), Line):
self._ping_sent = False
line = done.result()
emit = self.parse_tokens(line)
self._process_queue.append((line, emit))
elif done.result() is None:
if not self._ping_sent:
await self.send(build("PING", ["hello"]))
self._ping_sent = True
else:
await self.disconnect()
raise ServerDisconnectedException()
for notdone in notdones:
notdone.cancel()
else:
line, emit = self._process_queue.popleft()
await self._on_read(line, emit)
async def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]],
sent_aw: Optional[Awaitable[SentLine]]=None,
timeout: float=WAIT_TIMEOUT
) -> Line:
response_obj: IMatchResponse
if isinstance(response, set):
response_obj = ResponseOr(*response)
else:
response_obj = response
async with self._read_lguard:
self._wait_for.set()
async with self._read_lwork:
async with timeout_(timeout):
while True:
line = await self._read_line(timeout)
if line:
self._ping_sent = False
emit = self.parse_tokens(line)
self._process_queue.append((line, emit))
if response_obj.match(self, line):
return line
async def _on_send_line(self, line: Line):
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and
not self.cap_agreed(CAP_ECHO)):
new_line = line.with_source(self.hostmask())
self._read_queue.append(new_line)
async def _send_lines(self):
while True:
lines: List[SentLine] = []
while (not lines or
(len(lines) < 5 and self._send_queue.qsize() > 0)):
prio_line = await self._send_queue.get()
lines.append(prio_line)
for line in lines:
async with self.throttle:
self._writer.write(
f"{line.line.format()}\r\n".encode("utf8"))
await self._writer.drain()
for line in lines:
await self._on_send_line(line.line)
await self.line_send(line.line)
line.future.set_result(line)
# CAP-related
def cap_agreed(self, capability: ICapability) -> bool:
return bool(self.cap_available(capability))
def cap_available(self, capability: ICapability) -> Optional[str]:
return capability.available(self.agreed_caps)
async def _cap_ls(self, emit: Emit):
if not emit.tokens is None:
tokens: Dict[str, str] = {}
for token in emit.tokens:
key, _, value = token.partition("=")
tokens[key] = value
await CAPContext(self).on_ls(tokens)
async def sasl_auth(self, params: SASLParams) -> bool:
if (self.sasl_state == SASLResult.NONE and
self.cap_agreed(CAP_SASL)):
res = await SASLContext(self).from_params(params)
self.sasl_state = res
return True
else:
return False
# /CAP-related
def send_nick(self, new_nick: str) -> Awaitable[bool]:
fut = self.send(build("NICK", [new_nick]))
async def _assure() -> bool:
line = await self.wait_for({
Response("NICK", [Folded(new_nick)], source=MASK_SELF),
Responses([
ERR_BANNICKCHANGE,
ERR_NICKTOOFAST,
ERR_CANTCHANGENICK
], [ANY]),
Responses([
ERR_NICKNAMEINUSE,
ERR_ERRONEUSNICKNAME,
ERR_UNAVAILRESOURCE
], [ANY, Folded(new_nick)])
}, fut)
return line.command == "NICK"
return MaybeAwait(_assure)
def send_join(self,
name: str,
key: Optional[str]=None
) -> Awaitable[Channel]:
fut = self.send_joins([name], [] if key is None else [key])
async def _assure():
channels = await fut
return channels[0]
return MaybeAwait(_assure)
def send_part(self, name: str):
fut = self.send(build("PART", [name]))
async def _assure():
line = await self.wait_for(
Response("PART", [Folded(name)], source=MASK_SELF),
fut
)
return
return MaybeAwait(_assure)
def send_joins(self,
names: List[str],
keys: List[str]=[]
) -> Awaitable[List[Channel]]:
folded_names = [self.casefold(name) for name in names]
if not keys:
fut = self.send(build("JOIN", [",".join(names)]))
else:
fut = self.send(build("JOIN", [",".join(names)]+keys))
async def _assure():
channels: List[Channel] = []
while folded_names:
line = await self.wait_for({
Response(RPL_CHANNELMODEIS, [ANY, ANY]),
Responses(JOIN_ERR_FIRST, [ANY, ANY]),
Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]),
Response(ERR_LINKCHANNEL, [ANY, ANY, ANY])
}, fut)
chan: Optional[str] = None
if line.command == RPL_CHANNELMODEIS:
chan = line.params[1]
elif line.command in JOIN_ERR_FIRST:
chan = line.params[1]
elif line.command == ERR_USERONCHANNEL:
chan = line.params[2]
elif line.command == ERR_LINKCHANNEL:
#XXX i dont like this
chan = line.params[2]
await self.wait_for(
Response(RPL_CHANNELMODEIS, [ANY, Folded(chan)])
)
channels.append(self.channels[self.casefold(chan)])
continue
if chan is not None:
folded = self.casefold(chan)
if folded in folded_names:
folded_names.remove(folded)
channels.append(self.channels[folded])
return channels
return MaybeAwait(_assure)
def send_message(self, target: str, message: str
) -> Awaitable[Optional[str]]:
fut = self.send(build("PRIVMSG", [target, message]))
async def _assure():
line = await self.wait_for(
Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF),
fut
)
if line.command == "PRIVMSG":
return line.params[1]
else:
return None
return MaybeAwait(_assure)
def send_whois(self,
target: str,
remote: bool=False
) -> Awaitable[Optional[Whois]]:
args = [target]
if remote:
args.append(target)
fut = self.send(build("WHOIS", args))
async def _assure() -> Optional[Whois]:
folded = self.casefold(target)
params = [ANY, Folded(folded)]
obj = Whois()
while True:
line = await self.wait_for(Responses([
ERR_NOSUCHNICK,
ERR_NOSUCHSERVER,
RPL_WHOISUSER,
RPL_WHOISSERVER,
RPL_WHOISOPERATOR,
RPL_WHOISIDLE,
RPL_WHOISCHANNELS,
RPL_WHOISHOST,
RPL_WHOISACCOUNT,
RPL_WHOISSECURE,
RPL_ENDOFWHOIS
], params), fut)
if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]:
return None
elif line.command == RPL_WHOISUSER:
nick, user, host, _, real = line.params[1:]
obj.nickname = nick
obj.username = user
obj.hostname = host
obj.realname = real
elif line.command == RPL_WHOISIDLE:
idle, signon, _ = line.params[2:]
obj.idle = int(idle)
obj.signon = int(signon)
elif line.command == RPL_WHOISACCOUNT:
obj.account = line.params[2]
elif line.command == RPL_WHOISCHANNELS:
channels = list(filter(bool, line.params[2].split(" ")))
if obj.channels is None:
obj.channels = []
for i, channel in enumerate(channels):
symbols = ""
while channel[0] in self.isupport.prefix.prefixes:
symbols += channel[0]
channel = channel[1:]
channel_user = ChannelUser(
Name(obj.nickname, folded),
Name(channel, self.casefold(channel))
)
for symbol in symbols:
mode = self.isupport.prefix.from_prefix(symbol)
if mode is not None:
channel_user.modes.add(mode)
obj.channels.append(channel_user)
elif line.command == RPL_ENDOFWHOIS:
return obj
return MaybeAwait(_assure)