ircrobots/ircrobots/sasl.py

175 lines
6.2 KiB
Python

from typing import List
from enum import Enum
from base64 import b64decode, b64encode
from irctokens import build
from .matching import Response, ResponseOr, Numerics, ParamAny
from .contexts import ServerContext
from .params import SASLParams
from .scram import SCRAMContext
SASL_SCRAM_MECHANISMS = [
"SCRAM-SHA-512",
"SCRAM-SHA-256",
"SCRAM-SHA-1",
]
SASL_USERPASS_MECHANISMS = SASL_SCRAM_MECHANISMS+["PLAIN"]
class SASLResult(Enum):
NONE = 0
SUCCESS = 1
FAILURE = 2
ALREADY = 3
class SASLError(Exception):
pass
class SASLUnknownMechanismError(SASLError):
pass
AUTHENTICATE_ANY = Response("AUTHENTICATE", [ParamAny()])
NUMERICS_FAIL = Numerics(["ERR_SASLFAIL"])
NUMERICS_INITIAL = Numerics(
["ERR_SASLFAIL", "ERR_SASLALREADY", "RPL_SASLMECHS"])
NUMERICS_LAST = Numerics(["RPL_SASLSUCCESS", "ERR_SASLFAIL"])
def _b64e(s: str):
return b64encode(s.encode("utf8")).decode("ascii")
def _b64eb(s: bytes) -> str:
# encode-from-bytes
return b64encode(s).decode("ascii")
def _b64db(s: str) -> bytes:
# decode-to-bytes
return b64decode(s)
class SASLContext(ServerContext):
async def from_params(self, params: SASLParams) -> SASLResult:
if params.mechanism == "USERPASS":
return await self.userpass(params.username, params.password)
elif params.mechanism == "SCRAM":
return await self.scram(params.username, params.password)
elif params.mechanism == "EXTERNAL":
return await self.external()
else:
raise SASLUnknownMechanismError(
"SASLParams given with unknown mechanism "
f"{params.mechanism!r}")
async def external(self) -> SASLResult:
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
line = await self.server.wait_for(ResponseOr(
AUTHENTICATE_ANY,
NUMERICS_INITIAL
))
if line.command == "907":
# we've done SASL already. cleanly abort
return SASLResult.ALREADY
elif line.command == "908":
available = line.params[1].split(",")
raise SASLUnknownMechanismError(
"Server does not support SASL EXTERNAL "
f"(it supports {available}")
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
await self.server.send(build("AUTHENTICATE", ["+"]))
line = await self.server.wait_for(NUMERICS_LAST)
if line.command == "903":
return SASLResult.SUCCESS
return SASLResult.FAILURE
async def plain(self, username: str, password: str) -> SASLResult:
return await self.userpass(username, password, ["PLAIN"])
async def scram(self, username: str, password: str) -> SASLResult:
return await self.userpass(username, password, SASL_SCRAM_MECHANISMS)
async def userpass(self,
username: str,
password: str,
mechanisms: List[str]=SASL_USERPASS_MECHANISMS
) -> SASLResult:
def _common(server_mechs) -> List[str]:
mechs: List[str] = []
for our_mech in mechanisms:
if our_mech in server_mechs:
mechs.append(our_mech)
if mechs:
return mechs
else:
raise SASLUnknownMechanismError(
"No matching SASL mechanims. "
f"(we want: {mechanisms} "
f"server has: {server_mechs})")
if self.server.available_caps["sasl"]:
# CAP v3.2 tells us what mechs it supports
available = self.server.available_caps["sasl"].split(",")
match = _common(available)
else:
# CAP v3.1 does not. pick the pick and wait for 907 to inform us of
# what mechanisms are supported
match = mechanisms
while match:
await self.server.send(build("AUTHENTICATE", [match[0]]))
line = await self.server.wait_for(ResponseOr(
AUTHENTICATE_ANY,
NUMERICS_INITIAL
))
if line.command == "907":
# we've done SASL already. cleanly abort
return SASLResult.ALREADY
elif line.command == "908":
# prior to CAP v3.2 - ERR telling us which mechs are supported
available = line.params[1].split(",")
match = _common(available)
await self.server.wait_for(NUMERICS_FAIL)
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
auth_text = ""
if match[0] == "PLAIN":
auth_text = f"{username}\0{username}\0{password}"
elif match[0].startswith("SCRAM-SHA-"):
auth_text = await self._scram(
match[0], username, password)
if not auth_text == "+":
auth_text = _b64e(auth_text)
if auth_text:
await self.server.send(build("AUTHENTICATE", [auth_text]))
line = await self.server.wait_for(NUMERICS_LAST)
if line.command == "903":
return SASLResult.SUCCESS
elif line.command == "904":
match.pop(0)
return SASLResult.FAILURE
async def _scram(self, algo: str, username: str, password: str) -> str:
algo = algo.replace("SCRAM-", "", 1)
scram = SCRAMContext(algo, username, password)
client_first = _b64eb(scram.client_first())
await self.server.send(build("AUTHENTICATE", [client_first]))
line = await self.server.wait_for(AUTHENTICATE_ANY)
server_first = _b64db(line.params[0])
client_final = _b64eb(scram.server_first(server_first))
if not client_final == "":
await self.server.send(build("AUTHENTICATE", [client_final]))
line = await self.server.wait_for(AUTHENTICATE_ANY)
server_final = _b64db(line.params[0])
verified = scram.server_final(server_final)
#TODO PANIC if verified is false!
return "+"
else:
return ""