abots/abots/helpers/black_magic.py

79 lines
2.7 KiB
Python

from collections import defaultdict
from functools import wraps
from os.path import dirname, basename, isfile, join
from glob import glob
from importlib import import_module
def infinitedict():
d = lambda: defaultdict(d)
return defaultdict(d)
def debugger(func):
@wraps(func)
def wrapper_debug(*args, **kwargs):
args_repr = [repr(arg) for arg in args]
kwargs_repr = [f"{key}={value!r}" for key, value in kwargs.items()]
signature = ", ".join(args_repr + kwargs_repr)
print(f"[DEBUGGER]: Calling {func.__name__}({signature})")
result = func(*args, **kwargs)
print(f"[DEBUGGER]: {func.__name__!r} returned {result!r}")
return result
return wrapper_debug
def coroutine(func):
@wraps(func)
def wrapper_coroutine(*args, **kwargs):
coro = func(*args, **kwargs)
coro.__next__()
return coro
return wrapper_coroutine
def generator(func):
@wraps(func)
def wrapper_generator(*args, **kwargs):
try:
while True:
yield from func(*args, **kwargs)
except GeneratorExit:
pass
return coroutine(wrapper_generator)
def singleton(cls):
@wraps(cls)
def wrapper_singleton(*args, **kwargs):
if not wrapper_singleton.instance:
wrapper_singleton.instance = cls(*args, **kwargs)
return wrapper_singleton.instance
wrapper_singleton.instance = None
return wrapper_singleton
def curry(func, argc=None):
if argc is None:
argc = func.func_code.co_argcount
@wraps(func)
def wrapper_curry(*args):
if len(args) == argc:
return func(*args)
def curried(*c_args):
return func(*(args + c_args))
return curry(curried, argc - len(args))
return wrapper_curry
# This is used to automatically import the files in this directory
# Essentially, it is a auto-loader for a plugin system
# NOTE: Do as I say, not as I do. You should probably never do this
def autoload(location, context, package, prefix=""):
level = -(len(package.split(".")) + 1)
for module in glob(join(dirname(location), "*.py")):
if not isfile(module) or module.endswith("__init__.py"):
continue
# Equivalent of doing "import <package>.<module>"
plugin = import_module(f".{basename(module)[:level]}", package)
funcs = [f for f in dir(plugin) if f[0] != "_"]
for func in funcs:
# Translates the above to "from <package>.<module> import *"
plugin_func = getattr(plugin, func)
# To reduce conflicts in global, the prefix is used here
# These should not be used directly and just fire off decorators
context[f"{prefix}{func}"] = plugin_func