237 lines
8.3 KiB
Python
237 lines
8.3 KiB
Python
"""
|
|
IRC protocol class and associated logging config and a shared ssl context
|
|
"""
|
|
|
|
import asyncio
|
|
import ssl
|
|
import logging
|
|
|
|
logging.basicConfig( level=logging.INFO,
|
|
format='[%(asctime)s]%(levelname)s%(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S' )
|
|
logging.addLevelName( 20, ' ' )
|
|
logging.addLevelName( 30, ' ::WARNING:: ' )
|
|
|
|
# shared ssl context for connections
|
|
SSL_CTX = ssl.create_default_context()
|
|
SSL_CTX.check_hostname = False
|
|
SSL_CTX.verify_mode = ssl.CERT_NONE
|
|
|
|
class IrcProtocol( asyncio.Protocol ):
|
|
"""
|
|
IRC protocol class
|
|
Represents a connection to an IRC server
|
|
|
|
Built on top of asyncio's protocol class
|
|
|
|
Accepts a user config on creation and a server config when connecting
|
|
|
|
Can add message callbacks that get called when connection
|
|
receives messages
|
|
|
|
Attempts to reconnect if connection lost.
|
|
|
|
Custom logging and warning functions.
|
|
"""
|
|
def __init__( self, evloop, user_cfg ):
|
|
""" Creates a new protocol instance. Pass an asyncio event loop and a bot user config dict
|
|
containing 'nick', 'user', 'name', and 'owner' fields.
|
|
"""
|
|
self._loop = evloop
|
|
self._recnt = 0 # reconnect attempts
|
|
self._dat = b'' # buffered message data
|
|
self._trans = None # protocol's transport
|
|
self._stop = False # True if doing a manual stop and dont try to reconnect
|
|
self.cfg = {'usr':user_cfg,'sv':{}} # server config set on first connect
|
|
# callbacks
|
|
self._cb = {'message':[],'connect':[],'disconnect':[]}
|
|
self.names = {} # map of names on server to list of channels they're visible in
|
|
|
|
def connection_made( self, transport ):
|
|
self.log( 'Connection made! {}'.format( transport ) )
|
|
self._trans = transport
|
|
self.send( 'NICK {}'.format( self.cfg['usr']['nick'] ) )
|
|
self.send( 'USER {} 0 * :{}'.format( self.cfg['usr']['user'], self.cfg['usr']['name'] ) )
|
|
for c in self._cb['connect']: c( self )
|
|
|
|
def connection_lost( self, exc ):
|
|
self._trans.close()
|
|
self.names.clear() # clear names
|
|
if exc is None: exc = 'EOF'
|
|
self.warning( 'Connection lost to server \'{}\'! ({})'.format( self.cfg['sv']['host'], exc ) )
|
|
# callbacks
|
|
for c in self._cb['disconnect']: c( self )
|
|
# remove all callbacks
|
|
#for cb in self._cb.values():
|
|
#cb.clear()
|
|
# if was a manual stop, dont attempt reconnect
|
|
if self._stop:
|
|
self.on_stop()
|
|
return
|
|
self.warning( 'Attempting reconnect...' )
|
|
self._loop.create_task( self.do_connect( self.cfg['sv'] ) )
|
|
|
|
def data_received( self, data ):
|
|
# buffer data, then strip off and parse each line within
|
|
self._dat += data
|
|
while b'\r\n' in self._dat:
|
|
raw_msg = self._dat[:self._dat.index( b'\r\n' )].decode()
|
|
self._dat = self._dat[self._dat.index( b'\r\n' ) + 2:]
|
|
words = raw_msg.split()
|
|
|
|
if words[0] == 'PING': self.send( 'PONG ' + words[1] )
|
|
# register success, join initial channels
|
|
if words[1] == '001':
|
|
self._recnt = 0 # reset reconnect count
|
|
if 'channels' in self.cfg['sv']:
|
|
for c in self.cfg['sv']['channels']:
|
|
self.send( 'JOIN {}'.format( c ) )
|
|
else: self.cfg['sv']['channels'] = []
|
|
# on privmsg, callbacks get self, chan, nick, and message
|
|
elif words[1] == 'PRIVMSG':
|
|
dst = words[2]
|
|
raw_user = words[0]
|
|
nick = raw_user[1 : words[0].find( '!' )]
|
|
msg = raw_msg[raw_msg.find( ':' , 1 ) + 1 :]
|
|
for cb in self._cb['message']: cb( self, dst, nick, msg )
|
|
# on names
|
|
elif words[1] == '353':
|
|
#self.log( '::NAMES:: {}'.format( raw_msg.decode() ) )
|
|
names_chan = words[4]
|
|
for n in words[5:]:
|
|
if n[0] == ':': n = n[1:] # strip first name
|
|
# strip user mode chars
|
|
for c in ['+','~','%','&','@']:
|
|
if n[0] == c: n = n[1:]
|
|
self.join_name( n, names_chan )
|
|
# update names (add callbacks?)
|
|
elif 'JOIN' in words[1]:
|
|
nick = words[0][1:words[0].find('!')]
|
|
if nick == self.cfg['usr']['nick']:
|
|
self.log( '::JOIN:: {} has joined {}.'.format( nick, words[2][1:] ) )
|
|
self.join_name( nick, words[2][1:] )
|
|
elif 'PART' in words[1]:
|
|
nick = words[0][1:words[0].find('!')]
|
|
if nick == self.cfg['usr']['nick']:
|
|
self.log( '::PART:: {} has left {}.'.format( nick, words[2] ) )
|
|
self.part_name( nick, words[2] )
|
|
elif 'QUIT' in words[1]:
|
|
nick = words[0][1:words[0].find('!')]
|
|
if nick == self.cfg['usr']['nick']:
|
|
self.log( '::QUIT:: {} has quit the server.'.format( nick ) )
|
|
self.quit_name( nick )
|
|
elif 'NICK' in words[1]:
|
|
self.log( raw_msg )
|
|
nick = words[0][1:words[0].find('!')]
|
|
#self.log( '::NICK:: {} is now known as {}.'.format( nick, words[2][1:] ) )
|
|
newnick = words[2]
|
|
if newnick[0] == ':': newnick = newnick[1:]
|
|
self.change_name( nick, newnick )
|
|
elif 'KICK' in words[1]:
|
|
if words[3] == self.cfg['usr']['nick']:
|
|
self.log( '::KICK:: {} has been kicked from {}.'.format( words[3], words[2] ) )
|
|
self.part_name( words[3], words[2] )
|
|
# manual stop
|
|
#self._stop = True
|
|
#self.send('QUIT :stopped' )
|
|
|
|
# names management
|
|
def join_name( self, name, chan ):
|
|
""" Adds a chan belonging to name """
|
|
if name not in self.names.keys():
|
|
self.names[name] = []
|
|
if chan not in self.names[name]:
|
|
self.names[name].append( chan )
|
|
|
|
def part_name( self, name, chan ):
|
|
""" Removes chan from name """
|
|
if chan in self.names[name]: self.names[name].remove( chan )
|
|
# if chans are empty, remove name completely
|
|
if not self.names[name]: self.quit_name( name )
|
|
# when bot leaves a chan, remove chan from all existing names
|
|
if name == self.cfg['usr']['nick']:
|
|
nms = []
|
|
for k,v in self.names.items():
|
|
if chan in v:
|
|
nms.append( k )
|
|
for n in nms:
|
|
#self.names[n].remove( chan )
|
|
self.part_name( n, chan )
|
|
|
|
def quit_name( self, name ):
|
|
""" Removes name from connection """
|
|
self.names.pop( name, [] )
|
|
|
|
def change_name( self, name, new ):
|
|
""" Renames an existing name """
|
|
#if new in self.names:
|
|
self.names[new] = self.names.pop( name, [] )
|
|
#else: self.names[new] = []
|
|
#self.log( 'NICK {} changed to {} w/now: {}'.format( name, new, self.names[new] ) )
|
|
|
|
# callbacks
|
|
def add_message_callback( self, cb ):
|
|
"""
|
|
Adds a function to callback when a priv message is received.
|
|
Passes in self, dst, nick, and the message
|
|
"""
|
|
if cb not in self._cb['message']:
|
|
self._cb['message'].append( cb )
|
|
def add_connect_callback( self, cb ):
|
|
if cb not in self._cb['connect']:
|
|
self._cb['connect'].append( cb )
|
|
def add_disconnect_callback( self, cb ):
|
|
if cb not in self._cb['disconnect']:
|
|
self._cb['disconnect'].append( cb )
|
|
|
|
# logging
|
|
def log( self, msg, level=logging.INFO ):
|
|
logging.log( level, '(%s) %s', self.cfg['sv']['name'], msg )
|
|
def warning( self, msg ):
|
|
self.log( msg, logging.WARNING )
|
|
|
|
# helper func for writing a trimmed string message with newline
|
|
def send( self, msg ):
|
|
#if 'PONG' not in msg: self.log( '>> {}'.format( msg ) )
|
|
self._trans.write( bytes( msg[:510] + '\r\n', 'utf-8' ) )
|
|
# helper func for sending a privmsg to specified destination
|
|
def say_to( self, dst, msg ):
|
|
msg = msg.replace( '\n', ' ' ).replace( '\r', '' )
|
|
self.log( '{} >> {}'.format( dst, msg ) )
|
|
self.send( 'PRIVMSG {} :{}'.format( dst, msg ) )
|
|
|
|
# create connection task
|
|
async def do_connect( self, sv_cfg ):
|
|
"""
|
|
Attempts a connection using specified server config dict.
|
|
Fields: 'name', 'host', 'port', and optional 'channels' array to join on connect
|
|
The 'name' field is custom name for this connection
|
|
mainly used to label logging output for easier readability.
|
|
Called using an event loop's create_task function
|
|
"""
|
|
# set initial server config
|
|
if not self.cfg['sv']: self.cfg['sv'] = sv_cfg
|
|
|
|
self.log( 'Connecting to server \'{}\'...'.format( self.cfg['sv']['host'] ) )
|
|
try: await asyncio.wait_for( self._loop.create_connection(
|
|
lambda: self, self.cfg['sv']['host'], self.cfg['sv']['port'], ssl=SSL_CTX ),
|
|
10, loop=self._loop )
|
|
except Exception as e:
|
|
if e is None: e = 'EOF'
|
|
self.warning( 'Connection exception! {}'.format( e ) )
|
|
self._recnt += 1
|
|
# if enough reconnect attempts were made, stop trying
|
|
if self._recnt > 60:
|
|
self.warning( 'Reconnect attempt limit reached, stopping connection...' )
|
|
self.on_stop()
|
|
return
|
|
self.warning( 'Reconnect attempt #{} in {}s...'.format( self._recnt, self._recnt * 10 ) )
|
|
await asyncio.sleep( self._recnt * 10 )
|
|
self._loop.create_task( self.do_connect( self.cfg['sv'] ) )
|
|
|
|
# after a connection has been stopped
|
|
def on_stop( self ):
|
|
self.warning( 'Stopped connection to server \'{}\'.'.format( self.cfg['sv']['host'] ) )
|
|
#db.ps.close()
|
|
|