ircrobots/ircrobots/server.py

371 lines
12 KiB
Python
Raw Normal View History

import asyncio
from asyncio import Future, PriorityQueue
from typing import Awaitable, Deque, Dict, List, Optional, Set, Tuple
from collections import deque
from time import monotonic
2020-04-01 14:36:53 +00:00
from asyncio_throttle import Throttler
from async_timeout import timeout
from ircstates import Emit, Channel
from ircstates.numerics import *
from ircstates.server import ServerDisconnectedException
from irctokens import build, Line, tokenise
2020-04-01 14:36:53 +00:00
from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL,
2020-04-25 18:30:36 +00:00
CAP_LABEL, LABEL_TAG, resume_transmute)
from .sasl import SASLContext, SASLResult
from .join_info import WHOContext
from .matching import ResponseOr, Responses, Response, ANY, Folded, Nickname
2020-04-23 14:22:30 +00:00
from .asyncs import MaybeAwait, WaitFor
from .struct import Whois
2020-04-25 18:30:36 +00:00
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
from .interface import (IBot, ICapability, IServer, SentLine, SendPriority,
IMatchResponse)
from .interface import ITCPTransport, ITCPReader, ITCPWriter
2020-04-01 14:36:53 +00:00
THROTTLE_RATE = 4 # lines
THROTTLE_TIME = 2 # seconds
PING_TIMEOUT = 60 # seconds
2020-04-01 14:36:53 +00:00
class Server(IServer):
_reader: ITCPReader
_writer: ITCPWriter
2020-04-01 14:36:53 +00:00
params: ConnectionParams
def __init__(self, bot: IBot, name: str):
2020-04-01 14:36:53 +00:00
super().__init__(name)
self.bot = bot
self.disconnected = False
2020-04-01 14:36:53 +00:00
self.throttle = Throttler(
rate_limit=100, period=THROTTLE_TIME)
self.sasl_state = SASLResult.NONE
self.last_read = -1.0
2020-04-01 14:36:53 +00:00
self._sent_count: int = 0
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
self.desired_caps: Set[ICapability] = set([])
self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
2020-04-23 14:22:30 +00:00
self._wait_for: List[Tuple[Awaitable, WaitFor]] = []
2020-04-23 14:30:33 +00:00
self._wait_for_fut: Optional[Future[WaitFor]] = None
def hostmask(self) -> str:
hostmask = self.nickname
if not self.username is None:
hostmask += f"!{self.username}"
if not self.hostname is None:
hostmask += f"@{self.hostname}"
return hostmask
def send_raw(self, line: str, priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]:
return self.send(tokenise(line), priority)
def send(self, line: Line, priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]:
sent_line = SentLine(self._sent_count, priority, line)
self._sent_count += 1
label = self.cap_available(CAP_LABEL)
if not label is None:
tag = LABEL_TAG[label]
if line.tags is None or not tag in line.tags:
if line.tags is None:
line.tags = {}
line.tags[tag] = str(sent_line.id)
self._write_queue.put_nowait(sent_line)
async def _assure() -> SentLine:
await sent_line.future
return sent_line
return MaybeAwait(_assure)
2020-04-01 14:36:53 +00:00
def set_throttle(self, rate: int, time: float):
self.throttle.rate_limit = rate
self.throttle.period = time
2020-04-25 18:30:36 +00:00
def server_address(self) -> Tuple[str, int]:
return self._writer.get_peer()
async def connect(self,
transport: ITCPTransport,
params: ConnectionParams):
await sts_transmute(params)
2020-04-25 18:30:36 +00:00
await resume_transmute(params)
reader, writer = await transport.connect(
2020-04-02 21:39:03 +00:00
params.host,
params.port,
tls =params.tls,
tls_verify=params.tls_verify,
bindhost =params.bindhost)
2020-04-02 21:39:03 +00:00
2020-04-01 14:36:53 +00:00
self._reader = reader
self._writer = writer
self.params = params
await self.handshake()
async def disconnect(self):
if not self._writer is None:
await self._writer.close()
self._writer = None
self._read_queue.clear()
2020-04-01 14:36:53 +00:00
async def handshake(self):
nickname = self.params.nickname
username = self.params.username or nickname
realname = self.params.realname or nickname
2020-04-01 14:36:53 +00:00
# these must remain non-awaited; reading hasn't started yet
2020-04-13 18:10:02 +00:00
if not self.params.password is None:
self.send(build("PASS", [self.params.password]))
self.send(build("CAP", ["LS", "302"]))
self.send(build("NICK", [nickname]))
self.send(build("USER", [username, "0", "*", realname]))
2020-04-01 14:36:53 +00:00
# to be overridden
async def line_read(self, line: Line):
pass
async def line_send(self, line: Line):
pass
async def sts_policy(self, sts: STSPolicy):
pass
2020-04-25 18:30:36 +00:00
async def resume_policy(self, resume: ResumePolicy):
pass
# /to be overriden
async def _on_read_emit(self, line: Line, emit: Emit):
if emit.command == "001":
await self.send(build("WHO", [self.nickname]))
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
elif emit.command == "CAP":
if emit.subcommand == "NEW":
await self._cap_ls(emit)
elif (emit.subcommand == "LS" and
emit.finished):
if not self.registered:
await CAPContext(self).handshake()
else:
await self._cap_ls(emit)
2020-04-02 21:57:22 +00:00
elif emit.command == "JOIN":
if emit.self and not emit.channel is None:
2020-04-02 21:57:22 +00:00
await self.send(build("MODE", [emit.channel.name]))
await WHOContext(self).ensure(emit.channel.name)
async def _on_read_line(self, line: Line):
if line.command == "PING":
await self.send(build("PONG", line.params))
2020-04-23 14:30:33 +00:00
async def _line_or_wait(self, line_aw: Awaitable):
2020-04-23 14:22:30 +00:00
wait_for_fut: Future[WaitFor] = Future()
self._wait_for_fut = wait_for_fut
2020-04-23 14:30:33 +00:00
done, pend = await asyncio.wait([line_aw, wait_for_fut],
return_when=asyncio.FIRST_COMPLETED)
if wait_for_fut.done():
2020-04-23 14:30:33 +00:00
new_line_aw = list(pend)[0]
self._wait_for.append((new_line_aw, await wait_for_fut))
async def next_line(self) -> Tuple[Line, Optional[Emit]]:
if self._read_queue:
both = self._read_queue.popleft()
else:
2020-04-26 14:14:13 +00:00
ping_sent = False
2020-04-12 12:53:39 +00:00
while True:
try:
async with timeout(PING_TIMEOUT):
data = await self._reader.read(1024)
except asyncio.exceptions.TimeoutError:
2020-04-26 14:14:13 +00:00
if ping_sent:
data = b"" # empty data means the socket disconnected
else:
2020-04-26 14:14:13 +00:00
ping_sent = True
await self.send(build("PING", ["hello"]))
continue
self.last_read = monotonic()
2020-04-26 14:14:13 +00:00
ping_sent = False
try:
lines = self.recv(data)
except ServerDisconnectedException:
self.disconnected = True
raise
2020-04-12 12:53:39 +00:00
if lines:
self._read_queue.extend(lines[1:])
both = lines[0]
break
line, emit = both
async def _line():
if emit is not None:
await self._on_read_emit(line, emit)
await self._on_read_line(line)
await self.line_read(line)
2020-04-23 14:22:30 +00:00
for i, (aw, wait_for) in enumerate(self._wait_for):
if wait_for.response.match(self, line):
wait_for.resolve(line)
self._wait_for.pop(i)
await self._line_or_wait(aw)
break
await self._line_or_wait(_line())
return both
async def wait_for(self, response: IMatchResponse) -> Line:
wait_for_fut = self._wait_for_fut
if wait_for_fut is not None:
self._wait_for_fut = None
2020-04-23 14:22:30 +00:00
our_wait_for = WaitFor(response)
wait_for_fut.set_result(our_wait_for)
return await our_wait_for
raise Exception()
async def _on_write_line(self, line: Line):
if (line.command == "PRIVMSG" and
not self.cap_agreed(CAP_ECHO)):
new_line = line.with_source(self.hostmask())
emit = self.parse_tokens(new_line)
self._read_queue.append((new_line, emit))
2020-04-01 14:36:53 +00:00
async def _write_lines(self) -> List[Line]:
lines: List[SentLine] = []
2020-04-01 14:36:53 +00:00
while (not lines or
(len(lines) < 5 and self._write_queue.qsize() > 0)):
prio_line = await self._write_queue.get()
lines.append(prio_line)
2020-04-01 14:36:53 +00:00
for line in lines:
async with self.throttle:
self._writer.write(
f"{line.line.format()}\r\n".encode("utf8"))
2020-04-01 14:36:53 +00:00
await self._writer.drain()
for line in lines:
await self._on_write_line(line.line)
await self.line_send(line.line)
line.future.set_result(line)
return [l.line for l in lines]
# CAP-related
def cap_agreed(self, capability: ICapability) -> bool:
2020-04-02 16:04:08 +00:00
return bool(self.cap_available(capability))
def cap_available(self, capability: ICapability) -> Optional[str]:
2020-04-02 16:04:08 +00:00
return capability.available(self.agreed_caps)
async def _cap_ls(self, emit: Emit):
if not emit.tokens is None:
tokens: Dict[str, str] = {}
for token in emit.tokens:
key, _, value = token.partition("=")
tokens[key] = value
await CAPContext(self).on_ls(tokens)
2020-04-03 08:49:46 +00:00
async def sasl_auth(self, params: SASLParams) -> bool:
if (self.sasl_state == SASLResult.NONE and
self.cap_agreed(CAP_SASL)):
2020-04-03 08:49:46 +00:00
res = await SASLContext(self).from_params(params)
self.sasl_state = res
return True
else:
return False
# /CAP-related
def send_join(self,
name: str,
key: Optional[str]=None
) -> Awaitable[Channel]:
fut = self.send_joins([name], [] if key is None else [key])
async def _assure():
channels = await fut
return channels[0]
return MaybeAwait(_assure)
def send_part(self, name: str):
fut = self.send(build("PART", [name]))
async def _assure():
line = await self.wait_for(Response(
"PART",
[ParamFolded(name)],
source=Nickname(self.nickname_lower)
))
return
return MaybeAwait(_assure)
def send_joins(self,
names: List[str],
keys: List[str]=[]
) -> Awaitable[List[Channel]]:
folded_names = [self.casefold(name) for name in names]
if not keys:
fut = self.send(build("JOIN", [",".join(names)]))
else:
fut = self.send(build("JOIN", [",".join(names)]+keys))
async def _assure():
await fut
channels: List[Channel] = []
while folded_names:
line = await self.wait_for(
Response(RPL_CHANNELMODEIS, [ANY, ANY])
2020-04-19 00:32:55 +00:00
)
folded = self.casefold(line.params[1])
if folded in folded_names:
folded_names.remove(folded)
channels.append(self.channels[folded])
return channels
return MaybeAwait(_assure)
def send_whois(self, target: str) -> Awaitable[Whois]:
folded = self.casefold(target)
fut = self.send(build("WHOIS", [target, target]))
async def _assure():
await fut
params = [ANY, Folded(folded)]
obj = Whois()
while True:
line = await self.wait_for(Responses([
RPL_WHOISUSER,
RPL_WHOISSERVER,
RPL_WHOISOPERATOR,
RPL_WHOISIDLE,
RPL_WHOISHOST,
RPL_WHOISACCOUNT,
RPL_WHOISSECURE,
RPL_ENDOFWHOIS
], params))
if line.command == RPL_WHOISUSER:
obj.username, obj.hostname, _, obj.realname = line.params[2:]
elif line.command == RPL_WHOISIDLE:
obj.idle, signon, _ = line.params[2:]
obj.signon = int(signon)
elif line.command == RPL_WHOISACCOUNT:
obj.account = line.params[2]
elif line.command == RPL_ENDOFWHOIS:
return obj
return MaybeAwait(_assure)