Compare commits

...

25 Commits

Author SHA1 Message Date
jesopo 7c9a144124 v0.6.6 release 2023-08-17 22:48:21 +00:00
jesopo e3c91a50e1 update ircstates 2023-08-17 22:47:57 +00:00
jesopo f2ba48a582 v0.6.5 release 2023-07-06 00:57:08 +00:00
jesopo cf2e69a9e2 asyncio.wait(..) now requires Tasks 2023-07-06 00:56:45 +00:00
jesopo a1a459c13e v0.6.4 release 2023-07-06 00:44:25 +00:00
jesopo 81fa77cf29 missed some TLS_ uses 2023-07-06 00:44:13 +00:00
jesopo 422a9a93c1 v0.6.3 release 2023-07-06 00:35:47 +00:00
jesopo b04a0e0136 python no longer likes having mutables in non-default_factory 2023-07-06 00:35:13 +00:00
jesopo 7bb4c3d069 v0.6.2 release 2023-02-06 19:43:14 +00:00
jesopo 9a2f2156fe support specifying tls client keypair 2023-02-06 19:42:27 +00:00
alicetries 0435404ec3 Small tweak to how repr() of Formatless() displays 2022-03-28 23:43:27 +01:00
jesopo 63025af311 v0.6.1 release 2022-02-19 13:51:05 +00:00
jesopo 20c4f8f98c upgrade async-timeout to v4.0.2 2022-02-19 13:49:21 +00:00
jesopo 0ce3b9b0b0 v0.6.0 release 2022-01-24 10:01:51 +00:00
jesopo 5b347f95c9 combine params.tls and .tls_verify, support pinned certs 2022-01-24 09:53:32 +00:00
jesopo 0a5c774965 v0.5.0 release 2022-01-20 21:28:23 +00:00
jesopo 8245a411c0 hmm no this isnt how you ask for cert validation apparently 2022-01-20 21:24:54 +00:00
jesopo 80b941fa53 handle ERR_UNAVAILRESOURCE for prereg NICK failure too 2022-01-16 15:24:22 +00:00
jesopo b7019d35c1 upgrade ircstates 2022-01-07 19:04:48 +00:00
jesopo 66358f77e3 v0.4.7 release 2022-01-07 11:45:32 +00:00
jesopo fcd2f5b1b2 upgrade ircstates 2022-01-07 11:43:30 +00:00
jesopo 3e18deef86 we don't support py3.6; support py3.9 2022-01-07 11:41:35 +00:00
jesopo 9ba5b2b90f add `transport` (ITCPTransport) param to bot.add_server 2021-12-18 16:48:01 +00:00
jesopo 025fde97ee v0.4.6 release 2021-12-09 23:53:14 +00:00
jesopo 05750f00d9 make sure 'tls' is defined 2021-12-09 23:49:40 +00:00
16 changed files with 109 additions and 58 deletions

View File

@ -3,7 +3,7 @@ cache: pip
python: python:
- "3.7" - "3.7"
- "3.8" - "3.8"
- "3.8-dev" - "3.9"
install: install:
- pip3 install mypy -r requirements.txt - pip3 install mypy -r requirements.txt
script: script:

View File

@ -1 +1 @@
0.4.5 0.6.6

View File

@ -154,8 +154,8 @@ async def main(hostname: str, channel: str, nickname: str):
params = ConnectionParams( params = ConnectionParams(
nickname, nickname,
hostname, hostname,
6697, 6697
tls=True) )
await bot.add_server("freenode", params) await bot.add_server("freenode", params)
await bot.run() await bot.run()

View File

@ -23,7 +23,6 @@ async def main():
"MyNickname", "MyNickname",
host = "chat.freenode.invalid", host = "chat.freenode.invalid",
port = 6697, port = 6697,
tls = True,
sasl = sasl_params) sasl = sasl_params)
await bot.add_server("freenode", params) await bot.add_server("freenode", params)

View File

@ -25,7 +25,7 @@ class Bot(BaseBot):
async def main(): async def main():
bot = Bot() bot = Bot()
for name, host in SERVERS: for name, host in SERVERS:
params = ConnectionParams("BitBotNewTest", host, 6697, True) params = ConnectionParams("BitBotNewTest", host, 6697)
await bot.add_server(name, params) await bot.add_server(name, params)
await bot.run() await bot.run()

