dustbot/protocol.py

254 lines
8.9 KiB
Python

"""
IRC protocol class and associated logging config and a shared ssl context
"""
import asyncio
import ssl
import logging
logging.basicConfig( level=logging.INFO,
format='[%(asctime)s]%(levelname)s%(message)s',
datefmt='%Y-%m-%d %H:%M:%S' )
logging.addLevelName( 20, ' ' )
logging.addLevelName( 30, ' ::WARNING:: ' )
# shared ssl context for connections
SSL_CTX = ssl.create_default_context()
SSL_CTX.check_hostname = False
SSL_CTX.verify_mode = ssl.CERT_NONE
class IrcProtocol( asyncio.Protocol ):
"""
IRC protocol class
Represents a connection to an IRC server
Built on top of asyncio's protocol class
Accepts a user config on creation and a server config when connecting
Can add message callbacks that get called when connection
receives messages
Attempts to reconnect if connection lost.
Custom logging and warning functions.
"""
def __init__( self, evloop, user_cfg ):
""" Creates a new protocol instance. Pass an asyncio event loop and a bot user config dict
containing 'nick', 'user', 'name', and 'owner' fields.
"""
self._loop = evloop
self._recnt = 0 # reconnect attempts
self._dat = b'' # buffered message data
self._trans = None # protocol's transport
self._stop = False # True if doing a manual stop and dont try to reconnect
self.cfg = {'usr':user_cfg,'sv':{}} # server config set on first connect
# callbacks
self._cb = {'message':[],'connect':[],'disconnect':[]}
self.names = {} # map of names on server to list of channels they're visible in
def connection_made( self, transport ):
self.log( 'Connection made! {}'.format( transport ) )
self._trans = transport
self.send( 'NICK {}'.format( self.cfg['usr']['nick'] ) )
self.send( 'USER {} 0 * :{}'.format( self.cfg['usr']['user'], self.cfg['usr']['name'] ) )
for c in self._cb['connect']: c( self )
def connection_lost( self, exc ):
self._trans.close()
self.names.clear() # clear names
if exc is None: exc = 'EOF'
self.warning( 'Connection lost to server \'{}\'! ({})'.format( self.cfg['sv']['host'], exc ) )
# callbacks
for c in self._cb['disconnect']: c( self )
# remove all callbacks
#for cb in self._cb.values():
#cb.clear()
# if was a manual stop, dont attempt reconnect
if self._stop:
self.on_stop()
return
self.warning( 'Attempting reconnect...' )
self._loop.create_task( self.do_connect( self.cfg['sv'] ) )
def data_received( self, data ):
# buffer data, then strip off and parse each line within
self._dat += data
while b'\r\n' in self._dat:
raw_msg = self._dat[:self._dat.index( b'\r\n' )].decode()
self._dat = self._dat[self._dat.index( b'\r\n' ) + 2:]
words = raw_msg.split()
if 'ERROR' in words[0]: self.log( raw_msg )
# closing error (too many users?!)
if 'Closing' in words[1]:
self._stop = True
if words[0] == 'PING': self.send( 'PONG ' + words[1] )
# register success, join initial channels
if words[1] == '001':
self._recnt = 0 # reset reconnect count
if 'channels' in self.cfg['sv']:
for c in self.cfg['sv']['channels']:
self.send( 'JOIN {}'.format( c ) )
else: self.cfg['sv']['channels'] = []
# on privmsg, callbacks get self, chan, nick, and message
elif words[1] == 'PRIVMSG':
dst = words[2]
raw_user = words[0]
nick = raw_user[1 : words[0].find( '!' )]
msg = raw_msg[raw_msg.find( ':' , len( words[0] ) ) + 1 :]
for cb in self._cb['message']: cb( self, dst, nick, msg )
# on names
elif words[1] == '353':
#self.log( '::NAMES:: {}'.format( raw_msg.decode() ) )
names_chan = words[4]
for n in words[5:]:
if n[0] == ':': n = n[1:] # strip first name
# strip user mode chars
for c in ['+','~','%','&','@']:
if n[0] == c: n = n[1:]
self.join_name( n, names_chan )
# update names (add callbacks?)
elif 'JOIN' in words[1]:
self.log( '::JOIN:: {}'.format( raw_msg ) )
nick = words[0][1:words[0].find('!')]
jchan = words[2]
if jchan[0] == ':': jchan = jchan[1:]
if nick == self.cfg['usr']['nick']:
self.log( '::JOIN:: {} has joined {}.'.format( nick, jchan ) )
self.join_name( nick, jchan )
elif 'PART' in words[1]:
self.log( '::PART:: {}'.format( raw_msg ) )
nick = words[0][1:words[0].find('!')]
pchan = words[2]
if pchan[0] == ':': pchan = pchan[1:]
if nick == self.cfg['usr']['nick']:
self.log( '::PART:: {} has left {}.'.format( nick, pchan ) )
self.part_name( nick, pchan )
elif 'QUIT' in words[1]:
self.log( '::QUIT:: {}'.format( raw_msg ) )
nick = words[0][1:words[0].find('!')]
if nick == self.cfg['usr']['nick']:
self.log( '::QUIT:: {} has quit the server.'.format( nick ) )
self.quit_name( nick )
elif 'NICK' in words[1]:
self.log( raw_msg )
nick = words[0][1:words[0].find('!')]
#self.log( '::NICK:: {} is now known as {}.'.format( nick, words[2][1:] ) )
newnick = words[2]
if newnick[0] == ':': newnick = newnick[1:]
self.change_name( nick, newnick )
elif 'KICK' in words[1]:
kchan = words[2]
if kchan[0] == ':': kchan = kchan[1:]
if words[3] == self.cfg['usr']['nick']:
self.log( '::KICK:: {} has been kicked from {}.'.format( words[3], kchan ) )
self.part_name( words[3], kchan )
# manual stop
#self._stop = True
#self.send('QUIT :stopped' )
# names management
def join_name( self, name, chan ):
""" Adds a chan belonging to name """
if name not in self.names.keys():
self.names[name] = []
if chan not in self.names[name]:
self.names[name].append( chan )
def part_name( self, name, chan ):
""" Removes chan from name """
if chan in self.names[name]: self.names[name].remove( chan )
# if chans are empty, remove name completely
if not self.names[name]: self.quit_name( name )
# when bot leaves a chan, remove chan from all existing names
if name == self.cfg['usr']['nick']:
nms = []
for k,v in self.names.items():
if chan in v:
nms.append( k )
for n in nms:
#self.names[n].remove( chan )
self.part_name( n, chan )
def quit_name( self, name ):
""" Removes name from connection """
self.names.pop( name, [] )
def change_name( self, name, new ):
""" Renames an existing name """
#if new in self.names:
self.names[new] = self.names.pop( name, [] )
#else: self.names[new] = []
self.log( 'NICK {} changed to {} w/now: {}'.format( name, new, self.names[new] ) )
# callbacks
def add_message_callback( self, cb ):
"""
Adds a function to callback when a priv message is received.
Passes in self, dst, nick, and the message
"""
if cb not in self._cb['message']:
self._cb['message'].append( cb )
def add_connect_callback( self, cb ):
if cb not in self._cb['connect']:
self._cb['connect'].append( cb )
def add_disconnect_callback( self, cb ):
if cb not in self._cb['disconnect']:
self._cb['disconnect'].append( cb )
# logging
def log( self, msg, level=logging.INFO ):
logging.log( level, '(%s) %s', self.cfg['sv']['name'], msg )
def warning( self, msg ):
self.log( msg, logging.WARNING )
# helper func for writing a trimmed string message with newline
def send( self, msg ):
#if 'PONG' not in msg: self.log( '>> {}'.format( msg ) )
self._trans.write( bytes( msg[:510] + '\r\n', 'utf-8' ) )
# helper func for sending a privmsg to specified destination
# also passes on to callbacks by default but can be disabled
def say_to( self, dst, msg, callback=True ):
msg = msg.replace( '\n', ' ' ).replace( '\r', '' )
self.log( '{} >> {}'.format( dst, msg ) )
self.send( 'PRIVMSG {} :{}'.format( dst, msg ) )
if callback:
for cb in self._cb['message']: cb( self, dst, self.cfg['usr']['nick'], msg )
# create connection task
async def do_connect( self, sv_cfg ):
"""
Attempts a connection using specified server config dict.
Fields: 'name', 'host', 'port', and optional 'channels' array to join on connect
The 'name' field is custom name for this connection
mainly used to label logging output for easier readability.
Called using an event loop's create_task function
"""
# set initial server config
if not self.cfg['sv']: self.cfg['sv'] = sv_cfg
self.log( 'Connecting to server \'{}\'...'.format( self.cfg['sv']['host'] ) )
try: await asyncio.wait_for( self._loop.create_connection(
lambda: self, self.cfg['sv']['host'], self.cfg['sv']['port'], ssl=SSL_CTX ),
10, loop=self._loop )
except Exception as e:
if e is None: e = 'EOF'
self.warning( 'Connection exception! {}'.format( e ) )
self._recnt += 1
# if enough reconnect attempts were made, stop trying
if self._recnt > 60:
self.warning( 'Reconnect attempt limit reached, stopping connection...' )
self.on_stop()
return
self.warning( 'Reconnect attempt #{} in {}s...'.format( self._recnt, self._recnt * 10 ) )
await asyncio.sleep( self._recnt * 10 )
self._loop.create_task( self.do_connect( self.cfg['sv'] ) )
# after a connection has been stopped
def on_stop( self ):
self.warning( 'Stopped connection to server \'{}\'.'.format( self.cfg['sv']['host'] ) )
#db.ps.close()