move more CAP related stuff to CAPContext

This commit is contained in:
jesopo 2020-04-05 12:48:29 +01:00
parent afe9ec359d
commit f70932ac44
4 changed files with 36 additions and 46 deletions

View File

@ -1,5 +1,5 @@
from asyncio import Future
from typing import Awaitable, Iterable, List, Optional
from typing import Awaitable, Iterable, Set, Optional
from enum import IntEnum
from ircstates import Server
@ -33,7 +33,8 @@ class ICapability(object):
pass
class IServer(Server):
params: ConnectionParams
params: ConnectionParams
desired_caps: Set[ICapability]
async def send_raw(self, line: str, priority=SendPriority.DEFAULT):
pass
@ -49,9 +50,6 @@ class IServer(Server):
async def connect(self, params: ConnectionParams):
pass
async def queue_capability(self, cap: ICapability):
pass
async def line_read(self, line: Line):
pass
async def line_send(self, line: Line):
@ -65,8 +63,5 @@ class IServer(Server):
def cap_available(self, capability: ICapability) -> Optional[str]:
pass
def collect_caps(self) -> List[str]:
pass
async def sasl_auth(self, sasl: SASLParams) -> bool:
pass

View File

@ -1,4 +1,4 @@
from typing import Iterable, List, Optional
from typing import Dict, Iterable, List, Optional
from irctokens import build
from .contexts import ServerContext
@ -56,6 +56,33 @@ CAPS: List[ICapability] = [
]
class CAPContext(ServerContext):
async def on_ls(self, tokens: Dict[str, str]):
caps = list(self.server.desired_caps)+CAPS
if (not self.server.params.sasl is None and
not CAP_SASL in caps):
caps.append(CAP_SASL)
matched = (c.available(tokens) for c in caps)
cap_names = [name for name in matched if not name is None]
if cap_names:
await self.server.send(build("CAP", ["REQ", " ".join(cap_names)]))
while cap_names:
line = await self.server.wait_for(ResponseOr(
Response("CAP", [ParamAny(), ParamLiteral("ACK")]),
Response("CAP", [ParamAny(), ParamLiteral("NAK")])
))
current_caps = line.params[2].split(" ")
for cap in current_caps:
if cap in cap_names:
cap_names.remove(cap)
if (self.server.cap_agreed(CAP_SASL) and
not self.server.params.sasl is None):
await self.server.sasl_auth(self.server.params.sasl)
async def handshake(self) -> bool:
# improve this by being able to wait_for Emit objects
line = await self.server.wait_for(ResponseOr(
@ -67,26 +94,7 @@ class CAPContext(ServerContext):
))
if line.command == "CAP":
caps = self.server.collect_caps()
if caps:
await self.server.send(
build("CAP", ["REQ", " ".join(caps)]))
while caps:
line = await self.server.wait_for(ResponseOr(
Response("CAP", [ParamAny(), ParamLiteral("ACK")]),
Response("CAP", [ParamAny(), ParamLiteral("NAK")])
))
current_caps = line.params[2].split(" ")
for cap in current_caps:
if cap in caps:
caps.remove(cap)
if (self.server.cap_agreed(CAP_SASL) and
not self.server.params.sasl is None):
await self.server.sasl_auth(self.server.params.sasl)
await self.on_ls(self.server.available_caps)
await self.server.send(build("CAP", ["END"]))
return True
else:

View File

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List
from enum import Enum
from base64 import b64decode, b64encode
from irctokens import build

View File

@ -1,13 +1,13 @@
import asyncio
from ssl import SSLContext
from asyncio import Future, PriorityQueue, Queue
from typing import Awaitable, List, Optional, Set, Tuple
from typing import List, Optional, Set, Tuple
from asyncio_throttle import Throttler
from ircstates import Emit
from irctokens import build, Line, tokenise
from .ircv3 import CAPContext, CAPS, CAP_SASL
from .ircv3 import CAPContext, CAP_SASL
from .interface import (ConnectionParams, ICapability, IServer, SentLine,
SendPriority, SASLParams)
from .matching import BaseResponse
@ -33,7 +33,7 @@ class Server(IServer):
self._wait_for_cache: List[Tuple[Line, List[Emit]]] = []
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
self._read_queue: Queue[Tuple[Line, List[Emit]]] = Queue()
self._cap_queue: Set[ICapability] = set([])
self.desired_caps: Set[ICapability] = set([])
async def send_raw(self, line: str, priority=SendPriority.DEFAULT):
await self.send(tokenise(line), priority)
@ -140,24 +140,11 @@ class Server(IServer):
return [l.line for l in lines]
# CAP-related
async def queue_capability(self, cap: ICapability):
self._cap_queue.add(cap)
def cap_agreed(self, capability: ICapability) -> bool:
return bool(self.cap_available(capability))
def cap_available(self, capability: ICapability) -> Optional[str]:
return capability.available(self.agreed_caps)
def collect_caps(self) -> List[str]:
caps = CAPS+list(self._cap_queue)
self._cap_queue.clear()
if not self.params.sasl is None:
caps.append(CAP_SASL)
matched = [c.available(self.available_caps) for c in caps]
return [name for name in matched if not name is None]
async def _cap_new(self, emit: Emit):
if not emit.tokens is None:
tokens = [t.split("=", 1)[0] for t in emit.tokens]