diff --git a/ircrobots/ircv3.py b/ircrobots/ircv3.py index a31efe6..791a8c6 100644 --- a/ircrobots/ircv3.py +++ b/ircrobots/ircv3.py @@ -2,6 +2,7 @@ from time import time from typing import Dict, Iterable, List, Optional, Tuple from dataclasses import dataclass from irctokens import build +from ircstates.server import ServerDisconnectedException from .contexts import ServerContext from .matching import Response, ResponseOr, ParamAny, ParamLiteral @@ -85,6 +86,8 @@ async def sts_transmute(params: ConnectionParams): class CAPContext(ServerContext): async def on_ls(self, tokens: Dict[str, str]): + await self._sts(tokens) + caps = list(self.server.desired_caps)+CAPS if (not self.server.params.sasl is None and @@ -112,9 +115,13 @@ class CAPContext(ServerContext): await self.server.sasl_auth(self.server.params.sasl) async def handshake(self): - cap_sts = CAP_STS.available(self.server.available_caps) + await self.on_ls(self.server.available_caps) + await self.server.send(build("CAP", ["END"])) + + async def _sts(self, tokens: Dict[str, str]): + cap_sts = CAP_STS.available(tokens) if not cap_sts is None: - sts_dict = _cap_dict(self.server.available_caps[cap_sts]) + sts_dict = _cap_dict(tokens[cap_sts]) params = self.server.params if not params.tls: if "port" in sts_dict: @@ -123,7 +130,8 @@ class CAPContext(ServerContext): await self.server.bot.disconnect(self.server) await self.server.bot.add_server(self.server.name, params) - return + raise ServerDisconnectedException() + elif "duration" in sts_dict: policy = STSPolicy( int(time()), @@ -132,5 +140,3 @@ class CAPContext(ServerContext): "preload" in sts_dict) self.server.sts_policy(policy) - await self.on_ls(self.server.available_caps) - await self.server.send(build("CAP", ["END"]))