View File

@ -3,3 +3,4 @@ from .server import Server
from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM, from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM,
STSPolicy, ResumePolicy) STSPolicy, ResumePolicy)
from .ircv3 import Capability from .ircv3 import Capability
from .security import TLS

View File

@ -6,7 +6,7 @@ from ircstates.server import ServerDisconnectedException
from .server import ConnectionParams, Server from .server import ConnectionParams, Server
from .transport import TCPTransport from .transport import TCPTransport
from .interface import IBot, IServer from .interface import IBot, IServer, ITCPTransport
class Bot(IBot): class Bot(IBot):
def __init__(self): def __init__(self):
@ -38,10 +38,13 @@ class Bot(IBot):
del self.servers[server.name] del self.servers[server.name]
await server.disconnect() await server.disconnect()
async def add_server(self, name: str, params: ConnectionParams) -> Server: async def add_server(self,
name: str,
params: ConnectionParams,
transport: ITCPTransport = TCPTransport()) -> Server:
server = self.create_server(name) server = self.create_server(name)
self.servers[name] = server self.servers[name] = server
await server.connect(TCPTransport(), params) await server.connect(transport, params)
await self._server_queue.put(server) await self._server_queue.put(server)
return server return server

View File

@ -6,6 +6,7 @@ from ircstates import Server, Emit
from irctokens import Line, Hostmask from irctokens import Line, Hostmask
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
from .security import TLS
class ITCPReader(object): class ITCPReader(object):
async def read(self, byte_count: int): async def read(self, byte_count: int):
@ -24,11 +25,10 @@ class ITCPWriter(object):
class ITCPTransport(object): class ITCPTransport(object):
async def connect(self, async def connect(self,
hostname: str, hostname: str,
port: int, port: int,
tls: bool, tls: Optional[TLS],
tls_verify: bool=True, bindhost: Optional[str]=None
bindhost: Optional[str]=None
) -> Tuple[ITCPReader, ITCPWriter]: ) -> Tuple[ITCPReader, ITCPWriter]:
pass pass

View File

