add basic code for using labeled-responses as wait_for matches

This commit is contained in:
jesopo 2020-05-24 01:05:51 +01:00
parent 899c9c0b49
commit 33bcba8001
3 changed files with 51 additions and 11 deletions

View File

@ -1,8 +1,10 @@
from asyncio import Future
from irctokens import Line
from typing import Any, Awaitable, Callable, Generator, Generic, TypeVar
from .matching import IMatchResponse
from asyncio import Future
from irctokens import Line
from typing import (Any, Awaitable, Callable, Generator, Generic, Optional,
TypeVar)
from .matching import IMatchResponse
from .interface import IServer
from .ircv3 import TAG_LABEL
TEvent = TypeVar("TEvent")
class MaybeAwait(Generic[TEvent]):
@ -16,9 +18,11 @@ class MaybeAwait(Generic[TEvent]):
class WaitFor(object):
def __init__(self,
wait_fut: "Future[WaitFor]",
response: IMatchResponse):
response: IMatchResponse,
label: Optional[str]):
self._wait_fut = wait_fut
self.response = response
self._label = label
self.deferred = False
self._our_fut: "Future[Line]" = Future()
@ -30,6 +34,12 @@ class WaitFor(object):
return await self
def match(self, server: IServer, line: Line):
if (self._label is not None and
line.tags is not None):
label = TAG_LABEL.get(line.tags)
if (label is not None and
label == self._label):
return True
return self.response.match(server, line)
def resolve(self, line: Line):

View File

@ -40,13 +40,36 @@ class Capability(ICapability):
alias=self.alias,
depends_on=self.depends_on[:])
class MessageTag(object):
def __init__(self,
name: Optional[str],
draft_name: Optional[str]=None):
self.name = name
self.draft = draft_name
self._tags = [self.name, self.draft]
def available(self, tags: Iterable[str]) -> Optional[str]:
for tag in self._tags:
if tag is not None and tag in tags:
return tag
else:
return None
def get(self, tags: Dict[str, str]) -> Optional[str]:
name = self.available(tags)
if name is not None:
return tags[name]
else:
return None
CAP_SASL = Capability("sasl")
CAP_ECHO = Capability("echo-message")
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
CAP_STS = Capability("sts", "draft/sts")
CAP_RESUME = Capability(None, "draft/resume-0.5", alias="resume")
LABEL_TAG = {
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
TAG_LABEL = MessageTag("label", "draft/label")
LABEL_TAG_MAP = {
"draft/labeled-response-0.2": "draft/label",
"labeled-response": "label"
}

View File

@ -13,7 +13,7 @@ from ircstates.server import ServerDisconnectedException
from irctokens import build, Line, tokenise
from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL,
CAP_LABEL, LABEL_TAG, resume_transmute)
CAP_LABEL, LABEL_TAG_MAP, resume_transmute)
from .sasl import SASLContext, SASLResult
from .join_info import WHOContext
from .matching import (ResponseOr, Responses, Response, ANY, SELF, MASK_SELF,
@ -85,7 +85,7 @@ class Server(IServer):
label = self.cap_available(CAP_LABEL)
if not label is None:
tag = LABEL_TAG[label]
tag = LABEL_TAG_MAP[label]
if line.tags is None or not tag in line.tags:
if line.tags is None:
line.tags = {}
@ -259,8 +259,10 @@ class Server(IServer):
break
def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]]
response: Union[IMatchResponse, Set[IMatchResponse]],
sent_line: Optional[SentLine]=None
) -> Awaitable[Line]:
response_obj: IMatchResponse
if isinstance(response, set):
response_obj = ResponseOr(*response)
@ -270,7 +272,12 @@ class Server(IServer):
wait_for_fut = self._wait_for_fut
if wait_for_fut is not None:
self._wait_for_fut = None
our_wait_for = WaitFor(wait_for_fut, response_obj)
label: Optional[str] = None
if sent_line is not None:
label = str(sent_line.id)
our_wait_for = WaitFor(wait_for_fut, response_obj, label)
return our_wait_for
raise Exception()