allow ResponseOr to be shorthanded as a Set[IMatchResponse]

This commit is contained in:
jesopo 2020-04-27 01:28:46 +01:00
parent 769390baf7
commit 0921cb8086
5 changed files with 27 additions and 16 deletions

View File

@ -1,5 +1,5 @@
from asyncio import Future
from typing import Awaitable, Iterable, List, Optional, Set, Tuple
from typing import Awaitable, Iterable, List, Optional, Set, Tuple, Union
from enum import IntEnum
from ircstates import Server, Emit
@ -85,7 +85,9 @@ class IServer(Server):
) -> Awaitable[SentLine]:
pass
def wait_for(self, response: IMatchResponse) -> Awaitable[Line]:
def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]]
) -> Awaitable[Line]:
pass
def set_throttle(self, rate: int, time: float):

View File

@ -5,7 +5,7 @@ from irctokens import build
from ircstates.server import ServerDisconnectedException
from .contexts import ServerContext
from .matching import Response, ResponseOr, ANY
from .matching import Response, ANY
from .interface import ICapability
from .params import ConnectionParams, STSPolicy, ResumePolicy
@ -108,10 +108,10 @@ class CAPContext(ServerContext):
await self.server.send(build("CAP", ["REQ", " ".join(cap_names)]))
while cap_names:
line = await self.server.wait_for(ResponseOr(
line = await self.server.wait_for({
Response("CAP", [ANY, "ACK"]),
Response("CAP", [ANY, "NAK"])
))
})
current_caps = line.params[2].split(" ")
for cap in current_caps:
@ -136,10 +136,10 @@ class CAPContext(ServerContext):
if previous_policy is not None and not self.server.registered:
await self.server.send(build("RESUME", [previous_policy.token]))
line = await self.server.wait_for(ResponseOr(
line = await self.server.wait_for({
Response("RESUME", ["SUCCESS"]),
Response("FAIL", ["RESUME"])
))
})
if line.command == "RESUME":
raise HandshakeCancel()

View File

@ -3,7 +3,7 @@ from irctokens import build
from ircstates.numerics import *
from .contexts import ServerContext
from .matching import Response, ResponseOr, ANY, Folded
from .matching import Response, ANY, Folded
class WHOContext(ServerContext):
async def ensure(self, channel: str):

View File

@ -4,7 +4,7 @@ from base64 import b64decode, b64encode
from irctokens import build
from ircstates.numerics import *
from .matching import ResponseOr, Responses, Response, ANY
from .matching import Responses, Response, ANY
from .contexts import ServerContext
from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal
from .scram import SCRAMContext, SCRAMAlgorithm
@ -60,10 +60,10 @@ class SASLContext(ServerContext):
async def external(self) -> SASLResult:
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
line = await self.server.wait_for(ResponseOr(
line = await self.server.wait_for({
AUTHENTICATE_ANY,
NUMERICS_INITIAL
))
})
if line.command == "907":
# we've done SASL already. cleanly abort
@ -117,10 +117,10 @@ class SASLContext(ServerContext):
while match:
await self.server.send(build("AUTHENTICATE", [match[0]]))
line = await self.server.wait_for(ResponseOr(
line = await self.server.wait_for({
AUTHENTICATE_ANY,
NUMERICS_INITIAL
))
})
if line.command == "907":
# we've done SASL already. cleanly abort

View File

@ -1,6 +1,7 @@
import asyncio
from asyncio import Future, PriorityQueue
from typing import Awaitable, Deque, Dict, List, Optional, Set, Tuple
from typing import (Awaitable, Deque, Dict, List, Optional, Set, Tuple,
Union)
from collections import deque
from time import monotonic
@ -222,12 +223,20 @@ class Server(IServer):
return both
async def wait_for(self, response: IMatchResponse) -> Line:
async def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]]
) -> Line:
response_obj: IMatchResponse
if isinstance(response, set):
response_obj = ResponseOr(*response)
else:
response_obj = response
wait_for_fut = self._wait_for_fut
if wait_for_fut is not None:
self._wait_for_fut = None
our_wait_for = WaitFor(response)
our_wait_for = WaitFor(response_obj)
wait_for_fut.set_result(our_wait_for)
return await our_wait_for
raise Exception()