mirror of https://github.com/jesopo/ircrobots
292 lines
9.7 KiB
Python
292 lines
9.7 KiB
Python
from asyncio import Future, PriorityQueue
|
|
from typing import Awaitable, Deque, Dict, List, Optional, Set, Tuple
|
|
from collections import deque
|
|
|
|
from asyncio_throttle import Throttler
|
|
from ircstates import Emit, Channel, NUMERIC_NAMES
|
|
from irctokens import build, Line, tokenise
|
|
|
|
from .ircv3 import CAPContext, CAP_ECHO, CAP_SASL, CAP_LABEL, LABEL_TAG
|
|
from .sasl import SASLContext, SASLResult
|
|
from .matching import ResponseOr, Numerics, Numeric, ParamAny, ParamFolded
|
|
from .asyncs import MaybeAwait
|
|
from .struct import Whois
|
|
|
|
from .interface import (ConnectionParams, ICapability, IServer, SentLine,
|
|
SendPriority, SASLParams, IMatchResponse)
|
|
from .interface import ITCPTransport, ITCPReader, ITCPWriter
|
|
|
|
THROTTLE_RATE = 4 # lines
|
|
THROTTLE_TIME = 2 # seconds
|
|
|
|
class Server(IServer):
|
|
_reader: ITCPReader
|
|
_writer: ITCPWriter
|
|
params: ConnectionParams
|
|
|
|
def __init__(self, name: str):
|
|
super().__init__(name)
|
|
|
|
self.throttle = Throttler(
|
|
rate_limit=100, period=THROTTLE_TIME)
|
|
|
|
self.sasl_state = SASLResult.NONE
|
|
|
|
|
|
self._sent_count: int = 0
|
|
self._wait_for: List[Tuple["Future[Line]", IMatchResponse]] = []
|
|
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
|
|
self.desired_caps: Set[ICapability] = set([])
|
|
|
|
self._read_queue: Deque[Tuple[Line, List[Emit]]] = deque()
|
|
|
|
def hostmask(self) -> str:
|
|
hostmask = self.nickname
|
|
if not self.username is None:
|
|
hostmask += f"!{self.username}"
|
|
if not self.hostname is None:
|
|
hostmask += f"@{self.hostname}"
|
|
return hostmask
|
|
|
|
def send_raw(self, line: str, priority=SendPriority.DEFAULT
|
|
) -> Awaitable[SentLine]:
|
|
return self.send(tokenise(line), priority)
|
|
def send(self, line: Line, priority=SendPriority.DEFAULT
|
|
) -> Awaitable[SentLine]:
|
|
sent_line = SentLine(self._sent_count, priority, line)
|
|
self._sent_count += 1
|
|
|
|
label = self.cap_available(CAP_LABEL)
|
|
if not label is None:
|
|
tag = LABEL_TAG[label]
|
|
if line.tags is None or not tag in line.tags:
|
|
if line.tags is None:
|
|
line.tags = {}
|
|
line.tags[tag] = str(sent_line.id)
|
|
|
|
self._write_queue.put_nowait(sent_line)
|
|
|
|
async def _assure() -> SentLine:
|
|
await sent_line.future
|
|
return sent_line
|
|
return MaybeAwait(_assure)
|
|
|
|
def set_throttle(self, rate: int, time: float):
|
|
self.throttle.rate_limit = rate
|
|
self.throttle.period = time
|
|
|
|
async def connect(self,
|
|
transport: ITCPTransport,
|
|
params: ConnectionParams):
|
|
reader, writer = await transport.connect(
|
|
params.host,
|
|
params.port,
|
|
tls =params.tls,
|
|
tls_verify=params.tls_verify,
|
|
bindhost =params.bindhost)
|
|
|
|
self._reader = reader
|
|
self._writer = writer
|
|
|
|
self.params = params
|
|
await self.handshake()
|
|
|
|
async def handshake(self):
|
|
nickname = self.params.nickname
|
|
username = self.params.username or nickname
|
|
realname = self.params.realname or nickname
|
|
|
|
# these must remain non-awaited; reading hasn't started yet
|
|
if not self.params.password is None:
|
|
self.send(build("PASS", [self.params.password]))
|
|
self.send(build("CAP", ["LS", "302"]))
|
|
self.send(build("NICK", [nickname]))
|
|
self.send(build("USER", [username, "0", "*", realname]))
|
|
|
|
async def _on_read_emit(self, line: Line, emit: Emit):
|
|
if emit.command == "001":
|
|
await self.send(build("WHO", [self.nickname]))
|
|
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
|
|
|
|
elif emit.command == "CAP":
|
|
if emit.subcommand == "NEW":
|
|
await self._cap_ls(emit)
|
|
elif (emit.subcommand == "LS" and
|
|
emit.finished):
|
|
if not self.registered:
|
|
await CAPContext(self).handshake()
|
|
else:
|
|
await self._cap_ls(emit)
|
|
|
|
elif emit.command == "JOIN":
|
|
if emit.self and not emit.channel is None:
|
|
await self.send(build("MODE", [emit.channel.name]))
|
|
|
|
async def _on_read_line(self, line: Line):
|
|
if line.command == "PING":
|
|
await self.send(build("PONG", line.params))
|
|
|
|
async def line_read(self, line: Line):
|
|
pass
|
|
|
|
async def next_line(self) -> Tuple[Line, List[Emit]]:
|
|
if self._read_queue:
|
|
both = self._read_queue.popleft()
|
|
else:
|
|
data = await self._reader.read(1024)
|
|
while True:
|
|
lines = self.recv(data)
|
|
if lines:
|
|
self._read_queue.extend(lines[1:])
|
|
both = lines[0]
|
|
break
|
|
|
|
line, emits = both
|
|
for emit in emits:
|
|
await self._on_read_emit(line, emit)
|
|
await self._on_read_line(line)
|
|
await self.line_read(line)
|
|
|
|
return both
|
|
|
|
async def wait_for(self, response: IMatchResponse) -> Line:
|
|
our_fut: "Future[Line]" = Future()
|
|
self._wait_for.append((our_fut, response))
|
|
while self._wait_for:
|
|
both = await self.next_line()
|
|
line, emits = both
|
|
|
|
for i, (fut, waiting) in enumerate(self._wait_for):
|
|
if waiting.match(self, line):
|
|
fut.set_result(line)
|
|
self._wait_for.pop(i)
|
|
break
|
|
|
|
return await our_fut
|
|
|
|
async def line_send(self, line: Line):
|
|
pass
|
|
async def _on_write_line(self, line: Line):
|
|
if (line.command == "PRIVMSG" and
|
|
not self.cap_agreed(CAP_ECHO)):
|
|
new_line = line.with_source(self.hostmask())
|
|
emits = self.parse_tokens(new_line)
|
|
self._read_queue.append((new_line, emits))
|
|
|
|
async def _write_lines(self) -> List[Line]:
|
|
lines: List[SentLine] = []
|
|
|
|
while (not lines or
|
|
(len(lines) < 5 and self._write_queue.qsize() > 0)):
|
|
prio_line = await self._write_queue.get()
|
|
lines.append(prio_line)
|
|
|
|
for line in lines:
|
|
async with self.throttle:
|
|
self._writer.write(
|
|
f"{line.line.format()}\r\n".encode("utf8"))
|
|
|
|
await self._writer.drain()
|
|
|
|
for line in lines:
|
|
line.future.set_result(line)
|
|
await self._on_write_line(line.line)
|
|
await self.line_send(line.line)
|
|
|
|
return [l.line for l in lines]
|
|
|
|
# CAP-related
|
|
def cap_agreed(self, capability: ICapability) -> bool:
|
|
return bool(self.cap_available(capability))
|
|
def cap_available(self, capability: ICapability) -> Optional[str]:
|
|
return capability.available(self.agreed_caps)
|
|
|
|
async def _cap_ls(self, emit: Emit):
|
|
if not emit.tokens is None:
|
|
tokens: Dict[str, str] = {}
|
|
for token in emit.tokens:
|
|
key, _, value = token.partition("=")
|
|
tokens[key] = value
|
|
await CAPContext(self).on_ls(tokens)
|
|
|
|
async def sasl_auth(self, params: SASLParams) -> bool:
|
|
if (self.sasl_state == SASLResult.NONE and
|
|
self.cap_agreed(CAP_SASL)):
|
|
|
|
res = await SASLContext(self).from_params(params)
|
|
self.sasl_state = res
|
|
return True
|
|
else:
|
|
return False
|
|
# /CAP-related
|
|
|
|
def send_join(self,
|
|
name: str,
|
|
key: Optional[str]=None
|
|
) -> Awaitable[Channel]:
|
|
fut = self.send_joins([name], [] if key is None else [key])
|
|
|
|
async def _assure():
|
|
channels = await fut
|
|
return channels[0]
|
|
return MaybeAwait(_assure)
|
|
|
|
def send_joins(self,
|
|
names: List[str],
|
|
keys: List[str]=[]
|
|
) -> Awaitable[List[Channel]]:
|
|
|
|
folded_names = [self.casefold(name) for name in names]
|
|
|
|
if not keys:
|
|
fut = self.send(build("JOIN", [",".join(names)]))
|
|
else:
|
|
fut = self.send(build("JOIN", [",".join(names)]+keys))
|
|
|
|
async def _assure():
|
|
await fut
|
|
|
|
channels: List[Channel] = []
|
|
|
|
while folded_names:
|
|
line = await self.wait_for(
|
|
Numeric("RPL_CHANNELMODEIS", [ParamAny(), ParamAny()]))
|
|
|
|
folded = self.casefold(line.params[1])
|
|
if folded in folded_names:
|
|
folded_names.remove(folded)
|
|
channels.append(self.channels[folded])
|
|
|
|
return channels
|
|
return MaybeAwait(_assure)
|
|
|
|
def send_whois(self, target: str) -> Awaitable[Whois]:
|
|
folded = self.casefold(target)
|
|
fut = self.send(build("WHOIS", [target, target]))
|
|
|
|
async def _assure():
|
|
await fut
|
|
params = [ParamAny(), ParamFolded(folded)]
|
|
obj = Whois()
|
|
while True:
|
|
line = await self.wait_for(Numerics([
|
|
"RPL_WHOISUSER",
|
|
"RPL_WHOISSERVER",
|
|
"RPL_WHOISOPERATOR",
|
|
"RPL_WHOISIDLE",
|
|
"RPL_WHOISHOST",
|
|
"RPL_WHOISACCOUNT",
|
|
"RPL_WHOISSECURE",
|
|
"RPL_ENDOFWHOIS"
|
|
], params))
|
|
|
|
if line.command == NUMERIC_NAMES["RPL_WHOISUSER"]:
|
|
obj.username, obj.hostname, _, obj.realname = line.params[2:]
|
|
elif line.command == NUMERIC_NAMES["RPL_WHOISIDLE"]:
|
|
obj.idle, obj.signon, _ = line.params[2:]
|
|
elif line.command == NUMERIC_NAMES["RPL_WHOISACCOUNT"]:
|
|
obj.account = line.params[2]
|
|
elif line.command == NUMERIC_NAMES["RPL_ENDOFWHOIS"]:
|
|
return obj
|
|
return MaybeAwait(_assure)
|