From f70932ac4452e0f17f54a9ddd45e7672f727227e Mon Sep 17 00:00:00 2001 From: jesopo Date: Sun, 5 Apr 2020 12:48:29 +0100 Subject: [PATCH] move more CAP related stuff to CAPContext --- ircrobots/interface.py | 11 +++------- ircrobots/ircv3.py | 50 ++++++++++++++++++++++++------------------ ircrobots/sasl.py | 2 +- ircrobots/server.py | 19 +++------------- 4 files changed, 36 insertions(+), 46 deletions(-) diff --git a/ircrobots/interface.py b/ircrobots/interface.py index 98989ad..780126b 100644 --- a/ircrobots/interface.py +++ b/ircrobots/interface.py @@ -1,5 +1,5 @@ from asyncio import Future -from typing import Awaitable, Iterable, List, Optional +from typing import Awaitable, Iterable, Set, Optional from enum import IntEnum from ircstates import Server @@ -33,7 +33,8 @@ class ICapability(object): pass class IServer(Server): - params: ConnectionParams + params: ConnectionParams + desired_caps: Set[ICapability] async def send_raw(self, line: str, priority=SendPriority.DEFAULT): pass @@ -49,9 +50,6 @@ class IServer(Server): async def connect(self, params: ConnectionParams): pass - async def queue_capability(self, cap: ICapability): - pass - async def line_read(self, line: Line): pass async def line_send(self, line: Line): @@ -65,8 +63,5 @@ class IServer(Server): def cap_available(self, capability: ICapability) -> Optional[str]: pass - def collect_caps(self) -> List[str]: - pass - async def sasl_auth(self, sasl: SASLParams) -> bool: pass diff --git a/ircrobots/ircv3.py b/ircrobots/ircv3.py index f266c76..3d81a32 100644 --- a/ircrobots/ircv3.py +++ b/ircrobots/ircv3.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional +from typing import Dict, Iterable, List, Optional from irctokens import build from .contexts import ServerContext @@ -56,6 +56,33 @@ CAPS: List[ICapability] = [ ] class CAPContext(ServerContext): + async def on_ls(self, tokens: Dict[str, str]): + caps = list(self.server.desired_caps)+CAPS + + if (not self.server.params.sasl is None and + not CAP_SASL in caps): + caps.append(CAP_SASL) + + matched = (c.available(tokens) for c in caps) + cap_names = [name for name in matched if not name is None] + + if cap_names: + await self.server.send(build("CAP", ["REQ", " ".join(cap_names)])) + + while cap_names: + line = await self.server.wait_for(ResponseOr( + Response("CAP", [ParamAny(), ParamLiteral("ACK")]), + Response("CAP", [ParamAny(), ParamLiteral("NAK")]) + )) + + current_caps = line.params[2].split(" ") + for cap in current_caps: + if cap in cap_names: + cap_names.remove(cap) + if (self.server.cap_agreed(CAP_SASL) and + not self.server.params.sasl is None): + await self.server.sasl_auth(self.server.params.sasl) + async def handshake(self) -> bool: # improve this by being able to wait_for Emit objects line = await self.server.wait_for(ResponseOr( @@ -67,26 +94,7 @@ class CAPContext(ServerContext): )) if line.command == "CAP": - caps = self.server.collect_caps() - if caps: - await self.server.send( - build("CAP", ["REQ", " ".join(caps)])) - - while caps: - line = await self.server.wait_for(ResponseOr( - Response("CAP", [ParamAny(), ParamLiteral("ACK")]), - Response("CAP", [ParamAny(), ParamLiteral("NAK")]) - )) - - current_caps = line.params[2].split(" ") - for cap in current_caps: - if cap in caps: - caps.remove(cap) - - if (self.server.cap_agreed(CAP_SASL) and - not self.server.params.sasl is None): - await self.server.sasl_auth(self.server.params.sasl) - + await self.on_ls(self.server.available_caps) await self.server.send(build("CAP", ["END"])) return True else: diff --git a/ircrobots/sasl.py b/ircrobots/sasl.py index 1fc898f..489a4d5 100644 --- a/ircrobots/sasl.py +++ b/ircrobots/sasl.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List from enum import Enum from base64 import b64decode, b64encode from irctokens import build diff --git a/ircrobots/server.py b/ircrobots/server.py index ac5325e..51c7488 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -1,13 +1,13 @@ import asyncio from ssl import SSLContext from asyncio import Future, PriorityQueue, Queue -from typing import Awaitable, List, Optional, Set, Tuple +from typing import List, Optional, Set, Tuple from asyncio_throttle import Throttler from ircstates import Emit from irctokens import build, Line, tokenise -from .ircv3 import CAPContext, CAPS, CAP_SASL +from .ircv3 import CAPContext, CAP_SASL from .interface import (ConnectionParams, ICapability, IServer, SentLine, SendPriority, SASLParams) from .matching import BaseResponse @@ -33,7 +33,7 @@ class Server(IServer): self._wait_for_cache: List[Tuple[Line, List[Emit]]] = [] self._write_queue: PriorityQueue[SentLine] = PriorityQueue() self._read_queue: Queue[Tuple[Line, List[Emit]]] = Queue() - self._cap_queue: Set[ICapability] = set([]) + self.desired_caps: Set[ICapability] = set([]) async def send_raw(self, line: str, priority=SendPriority.DEFAULT): await self.send(tokenise(line), priority) @@ -140,24 +140,11 @@ class Server(IServer): return [l.line for l in lines] # CAP-related - async def queue_capability(self, cap: ICapability): - self._cap_queue.add(cap) - def cap_agreed(self, capability: ICapability) -> bool: return bool(self.cap_available(capability)) def cap_available(self, capability: ICapability) -> Optional[str]: return capability.available(self.agreed_caps) - def collect_caps(self) -> List[str]: - caps = CAPS+list(self._cap_queue) - self._cap_queue.clear() - - if not self.params.sasl is None: - caps.append(CAP_SASL) - - matched = [c.available(self.available_caps) for c in caps] - return [name for name in matched if not name is None] - async def _cap_new(self, emit: Emit): if not emit.tokens is None: tokens = [t.split("=", 1)[0] for t in emit.tokens]