Refactored bot state, added time-based events, remodeled socket engine to be non-blocking
This commit is contained in:
parent
cf2db77eb2
commit
1952376b51
|
@ -3,6 +3,7 @@
|
|||
NOTES
|
||||
test.py
|
||||
settings.json
|
||||
*private*.json
|
||||
plugins/*private*.py
|
||||
*.vim
|
||||
*.swp
|
||||
|
|
274
bot.py
274
bot.py
|
@ -7,6 +7,12 @@ from json import dumps as json_dumps, loads as json_loads
|
|||
from pathlib import Path
|
||||
from importlib import import_module
|
||||
from sys import stderr
|
||||
from argparse import ArgumentParser, FileType
|
||||
from time import time, sleep, mktime
|
||||
from threading import Thread
|
||||
from queue import Queue, Empty
|
||||
from select import select
|
||||
from datetime import datetime
|
||||
|
||||
def isnumeric(test):
|
||||
test.replace(".", "", 1).isdigit()
|
||||
|
@ -15,50 +21,99 @@ def eprint(*args, **kwargs):
|
|||
print(*args, file=stderr, **kwargs)
|
||||
|
||||
class Bot:
|
||||
def __init__(self, host, port, plugins, secure=True, timeout=0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.secure = secure
|
||||
self.timeout= timeout
|
||||
self.plugins = self._load_plugins(plugins)
|
||||
def __init__(self, config):
|
||||
self.server = "localhost"
|
||||
self.port = 6667
|
||||
self.secure = False
|
||||
self.timeout = 0
|
||||
self.plugins_dir = "plugins"
|
||||
|
||||
self.sock = None
|
||||
self.read_sock = None
|
||||
self.write_sock = None
|
||||
self.scripts = dict()
|
||||
self.plugins = dict()
|
||||
self.settings = dict()
|
||||
|
||||
self.mode = "reading"
|
||||
self.inputs = list()
|
||||
self.outputs = list()
|
||||
|
||||
self._load_settings(config)
|
||||
self._load_plugins()
|
||||
|
||||
self.events = Queue()
|
||||
self.actions = Queue()
|
||||
self.state = dict()
|
||||
self.state["_flags"] = dict()
|
||||
self.state["settings"] = self.settings
|
||||
|
||||
def _connect(self):
|
||||
print(f"Connecting to {self.host}:{self.port}...")
|
||||
print(f"Connecting to {self.server}:{self.port}...")
|
||||
while True:
|
||||
self.sock = socket(AF_INET, SOCK_STREAM)
|
||||
if self.secure:
|
||||
self.sock = wrap_socket(self.sock)
|
||||
|
||||
self.sock.setblocking(False)
|
||||
if self.timeout > 0:
|
||||
self.sock.settimeout(self.timeout)
|
||||
|
||||
try:
|
||||
self.sock.connect((self.host, self.port))
|
||||
self.sock.connect((self.server, self.port))
|
||||
print("Connected!")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print("Trying to connect again...")
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
def _load_plugins(self, location):
|
||||
# Needs to be run twice since connect consumes timeout
|
||||
if self.timeout > 0:
|
||||
self.sock.settimeout(self.timeout)
|
||||
|
||||
def _load_settings(self, config):
|
||||
if config is None:
|
||||
settings_path = Path(__file__).resolve().parent / "settings.json"
|
||||
settings_text = settings_path.read_text()
|
||||
|
||||
else:
|
||||
settings_text = config.read()
|
||||
|
||||
self.settings = json_loads(settings_text)
|
||||
meta_settings = self.settings.get("_meta", dict())
|
||||
self.server = meta_settings.get("server", self.server)
|
||||
self.port = meta_settings.get("port", self.port)
|
||||
self.secure = meta_settings.get("secure", self.secure)
|
||||
self.timeout = meta_settings.get("timeout", self.timeout)
|
||||
self.plugins_dir = meta_settings.get("plugins_dir", self.plugins_dir)
|
||||
|
||||
def _load_plugins(self):
|
||||
scope = dict()
|
||||
plugins = dict()
|
||||
modules = Path(__file__).parent
|
||||
for loc in location.split("/"):
|
||||
container = self.plugins_dir.replace("/", ".")
|
||||
for loc in self.plugins_dir.split("/"):
|
||||
modules = modules / loc
|
||||
|
||||
ignore_plugins = self.settings.get("ignore_plugins", list())
|
||||
extra_plugins = self.settings.get("extra_plugins", list())
|
||||
for module in modules.glob("*.py"):
|
||||
if not module.is_file() or module.name == "__init__.py":
|
||||
continue
|
||||
|
||||
package = module.name.replace(".py", "")
|
||||
plugins[package] = list()
|
||||
if package in ignore_plugins:
|
||||
continue
|
||||
|
||||
if "private" in package and package not in extra_plugins:
|
||||
continue
|
||||
|
||||
self.plugins[package] = list()
|
||||
|
||||
# Equivalent of doing "import <package>.<module>"
|
||||
container = location.replace("/", ".")
|
||||
script = import_module(f"{container}.{package}")
|
||||
self.scripts[package] = script
|
||||
|
||||
for variable in dir(script):
|
||||
# Ignore Python internals
|
||||
|
@ -73,9 +128,7 @@ class Bot:
|
|||
if not callable(script_var):
|
||||
continue
|
||||
|
||||
plugins[package].append(script_var)
|
||||
|
||||
return plugins
|
||||
self.plugins[package].append(script_var)
|
||||
|
||||
def _cleanup(self, parameters, dirt=":"):
|
||||
params = parameters.partition(dirt)
|
||||
|
@ -158,12 +211,25 @@ class Bot:
|
|||
def privmsg(self, message):
|
||||
return message.get("params", " ").split(" ", 1)
|
||||
|
||||
def send(self, message):
|
||||
def write(self, message):
|
||||
# DEBUG
|
||||
print(f">> {message}")
|
||||
try:
|
||||
self.sock.send(f"{message}\r\n".encode())
|
||||
self.write_sock.send(f"{message}\r\n".encode())
|
||||
return True
|
||||
|
||||
except sock_timeout as e:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
eprint(f"ERROR:\n{format_exc()}")
|
||||
return False
|
||||
|
||||
# Sanity fail-safe
|
||||
return False
|
||||
|
||||
def send(self, message):
|
||||
self.actions.put(message)
|
||||
|
||||
def sendto(self, channel, message):
|
||||
self.send(f"PRIVMSG {channel} :{message}")
|
||||
|
@ -174,21 +240,140 @@ class Bot:
|
|||
def leave(self, channel, farewell="Bye-bye!"):
|
||||
self.send(f"PART {channel} :{farewell}")
|
||||
|
||||
def subscribe(self, wait, callback, context=dict()):
|
||||
input_time, time_unit = int(wait[:-1]), wait[-1]
|
||||
wait_time = 0
|
||||
if time_unit == "s":
|
||||
wait_time = input_time
|
||||
|
||||
if time_unit == "m":
|
||||
wait_time = input_time * 60
|
||||
|
||||
if time_unit == "h":
|
||||
wait_time = input_time * 60 * 60
|
||||
|
||||
if time_unit == "d":
|
||||
wait_time = input_time * 60 * 60 * 24
|
||||
|
||||
if time_unit == "w":
|
||||
wait_time = input_time * 60 * 60 * 24 * 7
|
||||
|
||||
at = time() + wait_time
|
||||
event = at, callback, context
|
||||
self.events.put(event)
|
||||
|
||||
def process_events(self):
|
||||
events = list()
|
||||
#print("Processing events...")
|
||||
while not self.events.empty():
|
||||
event = self.events.get()
|
||||
if not isinstance(event, tuple):
|
||||
continue
|
||||
|
||||
if len(event) != 3 or None in event:
|
||||
continue
|
||||
|
||||
at, callback, context = event
|
||||
if time() < at:
|
||||
print(f"Skipping: {event}")
|
||||
events.append(event)
|
||||
continue
|
||||
|
||||
print(f"Running: {event}")
|
||||
callback(self, **context)
|
||||
|
||||
for event in events:
|
||||
self.events.put(event)
|
||||
|
||||
def toggle_mode(self):
|
||||
curr_index = 0 if self.mode == "reading" else 1
|
||||
next_index = 1 if curr_index == 0 else 0
|
||||
|
||||
modes = ["reading", "writing"]
|
||||
socks = [self.read_sock, self.write_sock]
|
||||
sock_lists = [self.inputs, self.outputs]
|
||||
|
||||
curr_sock = socks[curr_index]
|
||||
next_sock = socks[next_index]
|
||||
|
||||
curr_sock_list = sock_lists[curr_index]
|
||||
next_sock_list = sock_lists[next_index]
|
||||
|
||||
self.mode = modes[next_index]
|
||||
if curr_sock in curr_sock_list:
|
||||
curr_sock_list.remove(curr_sock)
|
||||
|
||||
if curr_sock not in next_sock_list:
|
||||
next_sock_list.append(curr_sock)
|
||||
|
||||
def set_read_mode(self):
|
||||
self.mode == "reading"
|
||||
if self.write_sock in self.outputs:
|
||||
self.outputs.remove(self.write_sock)
|
||||
|
||||
if self.write_sock not in self.inputs:
|
||||
self.inputs.append(self.write_sock)
|
||||
|
||||
def run(self):
|
||||
cache = b""
|
||||
state = dict()
|
||||
|
||||
settings_path = Path(__file__).resolve().parent / "settings.json"
|
||||
settings_text = settings_path.read_text()
|
||||
settings = json_loads(settings_text)
|
||||
state["_flags"] = dict()
|
||||
state["settings"] = settings
|
||||
#Thread(target=self.process_events, daemon=True).start()
|
||||
|
||||
self._connect()
|
||||
while not state.get("stop"):
|
||||
|
||||
self.inputs = [self.sock]
|
||||
while not self.state.get("stop"):
|
||||
socks = select(self.inputs, self.outputs, self.inputs, 0)
|
||||
if not any(socks):
|
||||
eprint("ERROR: Timed out!")
|
||||
|
||||
|
||||
# NOTE - Anything to do while it waits should be done here:
|
||||
# <waiting>
|
||||
self.process_events()
|
||||
# </waiting>
|
||||
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
read_socks, write_socks, exc_socks = socks
|
||||
if exc_socks:
|
||||
print(exc_socks)
|
||||
|
||||
has_read_sock = len(read_socks) > 0
|
||||
has_write_sock = len(write_socks) > 0
|
||||
|
||||
if has_read_sock:
|
||||
self.read_sock = read_socks[0]
|
||||
|
||||
if has_write_sock:
|
||||
self.write_sock = write_socks[0]
|
||||
|
||||
# This needs to be done first or it will wait on response to run
|
||||
requeue = list()
|
||||
while not self.actions.empty():
|
||||
action = self.actions.get()
|
||||
written = self.write(action)
|
||||
if not written:
|
||||
requeue.append(action)
|
||||
|
||||
else:
|
||||
#print("Done writing")
|
||||
for rq in requeue:
|
||||
print("Requeuing")
|
||||
self.actions.put(rq)
|
||||
|
||||
if not has_read_sock:
|
||||
self.toggle_mode()
|
||||
|
||||
data = cache
|
||||
try:
|
||||
packet = self.sock.recv(512)
|
||||
packet = self.read_sock.recv(512)
|
||||
|
||||
except sock_timeout as e:
|
||||
#eprint("Timed out on read")
|
||||
#self.process_events()
|
||||
self.toggle_mode()
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
eprint(f"ERROR:\n{format_exc()}")
|
||||
|
@ -196,6 +381,7 @@ class Bot:
|
|||
|
||||
data = data + packet
|
||||
if len(data) == 0:
|
||||
self.toggle_mode()
|
||||
continue
|
||||
|
||||
newline = b"\n"
|
||||
|
@ -221,26 +407,46 @@ class Bot:
|
|||
if len(message) == 0:
|
||||
continue
|
||||
|
||||
#print(message)
|
||||
parsed = self.parse_irc_message(message.strip("\r"))
|
||||
#print(parsed)
|
||||
|
||||
# DEBUG
|
||||
lag = 0
|
||||
tags = parsed.get("tags")
|
||||
if tags:
|
||||
time_tag = tags.get("time")
|
||||
time_str = "%Y-%m-%dT%H:%M:%S.%fZ"
|
||||
convert = datetime.strptime(time_tag, time_str).timetuple()
|
||||
msgnow = mktime(convert)
|
||||
now = datetime.utcnow().timestamp()
|
||||
lag = now - msgnow
|
||||
|
||||
print(f"<< [[lag={lag:.2f}s]] {message}")
|
||||
#print("%%", parsed)
|
||||
|
||||
|
||||
for plugin_name, plugin_callbacks in self.plugins.items():
|
||||
for callback in plugin_callbacks:
|
||||
try:
|
||||
state = callback(self, state, parsed)
|
||||
callback(self, parsed)
|
||||
|
||||
except Exception as e:
|
||||
eprint(f"ERROR:\n{format_exc()}")
|
||||
continue
|
||||
|
||||
if state is None:
|
||||
if self.state is None:
|
||||
eprint(f"ERROR: {plugin_name} returned None")
|
||||
continue
|
||||
|
||||
flags = state.get("_flags")
|
||||
flags = self.state.get("_flags")
|
||||
for key, value in flags.items():
|
||||
state[key] = value
|
||||
self.state[key] = value
|
||||
|
||||
self.toggle_mode()
|
||||
|
||||
if __name__ == "__main__":
|
||||
bot = Bot("irc.tilde.chat", 6697, "plugins", True)
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("-c", nargs="?", type=FileType("r"), dest="config")
|
||||
args = parser.parse_args()
|
||||
|
||||
bot = Bot(config=args.config)
|
||||
bot.run()
|
||||
|
|
|
@ -1,31 +1,31 @@
|
|||
def bot_callback(bot, state, message):
|
||||
def bot_callback(bot, message):
|
||||
command = message.get("command")
|
||||
params = message.get("params")
|
||||
has_tags = message.get("has_tags")
|
||||
tags = message.get("tags", dict())
|
||||
account = tags.get("account")
|
||||
settings = state.get("settings", dict())
|
||||
settings = bot.state.get("settings", dict())
|
||||
irc_settings = settings.get("irc", dict())
|
||||
author = irc_settings.get("author", "aewens")
|
||||
farewell = irc_settings.get("farewell", "Bye-bye!")
|
||||
|
||||
if command == "ERROR":
|
||||
state["stop"] = True
|
||||
bot.state["stop"] = True
|
||||
bot.send("QUIT :encountered error!")
|
||||
return state
|
||||
|
||||
if command == "PING":
|
||||
bot.send(f"PONG :{params}")
|
||||
|
||||
# NOTE - Your code goes here
|
||||
if command == "PRIVMSG" and state.get("joined"):
|
||||
if command == "PRIVMSG" and bot.state.get("joined"):
|
||||
NotImplemented
|
||||
|
||||
if command == "INVITE":
|
||||
name, channel = params.split(" ", 1)
|
||||
bot.join(channel)
|
||||
|
||||
if all([command == "PRIVMSG", state.get("joined"), has_tags, account == author]):
|
||||
if all([command == "PRIVMSG", bot.state.get("joined"), has_tags,
|
||||
account == author]):
|
||||
channel, privmsg = bot.privmsg(message)
|
||||
if privmsg.startswith("!join"):
|
||||
action, chan = privmsg.split(" ", 1)
|
||||
|
@ -34,5 +34,3 @@ def bot_callback(bot, state, message):
|
|||
if privmsg.startswith("!leave"):
|
||||
action, chan = privmsg.split(" ", 1)
|
||||
bot.leave(chan, farewell)
|
||||
|
||||
return state
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
def bot_callback(bot, state, message):
|
||||
def bot_callback(bot, message):
|
||||
command = message.get("command")
|
||||
params = message.get("params")
|
||||
has_tags = message.get("has_tags")
|
||||
|
@ -8,19 +8,19 @@ def bot_callback(bot, state, message):
|
|||
source = message.get("source")
|
||||
nick = message.get("nick")
|
||||
|
||||
settings = state.get("settings", dict())
|
||||
settings = bot.state.get("settings", dict())
|
||||
irc_settings = settings.get("irc", dict())
|
||||
bot_name = irc_settings.get("name", "ircbot")
|
||||
|
||||
flags = state.get("_flags")
|
||||
flags = bot.state.get("_flags")
|
||||
|
||||
if all([source == "server", command == "NOTICE", not state.get("init"),
|
||||
if all([source == "server", command == "NOTICE", not bot.state.get("init"),
|
||||
"Looking up your hostname" in params]):
|
||||
flags["init"] = True
|
||||
flags["caps"] = False
|
||||
bot.send("CAP LS")
|
||||
|
||||
if all([source == "server", command == "CAP", not state.get("caps"),
|
||||
if all([source == "server", command == "CAP", not bot.state.get("caps"),
|
||||
"* LS" in params]):
|
||||
capabilities = params.split("* LS", 1)[1].strip()
|
||||
requirements = irc_settings.get("requirements", list())
|
||||
|
@ -40,7 +40,7 @@ def bot_callback(bot, state, message):
|
|||
else:
|
||||
flags["stop"] = True
|
||||
|
||||
if all([source == "server", command == "CAP", not state.get("caps"),
|
||||
if all([source == "server", command == "CAP", not bot.state.get("caps"),
|
||||
"* ACK" in params]):
|
||||
flags["caps"] = True
|
||||
bot.send("CAP END")
|
||||
|
@ -48,20 +48,17 @@ def bot_callback(bot, state, message):
|
|||
bot.send(f"NICK {bot_name}")
|
||||
|
||||
if all([source == "nick", nick == "NickServ", command == "NOTICE",
|
||||
not state.get("identify"), any(["choose a different nick" in params,
|
||||
not bot.state.get("identify"), any(["choose a different nick" in params,
|
||||
"Your nickname is not registered" in params])]):
|
||||
bot_pass = irc_settings.get("password", "password")
|
||||
flags["identify"] = True
|
||||
bot.send(f"PRIVMSG NickServ :IDENTIFY {bot_name} {bot_pass}")
|
||||
|
||||
if all([source == "nick", nick == "NickServ", command == "MODE",
|
||||
state.get("identify"), "+r" in params]):
|
||||
bot.state.get("identify"), "+r" in params]):
|
||||
flags["registered"] = True
|
||||
bot.send(f"MODE {bot_name} +B")
|
||||
channels = irc_settings.get("auto-join", list())
|
||||
for channel in channels:
|
||||
bot.join(channel)
|
||||
|
||||
flags["joined"] = True
|
||||
|
||||
return state
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
def bot_callback(bot, state, message):
|
||||
def bot_callback(bot, message):
|
||||
command = message.get("command")
|
||||
params = message.get("params")
|
||||
has_tags = message.get("has_tags")
|
||||
|
@ -8,22 +8,20 @@ def bot_callback(bot, state, message):
|
|||
source = message.get("source")
|
||||
nick = message.get("nick")
|
||||
|
||||
settings = state.get("settings", dict())
|
||||
settings = bot.state.get("settings", dict())
|
||||
irc_settings = settings.get("irc", dict())
|
||||
bot_name = irc_settings.get("name", "ircbot")
|
||||
author = irc_settings.get("author", "aewens")
|
||||
about = irc_settings.get("about", "I am made by {}")
|
||||
|
||||
flags = state.get("_flags")
|
||||
flags = bot.state.get("_flags")
|
||||
|
||||
if all([source == "nick", nick == "NickServ", state.get("registered"),
|
||||
not state.get("is-bot")]):
|
||||
if all([source == "nick", nick == "NickServ", bot.state.get("registered"),
|
||||
not bot.state.get("is-bot")]):
|
||||
flags["is-bot"] = True
|
||||
bot.send(f"MODE {bot_name} +B")
|
||||
#bot.send(f"MODE {bot_name} +B")
|
||||
|
||||
if all([command == "PRIVMSG", state.get("joined")]):
|
||||
if all([command == "PRIVMSG", bot.state.get("joined")]):
|
||||
channel, privmsg = bot.privmsg(message)
|
||||
if privmsg in ["!botlist", "!rollcall"]:
|
||||
bot.sendto(channel, about.format(author))
|
||||
|
||||
return state
|
||||
|
|
|
@ -1,11 +1,19 @@
|
|||
{
|
||||
"_meta": {
|
||||
"server": "irc.tilde.chat",
|
||||
"port": 6697,
|
||||
"secure": true,
|
||||
"timeout": 0.1,
|
||||
"plugins_dir": "plugins"
|
||||
},
|
||||
"irc": {
|
||||
"name": "",
|
||||
"password": "",
|
||||
"author": "",
|
||||
"about": "",
|
||||
"auto-join": [
|
||||
"#bots"
|
||||
"#bots",
|
||||
"#babili"
|
||||
],
|
||||
"requirements": [
|
||||
"account-notify",
|
||||
|
@ -15,7 +23,8 @@
|
|||
"chghost",
|
||||
"extended-join",
|
||||
"message-tags",
|
||||
"server-time"
|
||||
"server-time",
|
||||
"echo-message"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue