diff --git a/ircrobots/__init__.py b/ircrobots/__init__.py index 5b798ed..6b4d641 100644 --- a/ircrobots/__init__.py +++ b/ircrobots/__init__.py @@ -1,5 +1,5 @@ from .bot import Bot from .server import Server -from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM, +from .params import (ConnectionParams, ClientTLSCertificate, SASLUserPass, SASLExternal, SASLSCRAM, STSPolicy, ResumePolicy) from .ircv3 import Capability diff --git a/ircrobots/interface.py b/ircrobots/interface.py index f680f5f..29db989 100644 --- a/ircrobots/interface.py +++ b/ircrobots/interface.py @@ -5,7 +5,7 @@ from enum import IntEnum from ircstates import Server, Emit from irctokens import Line, Hostmask -from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy +from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy, ClientTLSCertificate class ITCPReader(object): async def read(self, byte_count: int): @@ -24,11 +24,12 @@ class ITCPWriter(object): class ITCPTransport(object): async def connect(self, - hostname: str, - port: int, - tls: bool, - tls_verify: bool=True, - bindhost: Optional[str]=None + hostname: str, + port: int, + tls: bool, + tls_verify: bool=True, + certificate: Optional[ClientTLSCertificate]=None, + bindhost: Optional[str]=None ) -> Tuple[ITCPReader, ITCPWriter]: pass diff --git a/ircrobots/params.py b/ircrobots/params.py index 27117e3..ae02116 100644 --- a/ircrobots/params.py +++ b/ircrobots/params.py @@ -28,6 +28,12 @@ class ResumePolicy(object): address: str token: str +@dataclass +class ClientTLSCertificate(object): + certfile: str + keyfile: Optional[str] = None + password: Optional[str] = None + @dataclass class ConnectionParams(object): nickname: str @@ -39,9 +45,10 @@ class ConnectionParams(object): realname: Optional[str] = None bindhost: Optional[str] = None - password: Optional[str] = None - tls_verify: bool = True - sasl: Optional[SASLParams] = None + password: Optional[str] = None + tls_verify: bool = True + sasl: Optional[SASLParams] = None + certificate: Optional[ClientTLSCertificate] = None sts: Optional[STSPolicy] = None resume: Optional[ResumePolicy] = None diff --git a/ircrobots/security.py b/ircrobots/security.py index 17d1b78..31a3575 100644 --- a/ircrobots/security.py +++ b/ircrobots/security.py @@ -1,6 +1,13 @@ import ssl -def tls_context(verify: bool=True) -> ssl.SSLContext: +from typing import Optional +from .params import ClientTLSCertificate + +def tls_context( + verify: bool=True, + certificate: Optional[ClientTLSCertificate]=None + ) -> ssl.SSLContext: + context = ssl.SSLContext(ssl.PROTOCOL_TLS) context.options |= ssl.OP_NO_SSLv2 context.options |= ssl.OP_NO_SSLv3 @@ -10,4 +17,11 @@ def tls_context(verify: bool=True) -> ssl.SSLContext: if verify: context.verify_mode = ssl.CERT_REQUIRED + if certificate is not None: + context.load_cert_chain( + certificate.certfile, + certificate.keyfile, + certificate.password + ) + return context diff --git a/ircrobots/server.py b/ircrobots/server.py index d916761..5c10b59 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -124,9 +124,10 @@ class Server(IServer): reader, writer = await transport.connect( params.host, params.port, - tls =params.tls, - tls_verify=params.tls_verify, - bindhost =params.bindhost) + tls =params.tls, + tls_verify =params.tls_verify, + certificate=params.certificate, + bindhost =params.bindhost) self._reader = reader self._writer = writer diff --git a/ircrobots/transport.py b/ircrobots/transport.py index 291409c..8dfaff2 100644 --- a/ircrobots/transport.py +++ b/ircrobots/transport.py @@ -4,6 +4,7 @@ from asyncio import StreamReader, StreamWriter from async_stagger import open_connection from .interface import ITCPTransport, ITCPReader, ITCPWriter +from .params import ClientTLSCertificate from .security import tls_context class TCPReader(ITCPReader): @@ -32,16 +33,17 @@ class TCPWriter(ITCPWriter): class TCPTransport(ITCPTransport): async def connect(self, - hostname: str, - port: int, - tls: bool, - tls_verify: bool=True, - bindhost: Optional[str]=None + hostname: str, + port: int, + tls: bool, + tls_verify: bool=True, + certificate: Optional[ClientTLSCertificate]=None, + bindhost: Optional[str]=None ) -> Tuple[ITCPReader, ITCPWriter]: cur_ssl: Optional[SSLContext] = None if tls: - cur_ssl = tls_context(tls_verify) + cur_ssl = tls_context(tls_verify, certificate) local_addr: Optional[Tuple[str, int]] = None if not bindhost is None: