ircrobots/ircrobots/sasl.py

196 lines
6.7 KiB
Python

from typing import List
from enum import Enum
from base64 import b64decode, b64encode
from irctokens import build
from ircstates.numerics import *
from .matching import Responses, Response, ANY
from .contexts import ServerContext
from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal
from .scram import SCRAMContext, SCRAMAlgorithm
SASL_SCRAM_MECHANISMS = [
"SCRAM-SHA-512",
"SCRAM-SHA-256",
"SCRAM-SHA-1",
]
SASL_USERPASS_MECHANISMS = SASL_SCRAM_MECHANISMS+["PLAIN"]
class SASLResult(Enum):
NONE = 0
SUCCESS = 1
FAILURE = 2
ALREADY = 3
class SASLError(Exception):
pass
class SASLUnknownMechanismError(SASLError):
pass
AUTH_BYTE_MAX = 400
AUTHENTICATE_ANY = Response("AUTHENTICATE", [ANY])
NUMERICS_FAIL = Response(ERR_SASLFAIL)
NUMERICS_INITIAL = Responses([
ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS, ERR_SASLABORTED
])
NUMERICS_LAST = Responses([RPL_SASLSUCCESS, ERR_SASLFAIL])
def _b64e(s: str):
return b64encode(s.encode("utf8")).decode("ascii")
def _b64eb(s: bytes) -> str:
# encode-from-bytes
return b64encode(s).decode("ascii")
def _b64db(s: str) -> bytes:
# decode-to-bytes
return b64decode(s)
class SASLContext(ServerContext):
async def from_params(self, params: SASLParams) -> SASLResult:
if isinstance(params, SASLUserPass):
return await self.userpass(params.username, params.password)
elif isinstance(params, SASLSCRAM):
return await self.scram(params.username, params.password)
elif isinstance(params, SASLExternal):
return await self.external()
else:
raise SASLUnknownMechanismError(
"SASLParams given with unknown mechanism "
f"{params.mechanism!r}")
async def external(self) -> SASLResult:
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
line = await self.server.wait_for({
AUTHENTICATE_ANY,
NUMERICS_INITIAL
})
if line.command == "907":
# we've done SASL already. cleanly abort
return SASLResult.ALREADY
elif line.command == "908":
available = line.params[1].split(",")
raise SASLUnknownMechanismError(
"Server does not support SASL EXTERNAL "
f"(it supports {available}")
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
await self.server.send(build("AUTHENTICATE", ["+"]))
line = await self.server.wait_for(NUMERICS_LAST)
if line.command == "903":
return SASLResult.SUCCESS
return SASLResult.FAILURE
async def plain(self, username: str, password: str) -> SASLResult:
return await self.userpass(username, password, ["PLAIN"])
async def scram(self, username: str, password: str) -> SASLResult:
return await self.userpass(username, password, SASL_SCRAM_MECHANISMS)
async def userpass(self,
username: str,
password: str,
mechanisms: List[str]=SASL_USERPASS_MECHANISMS
) -> SASLResult:
def _common(server_mechs) -> List[str]:
mechs: List[str] = []
for our_mech in mechanisms:
if our_mech in server_mechs:
mechs.append(our_mech)
if mechs:
return mechs
else:
raise SASLUnknownMechanismError(
"No matching SASL mechanims. "
f"(we want: {mechanisms} "
f"server has: {server_mechs})")
if self.server.available_caps["sasl"]:
# CAP v3.2 tells us what mechs it supports
available = self.server.available_caps["sasl"].split(",")
match = _common(available)
else:
# CAP v3.1 does not. pick the pick and wait for 907 to inform us of
# what mechanisms are supported
match = mechanisms
while match:
await self.server.send(build("AUTHENTICATE", [match[0]]))
line = await self.server.wait_for({
AUTHENTICATE_ANY,
NUMERICS_INITIAL
})
if line.command == "907":
# we've done SASL already. cleanly abort
return SASLResult.ALREADY
elif line.command == "908":
# prior to CAP v3.2 - ERR telling us which mechs are supported
available = line.params[1].split(",")
match = _common(available)
await self.server.wait_for(NUMERICS_FAIL)
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
auth_text = ""
if match[0] == "PLAIN":
auth_text = f"{username}\0{username}\0{password}"
elif match[0].startswith("SCRAM-SHA-"):
auth_text = await self._scram(
match[0], username, password)
if not auth_text == "+":
auth_text = _b64e(auth_text)
if auth_text:
await self._send_auth_text(auth_text)
line = await self.server.wait_for(NUMERICS_LAST)
if line.command == "903":
return SASLResult.SUCCESS
elif line.command == "904":
match.pop(0)
else:
break
return SASLResult.FAILURE
async def _scram(self, algo_str: str,
username: str,
password: str) -> str:
algo_str_prep = algo_str.replace("SCRAM-", "", 1
).replace("-", "").upper()
try:
algo = SCRAMAlgorithm(algo_str_prep)
except ValueError:
raise ValueError("Unknown SCRAM algorithm '%s'" % algo_str_prep)
scram = SCRAMContext(algo, username, password)
client_first = _b64eb(scram.client_first())
await self._send_auth_text(client_first)
line = await self.server.wait_for(AUTHENTICATE_ANY)
server_first = _b64db(line.params[0])
client_final = _b64eb(scram.server_first(server_first))
if not client_final == "":
await self._send_auth_text(client_final)
line = await self.server.wait_for(AUTHENTICATE_ANY)
server_final = _b64db(line.params[0])
verified = scram.server_final(server_final)
#TODO PANIC if verified is false!
return "+"
else:
return ""
async def _send_auth_text(self, text: str):
n = AUTH_BYTE_MAX
chunks = [text[i:i+n] for i in range(0, len(text), n)]
if len(chunks[-1]) == 400:
chunks.append("+")
for chunk in chunks:
await self.server.send(build("AUTHENTICATE", [chunk]))