@ -8,6 +8,7 @@ from .contexts import ServerContext
from .matching import Response, ANY from .matching import Response, ANY
from .interface import ICapability from .interface import ICapability
from .params import ConnectionParams, STSPolicy, ResumePolicy from .params import ConnectionParams, STSPolicy, ResumePolicy
from .security import TLSVerifyChain
class Capability(ICapability): class Capability(ICapability):
def __init__(self, def __init__(self,
@ -101,12 +102,12 @@ def _cap_dict(s: str) -> Dict[str, str]:
return d return d
async def sts_transmute(params: ConnectionParams): async def sts_transmute(params: ConnectionParams):
if not params.sts is None and not params.tls: if not params.sts is None and params.tls is None:
now = time() now = time()
since = (now-params.sts.created) since = (now-params.sts.created)
if since <= params.sts.duration: if since <= params.sts.duration:
params.port = params.sts.port params.port = params.sts.port
params.tls = True params.tls = TLSVerifyChain()
async def resume_transmute(params: ConnectionParams): async def resume_transmute(params: ConnectionParams):
if params.resume is not None: if params.resume is not None:
params.host = params.resume.address params.host = params.resume.address
@ -182,7 +183,7 @@ class CAPContext(ServerContext):
if not params.tls: if not params.tls:
if "port" in sts_dict: if "port" in sts_dict:
params.port = int(sts_dict["port"]) params.port = int(sts_dict["port"])
params.tls = True params.tls = TLSVerifyChain()
await self.server.bot.disconnect(self.server) await self.server.bot.disconnect(self.server)
await self.server.bot.add_server(self.server.name, params) await self.server.bot.add_server(self.server.name, params)

View File

@ -73,8 +73,7 @@ class Formatless(IMatchResponseParam):
def __init__(self, value: TYPE_MAYBELIT_VALUE): def __init__(self, value: TYPE_MAYBELIT_VALUE):
self._value = _assure_lit(value) self._value = _assure_lit(value)
def __repr__(self) -> str: def __repr__(self) -> str:
brepr = super().__repr__() return f"Formatless({self._value!r})"
return f"Formatless({brepr})"
def match(self, server: IServer, arg: str) -> bool: def match(self, server: IServer, arg: str) -> bool:
strip = formatting.strip(arg) strip = formatting.strip(arg)
return self._value.match(server, strip) return self._value.match(server, strip)

View File

@ -1,6 +1,9 @@
from re import compile as re_compile
from typing import List, Optional from typing import List, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from .security import TLS, TLSNoVerify, TLSVerifyChain
class SASLParams(object): class SASLParams(object):
mechanism: str mechanism: str
@ -28,19 +31,24 @@ class ResumePolicy(object):
address: str address: str
token: str token: str
RE_IPV6HOST = re_compile("\[([a-fA-F0-9:]+)\]")
_TLS_TYPES = {
"+": TLSVerifyChain,
"~": TLSNoVerify,
}
@dataclass @dataclass
class ConnectionParams(object): class ConnectionParams(object):
nickname: str nickname: str
host: str host: str
port: int port: int
tls: bool tls: Optional[TLS] = field(default_factory=TLSVerifyChain)
username: Optional[str] = None username: Optional[str] = None
realname: Optional[str] = None realname: Optional[str] = None
bindhost: Optional[str] = None bindhost: Optional[str] = None
password: Optional[str] = None password: Optional[str] = None
tls_verify: bool = True
sasl: Optional[SASLParams] = None sasl: Optional[SASLParams] = None
sts: Optional[STSPolicy] = None sts: Optional[STSPolicy] = None
@ -57,13 +65,19 @@ class ConnectionParams(object):
hoststring: str hoststring: str
) -> "ConnectionParams": ) -> "ConnectionParams":
host, _, port_s = hoststring.strip().partition(":") ipv6host = RE_IPV6HOST.search(hoststring)
if ipv6host is not None and ipv6host.start() == 0:
host = ipv6host.group(1)
port_s = hoststring[ipv6host.end()+1:]
else:
host, _, port_s = hoststring.strip().partition(":")
if port_s.startswith("+"): tls_type: Optional[TLS] = None
tls = True if not port_s:
port_s = port_s.lstrip("+") or "6697"
elif not port_s:
tls = False
port_s = "6667" port_s = "6667"
else:
tls_type = _TLS_TYPES.get(port_s[0], lambda: None)()
if tls_type is not None:
port_s = port_s[1:] or "6697"
return ConnectionParams(nickname, host, int(port_s), tls) return ConnectionParams(nickname, host, int(port_s), tls_type)

View File

@ -1,13 +1,29 @@
import ssl import ssl
from dataclasses import dataclass
from typing import Optional, Tuple
@dataclass
class TLS:
client_keypair: Optional[Tuple[str, str]] = None
# tls without verification
class TLSNoVerify(TLS):
pass
# verify via CAs
class TLSVerifyChain(TLS):
pass
# verify by a pinned hash
class TLSVerifyHash(TLSNoVerify):
def __init__(self, sum: str):
self.sum = sum.lower()
class TLSVerifySHA512(TLSVerifyHash):
pass
def tls_context(verify: bool=True) -> ssl.SSLContext: def tls_context(verify: bool=True) -> ssl.SSLContext:
context = ssl.SSLContext(ssl.PROTOCOL_TLS) ctx = ssl.create_default_context()
context.options |= ssl.OP_NO_SSLv2 if not verify:
context.options |= ssl.OP_NO_SSLv3 ctx.check_hostname = False
context.options |= ssl.OP_NO_TLSv1 ctx.verify_mode = ssl.CERT_NONE
context.load_default_certs() return ctx
if verify:
context.verify_mode = ssl.CERT_REQUIRED
return context

View File

@ -124,9 +124,8 @@ class Server(IServer):
reader, writer = await transport.connect( reader, writer = await transport.connect(
params.host, params.host,
params.port, params.port,
tls =params.tls, tls =params.tls,
tls_verify=params.tls_verify, bindhost =params.bindhost)
bindhost =params.bindhost)
self._reader = reader self._reader = reader
self._writer = writer self._writer = writer
@ -181,9 +180,9 @@ class Server(IServer):
self._pending_who[0] == chan): self._pending_who[0] == chan):
self._pending_who.popleft() self._pending_who.popleft()
await self._next_who() await self._next_who()
elif (line.command in {
elif (line.command in {ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME} and ERR_NICKNAMEINUSE, ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE
not self.registered): } and not self.registered):
if self._alt_nicks: if self._alt_nicks:
nick = self._alt_nicks.pop(0) nick = self._alt_nicks.pop(0)
await self.send(build("NICK", [nick])) await self.send(build("NICK", [nick]))
@ -288,9 +287,10 @@ class Server(IServer):
if not self._process_queue: if not self._process_queue:
async with self._read_lwork: async with self._read_lwork:
read_aw = self._read_line(PING_TIMEOUT) read_aw = asyncio.create_task(self._read_line(PING_TIMEOUT))
wait_aw = asyncio.create_task(self._wait_for.wait())
dones, notdones = await asyncio.wait( dones, notdones = await asyncio.wait(
[read_aw, self._wait_for.wait()], [read_aw, wait_aw],
return_when=asyncio.FIRST_COMPLETED return_when=asyncio.FIRST_COMPLETED
) )
self._wait_for.clear() self._wait_for.clear()

