4
0
mirror of https://github.com/jesopo/ircrobots synced 2024-06-14 20:26:36 +00:00

add numerics.py to translate names, remove Response(errors=)

This commit is contained in:
jesopo 2020-04-02 20:55:01 +01:00
parent 06a4d20fc8
commit a4f5d8045f
4 changed files with 48 additions and 23 deletions

View File

@ -2,7 +2,8 @@ from typing import Iterable, List, Optional
from irctokens import build from irctokens import build
from .contexts import ServerContext from .contexts import ServerContext
from .matching import Response, ResponseOr, ParamAny, ParamNot, ParamLiteral from .matching import (Response, Numerics, ResponseOr, ParamAny, ParamNot,
ParamLiteral)
from .interface import ICapability from .interface import ICapability
class Capability(ICapability): class Capability(ICapability):
@ -57,10 +58,13 @@ CAPS: List[ICapability] = [
class CAPContext(ServerContext): class CAPContext(ServerContext):
async def handshake(self) -> bool: async def handshake(self) -> bool:
# improve this by being able to wait_for Emit objects # improve this by being able to wait_for Emit objects
line = await self.server.wait_for(Response( line = await self.server.wait_for(ResponseOr(
"CAP", Response(
[ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))], "CAP",
errors=["001"])) [ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))]
),
Numerics(["RPL_WELCOME"])
))
if line.command == "CAP": if line.command == "CAP":
caps = self.server.collect_caps() caps = self.server.collect_caps()

View File

@ -1,5 +1,6 @@
from typing import List, Optional from typing import List
from irctokens import Line from irctokens import Line
from .numerics import NUMERIC_NAMES
class ResponseParam(object): class ResponseParam(object):
def match(self, arg: str) -> bool: def match(self, arg: str) -> bool:
@ -12,7 +13,7 @@ class BaseResponse(object):
class Numerics(BaseResponse): class Numerics(BaseResponse):
def __init__(self, def __init__(self,
numerics: List[str]): numerics: List[str]):
self._numerics = numerics self._numerics = [NUMERIC_NAMES.get(n, n) for n in numerics]
def match(self, line: Line): def match(self, line: Line):
return line.command in self._numerics return line.command in self._numerics
@ -20,11 +21,9 @@ class Numerics(BaseResponse):
class Response(BaseResponse): class Response(BaseResponse):
def __init__(self, def __init__(self,
command: str, command: str,
params: List[ResponseParam], params: List[ResponseParam]):
errors: Optional[List[str]] = None):
self._command = command self._command = command
self._params = params self._params = params
self._errors = errors or []
def match(self, line: Line) -> bool: def match(self, line: Line) -> bool:
if line.command == self._command: if line.command == self._command:
@ -34,8 +33,6 @@ class Response(BaseResponse):
return False return False
else: else:
return True return True
elif line.command in self._errors:
return True
else: else:
return False return False

16
ircrobots/numerics.py Normal file
View File

@ -0,0 +1,16 @@
NUMERIC_NUMBERS = {}
NUMERIC_NAMES = {}
def _numeric(number: str, name: str):
NUMERIC_NUMBERS[number] = name
NUMERIC_NAMES[name] = number
_numeric("001", "RPL_WELCOME")
_numeric("005", "RPL_ISUPPORT")
_numeric("903", "RPL_SASLSUCCESS")
_numeric("904", "ERR_SASLFAIL")
_numeric("905", "ERR_SASLTOOLONG")
_numeric("906", "ERR_SASLABORTED")
_numeric("907", "ERR_SASLALREADY")
_numeric("908", "RPL_SASLMECHS")

View File

