also catch sts outside of connection reg, throw up Disconnect to be sure

This commit is contained in:
jesopo 2020-04-20 16:21:29 +01:00
parent dfabe99916
commit fedef3ba3d
1 changed files with 11 additions and 5 deletions

View File

@ -2,6 +2,7 @@ from time import time
from typing import Dict, Iterable, List, Optional, Tuple
from dataclasses import dataclass
from irctokens import build
from ircstates.server import ServerDisconnectedException
from .contexts import ServerContext
from .matching import Response, ResponseOr, ParamAny, ParamLiteral
@ -85,6 +86,8 @@ async def sts_transmute(params: ConnectionParams):
class CAPContext(ServerContext):
async def on_ls(self, tokens: Dict[str, str]):
await self._sts(tokens)
caps = list(self.server.desired_caps)+CAPS
if (not self.server.params.sasl is None and
@ -112,9 +115,13 @@ class CAPContext(ServerContext):
await self.server.sasl_auth(self.server.params.sasl)
async def handshake(self):
cap_sts = CAP_STS.available(self.server.available_caps)
await self.on_ls(self.server.available_caps)
await self.server.send(build("CAP", ["END"]))
async def _sts(self, tokens: Dict[str, str]):
cap_sts = CAP_STS.available(tokens)
if not cap_sts is None:
sts_dict = _cap_dict(self.server.available_caps[cap_sts])
sts_dict = _cap_dict(tokens[cap_sts])
params = self.server.params
if not params.tls:
if "port" in sts_dict:
@ -123,7 +130,8 @@ class CAPContext(ServerContext):
await self.server.bot.disconnect(self.server)
await self.server.bot.add_server(self.server.name, params)
return
raise ServerDisconnectedException()
elif "duration" in sts_dict:
policy = STSPolicy(
int(time()),
@ -132,5 +140,3 @@ class CAPContext(ServerContext):
"preload" in sts_dict)
self.server.sts_policy(policy)
await self.on_ls(self.server.available_caps)
await self.server.send(build("CAP", ["END"]))