View File

@ -1,10 +1,12 @@
from hashlib import sha512
from ssl import SSLContext from ssl import SSLContext
from typing import Optional, Tuple from typing import Optional, Tuple
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
from async_stagger import open_connection from async_stagger import open_connection
from .interface import ITCPTransport, ITCPReader, ITCPWriter from .interface import ITCPTransport, ITCPReader, ITCPWriter
from .security import tls_context from .security import (tls_context, TLS, TLSNoVerify, TLSVerifyHash,
TLSVerifySHA512)
class TCPReader(ITCPReader): class TCPReader(ITCPReader):
def __init__(self, reader: StreamReader): def __init__(self, reader: StreamReader):
@ -32,16 +34,18 @@ class TCPWriter(ITCPWriter):
class TCPTransport(ITCPTransport): class TCPTransport(ITCPTransport):
async def connect(self, async def connect(self,
hostname: str, hostname: str,
port: int, port: int,
tls: bool, tls: Optional[TLS],
tls_verify: bool=True, bindhost: Optional[str]=None
bindhost: Optional[str]=None
) -> Tuple[ITCPReader, ITCPWriter]: ) -> Tuple[ITCPReader, ITCPWriter]:
cur_ssl: Optional[SSLContext] = None cur_ssl: Optional[SSLContext] = None
if tls: if tls is not None:
cur_ssl = tls_context(tls_verify) cur_ssl = tls_context(not isinstance(tls, TLSNoVerify))
if tls.client_keypair is not None:
(client_cert, client_key) = tls.client_keypair
cur_ssl.load_cert_chain(client_cert, keyfile=client_key)
local_addr: Optional[Tuple[str, int]] = None local_addr: Optional[Tuple[str, int]] = None
if not bindhost is None: if not bindhost is None:
@ -55,5 +59,20 @@ class TCPTransport(ITCPTransport):
server_hostname=server_hostname, server_hostname=server_hostname,
ssl =cur_ssl, ssl =cur_ssl,
local_addr =local_addr) local_addr =local_addr)
if isinstance(tls, TLSVerifyHash):
cert: bytes = writer.transport.get_extra_info(
"ssl_object"
).getpeercert(True)
if isinstance(tls, TLSVerifySHA512):
sum = sha512(cert).hexdigest()
else:
raise ValueError(f"unknown hash pinning {type(tls)}")
if not sum == tls.sum:
raise ValueError(
f"pinned hash for {hostname} does not match ({sum})"
)
return (TCPReader(reader), TCPWriter(writer)) return (TCPReader(reader), TCPWriter(writer))

View File

@ -1,7 +1,6 @@
anyio ~=2.0.2 anyio ~=2.0.2
asyncio-rlock ~=0.1.0 asyncio-rlock ~=0.1.0
asyncio-throttle ~=1.0.1 asyncio-throttle ~=1.0.1
dataclasses ~=0.6; python_version<"3.7" ircstates ~=0.12.1
ircstates ~=0.11.10
async_stagger ~=0.3.0 async_stagger ~=0.3.0
async_timeout ~=3.0.1 async_timeout ~=4.0.2

View File

@ -26,6 +26,6 @@ setup(
"Operating System :: Microsoft :: Windows", "Operating System :: Microsoft :: Windows",
"Topic :: Communications :: Chat :: Internet Relay Chat" "Topic :: Communications :: Chat :: Internet Relay Chat"
], ],
python_requires='>=3.6', python_requires='>=3.7',
install_requires=install_requires install_requires=install_requires
) )