@ -3,7 +3,7 @@ from enum import Enum
from base64 import b64encode from base64 import b64encode
from irctokens import build from irctokens import build
from .matching import Response, Numerics, ParamAny from .matching import Response, ResponseOr, Numerics, ParamAny
from .contexts import ServerContext from .contexts import ServerContext
from .params import SASLParams from .params import SASLParams
@ -25,6 +25,10 @@ class SASLError(Exception):
class SASLUnknownMechanismError(SASLError): class SASLUnknownMechanismError(SASLError):
pass pass
NUMERICS_INITIAL = Numerics(
["ERR_SASLFAIL", "ERR_SASLALREADY", "RPL_SASLMECHS"])
NUMERICS_LAST = Numerics(["RPL_SASLSUCCESS", "ERR_SASLFAIL"])
class SASLContext(ServerContext): class SASLContext(ServerContext):
async def from_params(self, params: SASLParams) -> SASLResult: async def from_params(self, params: SASLParams) -> SASLResult:
if params.mechanism == "USERPASS": if params.mechanism == "USERPASS":
@ -38,8 +42,10 @@ class SASLContext(ServerContext):
async def external(self) -> SASLResult: async def external(self) -> SASLResult:
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"])) await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
line = await self.server.wait_for(Response("AUTHENTICATE", line = await self.server.wait_for(ResponseOr(
[ParamAny()], errors=["904", "907", "908"])) Response("AUTHENTICATE", [ParamAny()]),
NUMERICS_INITIAL
))
if line.command == "907": if line.command == "907":
# we've done SASL already. cleanly abort # we've done SASL already. cleanly abort
@ -52,7 +58,7 @@ class SASLContext(ServerContext):
elif line.command == "AUTHENTICATE" and line.params[0] == "+": elif line.command == "AUTHENTICATE" and line.params[0] == "+":
await self.server.send(build("AUTHENTICATE", ["+"])) await self.server.send(build("AUTHENTICATE", ["+"]))
line = await self.server.wait_for(Numerics(["903", "904"])) line = await self.server.wait_for(NUMERICS_LAST)
if line.command == "903": if line.command == "903":
return SASLResult.SUCCESS return SASLResult.SUCCESS
return SASLResult.FAILURE return SASLResult.FAILURE
@ -80,8 +86,10 @@ class SASLContext(ServerContext):
match = SASL_USERPASS_MECHANISMS[0] match = SASL_USERPASS_MECHANISMS[0]
await self.server.send(build("AUTHENTICATE", [match])) await self.server.send(build("AUTHENTICATE", [match]))
line = await self.server.wait_for(Response("AUTHENTICATE", line = await self.server.wait_for(ResponseOr(
[ParamAny()], errors=["904", "907", "908"])) Response("AUTHENTICATE", [ParamAny()]),
NUMERICS_INITIAL
))
if line.command == "907": if line.command == "907":
# we've done SASL already. cleanly abort # we've done SASL already. cleanly abort
@ -92,8 +100,8 @@ class SASLContext(ServerContext):
match = _common(available) match = _common(available)
await self.server.send(build("AUTHENTICATE", [match])) await self.server.send(build("AUTHENTICATE", [match]))
line = await self.server.wait_for(Response("AUTHENTICATE", line = await self.server.wait_for(
[ParamAny()])) Response("AUTHENTICATE", [ParamAny()]))
if line.command == "AUTHENTICATE" and line.params[0] == "+": if line.command == "AUTHENTICATE" and line.params[0] == "+":
auth_text: Optional[str] = None auth_text: Optional[str] = None
@ -101,11 +109,11 @@ class SASLContext(ServerContext):
auth_text = f"{username}\0{username}\0{password}" auth_text = f"{username}\0{username}\0{password}"
if not auth_text is None: if not auth_text is None:
auth_b64 = b64encode(auth_text.encode("utf8") auth_b64 = b64encode(
).decode("ascii") auth_text.encode("utf8")).decode("ascii")
await self.server.send(build("AUTHENTICATE", [auth_b64])) await self.server.send(build("AUTHENTICATE", [auth_b64]))
line = await self.server.wait_for(Numerics(["903", "904"])) line = await self.server.wait_for(NUMERICS_LAST)
if line.command == "903": if line.command == "903":
return SASLResult.SUCCESS return SASLResult.SUCCESS
return SASLResult.FAILURE return SASLResult.FAILURE