ircrobots/ircrobots/server.py

504 lines
17 KiB
Python

import asyncio
from asyncio import Future, PriorityQueue
from typing import (Awaitable, Deque, Dict, List, Optional, Set, Tuple,
Union)
from collections import deque
from time import monotonic
from asyncio_throttle import Throttler
from async_timeout import timeout
from ircstates import Emit, Channel, ChannelUser
from ircstates.numerics import *
from ircstates.server import ServerDisconnectedException
from irctokens import build, Line, tokenise
from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL,
CAP_LABEL, LABEL_TAG_MAP, resume_transmute)
from .sasl import SASLContext, SASLResult
from .join_info import WHOContext
from .matching import (ResponseOr, Responses, Response, ANY, SELF, MASK_SELF,
Folded)
from .asyncs import MaybeAwait, WaitFor
from .struct import Whois
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
from .interface import (IBot, ICapability, IServer, SentLine, SendPriority,
IMatchResponse)
from .interface import ITCPTransport, ITCPReader, ITCPWriter
THROTTLE_RATE = 4 # lines
THROTTLE_TIME = 2 # seconds
PING_TIMEOUT = 60 # seconds
JOIN_ERR_FIRST = [
ERR_NOSUCHCHANNEL,
ERR_BADCHANNAME,
ERR_UNAVAILRESOURCE,
ERR_TOOMANYCHANNELS,
ERR_BANNEDFROMCHAN,
ERR_INVITEONLYCHAN,
ERR_BADCHANNELKEY,
ERR_NEEDREGGEDNICK,
ERR_THROTTLE
]
class Server(IServer):
_reader: ITCPReader
_writer: ITCPWriter
params: ConnectionParams
def __init__(self, bot: IBot, name: str):
super().__init__(name)
self.bot = bot
self.disconnected = False
self.throttle = Throttler(
rate_limit=100, period=THROTTLE_TIME)
self.sasl_state = SASLResult.NONE
self.last_read = -1.0
self._sent_count: int = 0
self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
self.desired_caps: Set[ICapability] = set([])
self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
self._wait_for_fut: Optional[Future[WaitFor]] = None
def hostmask(self) -> str:
hostmask = self.nickname
if not self.username is None:
hostmask += f"!{self.username}"
if not self.hostname is None:
hostmask += f"@{self.hostname}"
return hostmask
def send_raw(self, line: str, priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]:
return self.send(tokenise(line), priority)
def send(self, line: Line, priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]:
self.line_presend(line)
sent_line = SentLine(self._sent_count, priority, line)
self._sent_count += 1
label = self.cap_available(CAP_LABEL)
if not label is None:
tag = LABEL_TAG_MAP[label]
if line.tags is None or not tag in line.tags:
if line.tags is None:
line.tags = {}
line.tags[tag] = str(sent_line.id)
self._send_queue.put_nowait(sent_line)
async def _assure() -> SentLine:
await sent_line.future
return sent_line
return MaybeAwait(_assure)
def set_throttle(self, rate: int, time: float):
self.throttle.rate_limit = rate
self.throttle.period = time
def server_address(self) -> Tuple[str, int]:
return self._writer.get_peer()
async def connect(self,
transport: ITCPTransport,
params: ConnectionParams):
await sts_transmute(params)
await resume_transmute(params)
reader, writer = await transport.connect(
params.host,
params.port,
tls =params.tls,
tls_verify=params.tls_verify,
bindhost =params.bindhost)
self._reader = reader
self._writer = writer
self.params = params
await self.handshake()
async def disconnect(self):
if not self._writer is None:
await self._writer.close()
self._writer = None
self._read_queue.clear()
async def handshake(self):
nickname = self.params.nickname
username = self.params.username or nickname
realname = self.params.realname or nickname
# these must remain non-awaited; reading hasn't started yet
if not self.params.password is None:
self.send(build("PASS", [self.params.password]))
self.send(build("CAP", ["LS", "302"]))
self.send(build("NICK", [nickname]))
self.send(build("USER", [username, "0", "*", realname]))
# to be overridden
def line_preread(self, line: Line):
pass
def line_presend(self, line: Line):
pass
async def line_read(self, line: Line):
pass
async def line_send(self, line: Line):
pass
async def sts_policy(self, sts: STSPolicy):
pass
async def resume_policy(self, resume: ResumePolicy):
pass
# /to be overriden
async def _on_read(self, line: Line, emit: Optional[Emit]):
if line.command == "PING":
await self.send(build("PONG", line.params))
elif emit is not None:
if emit.command == "001":
await self.send(build("WHO", [self.nickname]))
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
elif emit.command == "CAP":
if emit.subcommand == "NEW":
await self._cap_ls(emit)
elif (emit.subcommand == "LS" and
emit.finished):
if not self.registered:
await CAPContext(self).handshake()
else:
await self._cap_ls(emit)
elif emit.command == "JOIN":
if emit.self and not emit.channel is None:
await self.send(build("MODE", [emit.channel.name]))
await WHOContext(self).ensure(emit.channel.name)
await self.line_read(line)
async def _next_line(self) -> Tuple[Line, Optional[Emit]]:
if self._read_queue:
both = self._read_queue.popleft()
else:
ping_sent = False
while True:
try:
async with timeout(PING_TIMEOUT):
data = await self._reader.read(1024)
except asyncio.TimeoutError:
if ping_sent:
data = b"" # empty data means the socket disconnected
else:
ping_sent = True
await self.send(build("PING", ["hello"]))
continue
self.last_read = monotonic()
ping_sent = False
try:
lines = self.recv(data)
except ServerDisconnectedException:
self.disconnected = True
raise
if lines:
self._read_queue.extend(lines[1:])
both = lines[0]
break
return both
async def _line_or_wait(self, line_aw: Awaitable):
wait_for_fut: Future[WaitFor] = Future()
self._wait_for_fut = wait_for_fut
done, pend = await asyncio.wait([line_aw, wait_for_fut],
return_when=asyncio.FIRST_COMPLETED)
if wait_for_fut.done():
new_line_aw = list(pend)[0]
return (await wait_for_fut), new_line_aw
else:
return None, None
async def _read_lines(self):
waited_reads: Deque[Tuple[Line, Optional[Emit]]] = deque()
wait_for: Optional[WaitFor] = None
wait_for_aw: Optional[Awaitable] = None
async def _line() -> Tuple[Line, Optional[Emit]]:
both = await self._next_line()
waited_reads.append(both)
line, emit = both
self.line_preread(line)
return both
while True:
if wait_for is not None:
line, emit = await _line()
if wait_for.response.match(self, line):
wait_for.resolve(line)
wait_for, wait_for_aw = await self._line_or_wait(
wait_for_aw)
else:
if not waited_reads:
await _line()
while waited_reads:
new_line, new_emit = waited_reads.popleft()
line_aw = self._on_read(new_line, new_emit)
wait_for, wait_for_aw = await self._line_or_wait(line_aw)
if wait_for is not None:
break
def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]],
sent_line: Optional[SentLine]=None
) -> Awaitable[Line]:
response_obj: IMatchResponse
if isinstance(response, set):
response_obj = ResponseOr(*response)
else:
response_obj = response
wait_for_fut = self._wait_for_fut
if wait_for_fut is not None:
self._wait_for_fut = None
label: Optional[str] = None
if sent_line is not None:
label = str(sent_line.id)
our_wait_for = WaitFor(wait_for_fut, response_obj, label)
return our_wait_for
raise Exception()
async def _on_send_line(self, line: Line):
if (line.command == "PRIVMSG" and
not self.cap_agreed(CAP_ECHO)):
new_line = line.with_source(self.hostmask())
emit = self.parse_tokens(new_line)
self._read_queue.append((new_line, emit))
async def _send_lines(self):
while True:
lines: List[SentLine] = []
while (not lines or
(len(lines) < 5 and self._send_queue.qsize() > 0)):
prio_line = await self._send_queue.get()
lines.append(prio_line)
for line in lines:
async with self.throttle:
self._writer.write(
f"{line.line.format()}\r\n".encode("utf8"))
await self._writer.drain()
for line in lines:
await self._on_send_line(line.line)
await self.line_send(line.line)
line.future.set_result(line)
# CAP-related
def cap_agreed(self, capability: ICapability) -> bool:
return bool(self.cap_available(capability))
def cap_available(self, capability: ICapability) -> Optional[str]:
return capability.available(self.agreed_caps)
async def _cap_ls(self, emit: Emit):
if not emit.tokens is None:
tokens: Dict[str, str] = {}
for token in emit.tokens:
key, _, value = token.partition("=")
tokens[key] = value
await CAPContext(self).on_ls(tokens)
async def sasl_auth(self, params: SASLParams) -> bool:
if (self.sasl_state == SASLResult.NONE and
self.cap_agreed(CAP_SASL)):
res = await SASLContext(self).from_params(params)
self.sasl_state = res
return True
else:
return False
# /CAP-related
def send_nick(self, new_nick: str) -> Awaitable[bool]:
fut = self.send(build("NICK", [new_nick]))
async def _assure() -> bool:
await fut
line = await self.wait_for({
Response("NICK", [Folded(new_nick)], source=MASK_SELF),
Responses([
ERR_BANNICKCHANGE,
ERR_NICKTOOFAST,
ERR_CANTCHANGENICK
], [ANY]),
Responses([
ERR_NICKNAMEINUSE,
ERR_ERRONEUSNICKNAME,
ERR_UNAVAILRESOURCE
], [ANY, Folded(new_nick)])
})
return line.command == "NICK"
return MaybeAwait(_assure)
def send_join(self,
name: str,
key: Optional[str]=None
) -> Awaitable[Channel]:
fut = self.send_joins([name], [] if key is None else [key])
async def _assure():
channels = await fut
return channels[0]
return MaybeAwait(_assure)
def send_part(self, name: str):
fut = self.send(build("PART", [name]))
async def _assure():
await fut
line = await self.wait_for(
Response("PART", [Folded(name)], source=MASK_SELF)
)
return
return MaybeAwait(_assure)
def send_joins(self,
names: List[str],
keys: List[str]=[]
) -> Awaitable[List[Channel]]:
folded_names = [self.casefold(name) for name in names]
if not keys:
fut = self.send(build("JOIN", [",".join(names)]))
else:
fut = self.send(build("JOIN", [",".join(names)]+keys))
async def _assure():
await fut
channels: List[Channel] = []
while folded_names:
line = await self.wait_for({
Response(RPL_CHANNELMODEIS, [ANY, ANY]),
Responses(JOIN_ERR_FIRST, [ANY, ANY]),
Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]),
Response(ERR_LINKCHANNEL, [ANY, ANY, ANY])
})
chan: Optional[str] = None
if line.command == RPL_CHANNELMODEIS:
chan = line.params[1]
elif line.command in JOIN_ERR_FIRST:
chan = line.params[1]
elif line.command == ERR_USERONCHANNEL:
chan = line.params[2]
elif line.command == ERR_LINKCHANNEL:
#XXX i dont like this
chan = line.params[2]
await self.wait_for(
Response(RPL_CHANNELMODEIS, [ANY, Folded(chan)])
)
channels.append(self.channels[self.casefold(chan)])
continue
if chan is not None:
folded = self.casefold(chan)
if folded in folded_names:
folded_names.remove(folded)
channels.append(self.channels[folded])
return channels
return MaybeAwait(_assure)
def send_message(self, target: str, message: str
) -> Awaitable[Optional[str]]:
fut = self.send(build("PRIVMSG", [target, message]))
async def _assure():
await fut
line = await self.wait_for(
Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF)
)
if line.command == "PRIVMSG":
return line.params[1]
else:
return None
return MaybeAwait(_assure)
def send_whois(self,
target: str,
remote: bool=False
) -> Awaitable[Optional[Whois]]:
args = [target]
if remote:
args.append(target)
fut = self.send(build("WHOIS", args))
async def _assure() -> Optional[Whois]:
await fut
params = [ANY, Folded(self.casefold(target))]
obj = Whois()
while True:
line = await self.wait_for(Responses([
ERR_NOSUCHNICK,
ERR_NOSUCHSERVER,
RPL_WHOISUSER,
RPL_WHOISSERVER,
RPL_WHOISOPERATOR,
RPL_WHOISIDLE,
RPL_WHOISCHANNELS,
RPL_WHOISHOST,
RPL_WHOISACCOUNT,
RPL_WHOISSECURE,
RPL_ENDOFWHOIS
], params))
if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]:
return None
elif line.command == RPL_WHOISUSER:
nick, user, host, _, real = line.params[1:]
obj.nickname = nick
obj.username = user
obj.hostname = host
obj.realname = real
elif line.command == RPL_WHOISIDLE:
idle, signon, _ = line.params[2:]
obj.idle = int(idle)
obj.signon = int(signon)
elif line.command == RPL_WHOISACCOUNT:
obj.account = line.params[2]
elif line.command == RPL_WHOISCHANNELS:
channels = list(filter(bool, line.params[2].split(" ")))
if obj.channels is None:
obj.channels = []
for i, channel in enumerate(channels):
symbols = ""
while channel[0] in self.isupport.prefix.prefixes:
symbols += channel[0]
channel = channel[1:]
channel_user = ChannelUser()
for symbol in symbols:
mode = self.isupport.prefix.from_prefix(symbol)
if mode is not None:
channel_user.modes.append(mode)
obj.channels.append(channel_user)
elif line.command == RPL_ENDOFWHOIS:
return obj
return MaybeAwait(_assure)