mirror of
https://github.com/jesopo/ircrobots
synced 2024-06-14 20:26:36 +00:00
add numerics.py to translate names, remove Response(errors=)
This commit is contained in:
parent
06a4d20fc8
commit
a4f5d8045f
|
@ -2,7 +2,8 @@ from typing import Iterable, List, Optional
|
||||||
from irctokens import build
|
from irctokens import build
|
||||||
|
|
||||||
from .contexts import ServerContext
|
from .contexts import ServerContext
|
||||||
from .matching import Response, ResponseOr, ParamAny, ParamNot, ParamLiteral
|
from .matching import (Response, Numerics, ResponseOr, ParamAny, ParamNot,
|
||||||
|
ParamLiteral)
|
||||||
from .interface import ICapability
|
from .interface import ICapability
|
||||||
|
|
||||||
class Capability(ICapability):
|
class Capability(ICapability):
|
||||||
|
@ -57,10 +58,13 @@ CAPS: List[ICapability] = [
|
||||||
class CAPContext(ServerContext):
|
class CAPContext(ServerContext):
|
||||||
async def handshake(self) -> bool:
|
async def handshake(self) -> bool:
|
||||||
# improve this by being able to wait_for Emit objects
|
# improve this by being able to wait_for Emit objects
|
||||||
line = await self.server.wait_for(Response(
|
line = await self.server.wait_for(ResponseOr(
|
||||||
"CAP",
|
Response(
|
||||||
[ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))],
|
"CAP",
|
||||||
errors=["001"]))
|
[ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))]
|
||||||
|
),
|
||||||
|
Numerics(["RPL_WELCOME"])
|
||||||
|
))
|
||||||
|
|
||||||
if line.command == "CAP":
|
if line.command == "CAP":
|
||||||
caps = self.server.collect_caps()
|
caps = self.server.collect_caps()
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
from irctokens import Line
|
from irctokens import Line
|
||||||
|
from .numerics import NUMERIC_NAMES
|
||||||
|
|
||||||
class ResponseParam(object):
|
class ResponseParam(object):
|
||||||
def match(self, arg: str) -> bool:
|
def match(self, arg: str) -> bool:
|
||||||
|
@ -12,7 +13,7 @@ class BaseResponse(object):
|
||||||
class Numerics(BaseResponse):
|
class Numerics(BaseResponse):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
numerics: List[str]):
|
numerics: List[str]):
|
||||||
self._numerics = numerics
|
self._numerics = [NUMERIC_NAMES.get(n, n) for n in numerics]
|
||||||
|
|
||||||
def match(self, line: Line):
|
def match(self, line: Line):
|
||||||
return line.command in self._numerics
|
return line.command in self._numerics
|
||||||
|
@ -20,11 +21,9 @@ class Numerics(BaseResponse):
|
||||||
class Response(BaseResponse):
|
class Response(BaseResponse):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
command: str,
|
command: str,
|
||||||
params: List[ResponseParam],
|
params: List[ResponseParam]):
|
||||||
errors: Optional[List[str]] = None):
|
|
||||||
self._command = command
|
self._command = command
|
||||||
self._params = params
|
self._params = params
|
||||||
self._errors = errors or []
|
|
||||||
|
|
||||||
def match(self, line: Line) -> bool:
|
def match(self, line: Line) -> bool:
|
||||||
if line.command == self._command:
|
if line.command == self._command:
|
||||||
|
@ -34,8 +33,6 @@ class Response(BaseResponse):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
elif line.command in self._errors:
|
|
||||||
return True
|
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
16
ircrobots/numerics.py
Normal file
16
ircrobots/numerics.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
NUMERIC_NUMBERS = {}
|
||||||
|
NUMERIC_NAMES = {}
|
||||||
|
|
||||||
|
def _numeric(number: str, name: str):
|
||||||
|
NUMERIC_NUMBERS[number] = name
|
||||||
|
NUMERIC_NAMES[name] = number
|
||||||
|
|
||||||
|
_numeric("001", "RPL_WELCOME")
|
||||||
|
_numeric("005", "RPL_ISUPPORT")
|
||||||
|
|
||||||
|
_numeric("903", "RPL_SASLSUCCESS")
|
||||||
|
_numeric("904", "ERR_SASLFAIL")
|
||||||
|
_numeric("905", "ERR_SASLTOOLONG")
|
||||||
|
_numeric("906", "ERR_SASLABORTED")
|
||||||
|
_numeric("907", "ERR_SASLALREADY")
|
||||||
|
_numeric("908", "RPL_SASLMECHS")
|
|
@ -3,7 +3,7 @@ from enum import Enum
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
from irctokens import build
|
from irctokens import build
|
||||||
|
|
||||||
from .matching import Response, Numerics, ParamAny
|
from .matching import Response, ResponseOr, Numerics, ParamAny
|
||||||
from .contexts import ServerContext
|
from .contexts import ServerContext
|
||||||
from .params import SASLParams
|
from .params import SASLParams
|
||||||
|
|
||||||
|
@ -25,6 +25,10 @@ class SASLError(Exception):
|
||||||
class SASLUnknownMechanismError(SASLError):
|
class SASLUnknownMechanismError(SASLError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
NUMERICS_INITIAL = Numerics(
|
||||||
|
["ERR_SASLFAIL", "ERR_SASLALREADY", "RPL_SASLMECHS"])
|
||||||
|
NUMERICS_LAST = Numerics(["RPL_SASLSUCCESS", "ERR_SASLFAIL"])
|
||||||
|
|
||||||
class SASLContext(ServerContext):
|
class SASLContext(ServerContext):
|
||||||
async def from_params(self, params: SASLParams) -> SASLResult:
|
async def from_params(self, params: SASLParams) -> SASLResult:
|
||||||
if params.mechanism == "USERPASS":
|
if params.mechanism == "USERPASS":
|
||||||
|
@ -38,8 +42,10 @@ class SASLContext(ServerContext):
|
||||||
|
|
||||||
async def external(self) -> SASLResult:
|
async def external(self) -> SASLResult:
|
||||||
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
|
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
|
||||||
line = await self.server.wait_for(Response("AUTHENTICATE",
|
line = await self.server.wait_for(ResponseOr(
|
||||||
[ParamAny()], errors=["904", "907", "908"]))
|
Response("AUTHENTICATE", [ParamAny()]),
|
||||||
|
NUMERICS_INITIAL
|
||||||
|
))
|
||||||
|
|
||||||
if line.command == "907":
|
if line.command == "907":
|
||||||
# we've done SASL already. cleanly abort
|
# we've done SASL already. cleanly abort
|
||||||
|
@ -52,7 +58,7 @@ class SASLContext(ServerContext):
|
||||||
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
|
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
|
||||||
await self.server.send(build("AUTHENTICATE", ["+"]))
|
await self.server.send(build("AUTHENTICATE", ["+"]))
|
||||||
|
|
||||||
line = await self.server.wait_for(Numerics(["903", "904"]))
|
line = await self.server.wait_for(NUMERICS_LAST)
|
||||||
if line.command == "903":
|
if line.command == "903":
|
||||||
return SASLResult.SUCCESS
|
return SASLResult.SUCCESS
|
||||||
return SASLResult.FAILURE
|
return SASLResult.FAILURE
|
||||||
|
@ -80,8 +86,10 @@ class SASLContext(ServerContext):
|
||||||
match = SASL_USERPASS_MECHANISMS[0]
|
match = SASL_USERPASS_MECHANISMS[0]
|
||||||
|
|
||||||
await self.server.send(build("AUTHENTICATE", [match]))
|
await self.server.send(build("AUTHENTICATE", [match]))
|
||||||
line = await self.server.wait_for(Response("AUTHENTICATE",
|
line = await self.server.wait_for(ResponseOr(
|
||||||
[ParamAny()], errors=["904", "907", "908"]))
|
Response("AUTHENTICATE", [ParamAny()]),
|
||||||
|
NUMERICS_INITIAL
|
||||||
|
))
|
||||||
|
|
||||||
if line.command == "907":
|
if line.command == "907":
|
||||||
# we've done SASL already. cleanly abort
|
# we've done SASL already. cleanly abort
|
||||||
|
@ -92,8 +100,8 @@ class SASLContext(ServerContext):
|
||||||
match = _common(available)
|
match = _common(available)
|
||||||
|
|
||||||
await self.server.send(build("AUTHENTICATE", [match]))
|
await self.server.send(build("AUTHENTICATE", [match]))
|
||||||
line = await self.server.wait_for(Response("AUTHENTICATE",
|
line = await self.server.wait_for(
|
||||||
[ParamAny()]))
|
Response("AUTHENTICATE", [ParamAny()]))
|
||||||
|
|
||||||
if line.command == "AUTHENTICATE" and line.params[0] == "+":
|
if line.command == "AUTHENTICATE" and line.params[0] == "+":
|
||||||
auth_text: Optional[str] = None
|
auth_text: Optional[str] = None
|
||||||
|
@ -101,11 +109,11 @@ class SASLContext(ServerContext):
|
||||||
auth_text = f"{username}\0{username}\0{password}"
|
auth_text = f"{username}\0{username}\0{password}"
|
||||||
|
|
||||||
if not auth_text is None:
|
if not auth_text is None:
|
||||||
auth_b64 = b64encode(auth_text.encode("utf8")
|
auth_b64 = b64encode(
|
||||||
).decode("ascii")
|
auth_text.encode("utf8")).decode("ascii")
|
||||||
await self.server.send(build("AUTHENTICATE", [auth_b64]))
|
await self.server.send(build("AUTHENTICATE", [auth_b64]))
|
||||||
|
|
||||||
line = await self.server.wait_for(Numerics(["903", "904"]))
|
line = await self.server.wait_for(NUMERICS_LAST)
|
||||||
if line.command == "903":
|
if line.command == "903":
|
||||||
return SASLResult.SUCCESS
|
return SASLResult.SUCCESS
|
||||||
return SASLResult.FAILURE
|
return SASLResult.FAILURE
|
||||||
|
|
Loading…
Reference in New Issue
Block a user