diff --git a/ircrobots/interface.py b/ircrobots/interface.py index dc8a313..dd52990 100644 --- a/ircrobots/interface.py +++ b/ircrobots/interface.py @@ -5,7 +5,7 @@ from enum import IntEnum from ircstates import Server, Emit from irctokens import Line, Hostmask -from .params import ConnectionParams, SASLParams, STSPolicy +from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy class ITCPReader(object): async def read(self, byte_count: int): @@ -13,8 +13,14 @@ class ITCPReader(object): class ITCPWriter(object): def write(self, data: bytes): pass + + def get_peer(self) -> Tuple[str, int]: + pass + async def drain(self): pass + async def close(self): + pass class ITCPTransport(object): async def connect(self, @@ -84,6 +90,9 @@ class IServer(Server): def set_throttle(self, rate: int, time: float): pass + def server_address(self) -> Tuple[str, int]: + pass + async def connect(self, transport: ITCPTransport, params: ConnectionParams): @@ -97,6 +106,8 @@ class IServer(Server): pass async def sts_policy(self, sts: STSPolicy): pass + async def resume_policy(self, resume: ResumePolicy): + pass async def next_line(self) -> Optional[Tuple[Line, Optional[Emit]]]: pass diff --git a/ircrobots/ircv3.py b/ircrobots/ircv3.py index 3ccef8a..4477714 100644 --- a/ircrobots/ircv3.py +++ b/ircrobots/ircv3.py @@ -7,7 +7,7 @@ from ircstates.server import ServerDisconnectedException from .contexts import ServerContext from .matching import Response, ResponseOr, ANY from .interface import ICapability -from .params import ConnectionParams, STSPolicy +from .params import ConnectionParams, STSPolicy, ResumePolicy class Capability(ICapability): def __init__(self, @@ -40,10 +40,11 @@ class Capability(ICapability): alias=self.alias, depends_on=self.depends_on[:]) -CAP_SASL = Capability("sasl") -CAP_ECHO = Capability("echo-message") -CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2") -CAP_STS = Capability("sts", "draft/sts") +CAP_SASL = Capability("sasl") +CAP_ECHO = Capability("echo-message") +CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2") +CAP_STS = Capability("sts", "draft/sts") +CAP_RESUME = Capability(None, "draft/resume-0.5", alias="resume") LABEL_TAG = { "draft/labeled-response-0.2": "draft/label", @@ -65,7 +66,8 @@ CAPS: List[ICapability] = [ Capability("batch"), Capability(None, "draft/rename", alias="rename"), - Capability("setname", "draft/setname") + Capability("setname", "draft/setname"), + CAP_RESUME ] def _cap_dict(s: str) -> Dict[str, str]: @@ -82,6 +84,9 @@ async def sts_transmute(params: ConnectionParams): if since <= params.sts.duration: params.port = params.sts.port params.tls = True +async def resume_transmute(params: ConnectionParams): + if params.resume is not None: + params.host = params.resume.address class CAPContext(ServerContext): async def on_ls(self, tokens: Dict[str, str]): @@ -109,10 +114,22 @@ class CAPContext(ServerContext): for cap in current_caps: if cap in cap_names: cap_names.remove(cap) + if CAP_RESUME.available(current_caps): + await self.resume_token() + if (self.server.cap_agreed(CAP_SASL) and not self.server.params.sasl is None): await self.server.sasl_auth(self.server.params.sasl) + async def resume_token(self): + line = await self.server.wait_for(Response("RESUME", ["TOKEN", ANY])) + token = line.params[1] + address, port = self.server.server_address() + + resume_policy = ResumePolicy(address, token) + self.server.params.resume = resume_policy + await self.server.resume_policy(resume_policy) + async def handshake(self): await self.on_ls(self.server.available_caps) await self.server.send(build("CAP", ["END"])) diff --git a/ircrobots/params.py b/ircrobots/params.py index 7d14688..f0fc1e5 100644 --- a/ircrobots/params.py +++ b/ircrobots/params.py @@ -27,6 +27,11 @@ class STSPolicy(object): duration: int preload: bool +@dataclass +class ResumePolicy(object): + address: str + token: str + @dataclass class ConnectionParams(object): nickname: str @@ -42,4 +47,5 @@ class ConnectionParams(object): tls_verify: bool = True sasl: Optional[SASLParams] = None - sts: Optional[STSPolicy] = None + sts: Optional[STSPolicy] = None + resume: Optional[ResumePolicy] = None diff --git a/ircrobots/server.py b/ircrobots/server.py index a10957a..5097741 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -10,13 +10,13 @@ from ircstates.server import ServerDisconnectedException from irctokens import build, Line, tokenise from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL, - CAP_LABEL, LABEL_TAG) + CAP_LABEL, LABEL_TAG, resume_transmute) from .sasl import SASLContext, SASLResult from .join_info import WHOContext from .matching import ResponseOr, Responses, Response, ANY, Folded, Nickname from .asyncs import MaybeAwait, WaitFor from .struct import Whois -from .params import ConnectionParams, SASLParams, STSPolicy +from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy from .interface import (IBot, ICapability, IServer, SentLine, SendPriority, IMatchResponse) from .interface import ITCPTransport, ITCPReader, ITCPWriter @@ -84,10 +84,14 @@ class Server(IServer): self.throttle.rate_limit = rate self.throttle.period = time + def server_address(self) -> Tuple[str, int]: + return self._writer.get_peer() + async def connect(self, transport: ITCPTransport, params: ConnectionParams): await sts_transmute(params) + await resume_transmute(params) reader, writer = await transport.connect( params.host, @@ -126,6 +130,8 @@ class Server(IServer): pass async def sts_policy(self, sts: STSPolicy): pass + async def resume_policy(self, resume: ResumePolicy): + pass # /to be overriden async def _on_read_emit(self, line: Line, emit: Emit): diff --git a/ircrobots/transport.py b/ircrobots/transport.py index 2d276ff..291409c 100644 --- a/ircrobots/transport.py +++ b/ircrobots/transport.py @@ -16,6 +16,10 @@ class TCPWriter(ITCPWriter): def __init__(self, writer: StreamWriter): self._writer = writer + def get_peer(self) -> Tuple[str, int]: + address, port, *_ = self._writer.transport.get_extra_info("peername") + return (address, port) + def write(self, data: bytes): self._writer.write(data)