mirror of https://github.com/str4d/RelayBot.git
Add SSL fingerprint verification.
This commit is contained in:
parent
2838689951
commit
e3847f76cc
115
relaybot.py
115
relaybot.py
|
@ -1,13 +1,17 @@
|
|||
from twisted.words.protocols import irc
|
||||
from twisted.internet import reactor, protocol, ssl
|
||||
from twisted.internet.protocol import ReconnectingClientFactory
|
||||
from twisted.python import log
|
||||
from twisted.python import log, reflect, util
|
||||
from twisted.internet.endpoints import clientFromString
|
||||
from twisted.internet.error import VerifyError, CertificateError
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.task import LoopingCall
|
||||
from twisted.application import service
|
||||
from twisted.python.hashlib import md5
|
||||
from OpenSSL import SSL, crypto
|
||||
from signal import signal, SIGINT
|
||||
from ConfigParser import SafeConfigParser
|
||||
import re, sys
|
||||
import re, sys, itertools
|
||||
|
||||
#
|
||||
# RelayBot is a derivative of http://code.google.com/p/relaybot/
|
||||
|
@ -18,6 +22,8 @@ log.startLogging(sys.stdout)
|
|||
__version__ = "0.1"
|
||||
application = service.Application("RelayBot")
|
||||
|
||||
_sessionCounter = itertools.count().next
|
||||
|
||||
def main():
|
||||
config = SafeConfigParser()
|
||||
config.read("relaybot.config")
|
||||
|
@ -32,7 +38,7 @@ def main():
|
|||
return None
|
||||
|
||||
options = {}
|
||||
for option in [ "timeout", "host", "port", "nick", "channel", "heartbeat", "password", "username", "realname", "mode", "ssl" ]:
|
||||
for option in [ "timeout", "host", "port", "nick", "channel", "heartbeat", "password", "username", "realname", "mode", "ssl", "fingerprint" ]:
|
||||
options[option] = get(option)
|
||||
|
||||
mode = get("mode")
|
||||
|
@ -56,12 +62,113 @@ def main():
|
|||
|
||||
factory = factory(options)
|
||||
if options['ssl'] == "True":
|
||||
reactor.connectSSL(options['host'], int(options['port']), factory, ssl.ClientContextFactory(), int(options['timeout']))
|
||||
if options['fingerprint']:
|
||||
ctx = certoptions(fingerprint=options['fingerprint'], verifyDepth=0)
|
||||
reactor.connectSSL(options['host'], int(options['port']), factory, ctx, int(options['timeout']))
|
||||
else:
|
||||
reactor.connectSSL(options['host'], int(options['port']), factory, ssl.ClientContextFactory(), int(options['timeout']))
|
||||
else:
|
||||
reactor.connectTCP(options['host'], int(options['port']), factory, int(options['timeout']))
|
||||
|
||||
reactor.callWhenRunning(signal, SIGINT, handler)
|
||||
|
||||
class certoptions(object):
|
||||
_context = None
|
||||
_OP_ALL = getattr(SSL, 'OP_ALL', 0x0000FFFF)
|
||||
_OP_NO_TICKET = 0x00004000
|
||||
method = SSL.TLSv1_METHOD
|
||||
|
||||
def __init__(self, privateKey=None, certificate=None, method=None, verify=False, caCerts=None, verifyDepth=9, requireCertificate=True, verifyOnce=True, enableSingleUseKeys=True, enableSessions=True, fixBrokenPeers=False, enableSessionTickets=False, fingerprint=True):
|
||||
assert (privateKey is None) == (certificate is None), "Specify neither or both of privateKey and certificate"
|
||||
self.privateKey = privateKey
|
||||
self.certificate = certificate
|
||||
if method is not None:
|
||||
self.method = method
|
||||
|
||||
self.verify = verify
|
||||
assert ((verify and caCerts) or
|
||||
(not verify)), "Specify client CA certificate information if and only if enabling certificate verification"
|
||||
|
||||
self.caCerts = caCerts
|
||||
self.verifyDepth = verifyDepth
|
||||
self.requireCertificate = requireCertificate
|
||||
self.verifyOnce = verifyOnce
|
||||
self.enableSingleUseKeys = enableSingleUseKeys
|
||||
self.enableSessions = enableSessions
|
||||
self.fixBrokenPeers = fixBrokenPeers
|
||||
self.enableSessionTickets = enableSessionTickets
|
||||
self.fingerprint = fingerprint
|
||||
|
||||
def __getstate__(self):
|
||||
d = self.__dict__.copy()
|
||||
try:
|
||||
del d['_context']
|
||||
except KeyError:
|
||||
pass
|
||||
return d
|
||||
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
|
||||
|
||||
def getContext(self):
|
||||
if self._context is None:
|
||||
self._context = self._makeContext()
|
||||
return self._context
|
||||
|
||||
|
||||
def _makeContext(self):
|
||||
ctx = SSL.Context(self.method)
|
||||
|
||||
if self.certificate is not None and self.privateKey is not None:
|
||||
ctx.use_certificate(self.certificate)
|
||||
ctx.use_privatekey(self.privateKey)
|
||||
ctx.check_privatekey()
|
||||
|
||||
verifyFlags = SSL.VERIFY_NONE
|
||||
if self.verify or self.fingerprint:
|
||||
verifyFlags = SSL.VERIFY_PEER
|
||||
if self.requireCertificate:
|
||||
verifyFlags |= SSL.VERIFY_FAIL_IF_NO_PEER_CERT
|
||||
if self.verifyOnce:
|
||||
verifyFlags |= SSL.VERIFY_CLIENT_ONCE
|
||||
if self.caCerts:
|
||||
store = ctx.get_cert_store()
|
||||
for cert in self.caCerts:
|
||||
store.add_cert(cert)
|
||||
|
||||
def _verifyCallback(conn, cert, errno, depth, preverify_ok):
|
||||
if self.fingerprint:
|
||||
digest = cert.digest("sha1")
|
||||
if digest != self.fingerprint:
|
||||
log.msg("Remote server fingerprint mismatch. Got: %s Expect: %s" % (digest, self.fingerprint))
|
||||
return False
|
||||
else:
|
||||
log.msg("Remote server fingerprint match: %s " % (digest))
|
||||
return True
|
||||
return preverify_ok
|
||||
|
||||
ctx.set_verify(verifyFlags, _verifyCallback)
|
||||
|
||||
if self.verifyDepth is not None:
|
||||
ctx.set_verify_depth(self.verifyDepth)
|
||||
|
||||
if self.enableSingleUseKeys:
|
||||
ctx.set_options(SSL.OP_SINGLE_DH_USE)
|
||||
|
||||
if self.fixBrokenPeers:
|
||||
ctx.set_options(self._OP_ALL)
|
||||
|
||||
if self.enableSessions:
|
||||
sessionName = md5("%s-%d" % (reflect.qual(self.__class__), _sessionCounter())).hexdigest()
|
||||
ctx.set_session_id(sessionName)
|
||||
|
||||
if not self.enableSessionTickets:
|
||||
ctx.set_options(self._OP_NO_TICKET)
|
||||
|
||||
return ctx
|
||||
|
||||
class Communicator:
|
||||
def __init__(self):
|
||||
self.protocolInstances = {}
|
||||
|
|
Loading…
Reference in New Issue