add numerics.py to translate names, remove Response(errors=)

This commit is contained in:
jesopo 2020-04-02 20:55:01 +01:00
parent 06a4d20fc8
commit a4f5d8045f
4 changed files with 48 additions and 23 deletions

View File

@ -2,7 +2,8 @@ from typing import Iterable, List, Optional
from irctokens import build
from .contexts import ServerContext
from .matching import Response, ResponseOr, ParamAny, ParamNot, ParamLiteral
from .matching import (Response, Numerics, ResponseOr, ParamAny, ParamNot,
ParamLiteral)
from .interface import ICapability
class Capability(ICapability):
@ -57,10 +58,13 @@ CAPS: List[ICapability] = [
class CAPContext(ServerContext):
async def handshake(self) -> bool:
# improve this by being able to wait_for Emit objects
line = await self.server.wait_for(Response(
"CAP",
[ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))],
errors=["001"]))
line = await self.server.wait_for(ResponseOr(
Response(
"CAP",
[ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))]
),
Numerics(["RPL_WELCOME"])
))
if line.command == "CAP":
caps = self.server.collect_caps()

View File

@ -1,5 +1,6 @@
from typing import List, Optional
from typing import List
from irctokens import Line
from .numerics import NUMERIC_NAMES
class ResponseParam(object):
def match(self, arg: str) -> bool:
@ -12,7 +13,7 @@ class BaseResponse(object):
class Numerics(BaseResponse):
def __init__(self,
numerics: List[str]):
self._numerics = numerics
self._numerics = [NUMERIC_NAMES.get(n, n) for n in numerics]
def match(self, line: Line):
return line.command in self._numerics
@ -20,11 +21,9 @@ class Numerics(BaseResponse):
class Response(BaseResponse):
def __init__(self,
command: str,
params: List[ResponseParam],
errors: Optional[List[str]] = None):
params: List[ResponseParam]):
self._command = command
self._params = params
self._errors = errors or []
def match(self, line: Line) -> bool:
if line.command == self._command:
@ -34,8 +33,6 @@ class Response(BaseResponse):
return False
else:
return True
elif line.command in self._errors:
return True
else:
return False

16
ircrobots/numerics.py Normal file
View File

@ -0,0 +1,16 @@
NUMERIC_NUMBERS = {}
NUMERIC_NAMES = {}
def _numeric(number: str, name: str):
NUMERIC_NUMBERS[number] = name
NUMERIC_NAMES[name] = number
_numeric("001", "RPL_WELCOME")
_numeric("005", "RPL_ISUPPORT")
_numeric("903", "RPL_SASLSUCCESS")
_numeric("904", "ERR_SASLFAIL")
_numeric("905", "ERR_SASLTOOLONG")
_numeric("906", "ERR_SASLABORTED")
_numeric("907", "ERR_SASLALREADY")
_numeric("908", "RPL_SASLMECHS")

View File

@ -3,7 +3,7 @@ from enum import Enum
from base64 import b64encode
from irctokens import build
from .matching import Response, Numerics, ParamAny
from .matching import Response, ResponseOr, Numerics, ParamAny
from .contexts import ServerContext
from .params import SASLParams
@ -25,6 +25,10 @@ class SASLError(Exception):
class SASLUnknownMechanismError(SASLError):
pass
NUMERICS_INITIAL = Numerics(
["ERR_SASLFAIL", "ERR_SASLALREADY", "RPL_SASLMECHS"])
NUMERICS_LAST = Numerics(["RPL_SASLSUCCESS", "ERR_SASLFAIL"])
class SASLContext(ServerContext):
async def from_params(self, params: SASLParams) -> SASLResult:
if params.mechanism == "USERPASS":
@ -38,8 +42,10 @@ class SASLContext(ServerContext):
async def external(self) -> SASLResult:
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
line = await self.server.wait_for(Response("AUTHENTICATE",
[ParamAny()], errors=["904", "907", "908"]))
line = await self.server.wait_for(ResponseOr(
Response("AUTHENTICATE", [ParamAny()]),
NUMERICS_INITIAL
))
if line.command == "907":
# we've done SASL already. cleanly abort
@ -52,7 +58,7 @@ class SASLContext(ServerContext):
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
await self.server.send(build("AUTHENTICATE", ["+"]))
line = await self.server.wait_for(Numerics(["903", "904"]))
line = await self.server.wait_for(NUMERICS_LAST)
if line.command == "903":
return SASLResult.SUCCESS
return SASLResult.FAILURE
@ -80,8 +86,10 @@ class SASLContext(ServerContext):
match = SASL_USERPASS_MECHANISMS[0]
await self.server.send(build("AUTHENTICATE", [match]))
line = await self.server.wait_for(Response("AUTHENTICATE",
[ParamAny()], errors=["904", "907", "908"]))
line = await self.server.wait_for(ResponseOr(
Response("AUTHENTICATE", [ParamAny()]),
NUMERICS_INITIAL
))
if line.command == "907":
# we've done SASL already. cleanly abort
@ -92,8 +100,8 @@ class SASLContext(ServerContext):
match = _common(available)
await self.server.send(build("AUTHENTICATE", [match]))
line = await self.server.wait_for(Response("AUTHENTICATE",
[ParamAny()]))
line = await self.server.wait_for(
Response("AUTHENTICATE", [ParamAny()]))
if line.command == "AUTHENTICATE" and line.params[0] == "+":
auth_text: Optional[str] = None
@ -101,11 +109,11 @@ class SASLContext(ServerContext):
auth_text = f"{username}\0{username}\0{password}"
if not auth_text is None:
auth_b64 = b64encode(auth_text.encode("utf8")
).decode("ascii")
auth_b64 = b64encode(
auth_text.encode("utf8")).decode("ascii")
await self.server.send(build("AUTHENTICATE", [auth_b64]))
line = await self.server.wait_for(Numerics(["903", "904"]))
line = await self.server.wait_for(NUMERICS_LAST)
if line.command == "903":
return SASLResult.SUCCESS
return SASLResult.FAILURE