ManifoldMarketManager/ManifoldMarketManager/state/persistant.py

232 lines
8.0 KiB
Python

"""Store state between sessions."""
from __future__ import annotations
from logging import getLogger
from os import getenv
from pathlib import Path
from sqlite3 import PARSE_COLNAMES, PARSE_DECLTYPES, connect
from typing import TYPE_CHECKING
from uuid import NAMESPACE_DNS, uuid5
from ..account import Account
from ..consts import EnvironmentVariable
from ..util import require_env
if TYPE_CHECKING: # pragma: no cover
from datetime import datetime
from sqlite3 import Connection, Cursor
from typing import Any, Callable, Iterable, Sequence
from ..market import Market
from ..util import T
logger = getLogger(__name__)
def db_wrapper(func: Callable[..., T]) -> Callable[..., T]:
"""Wrap a function so that it automatically gets a reference to the database if one is not provided."""
def wrapper(*args: Any, db: Connection | None = None, **kwargs: Any) -> T:
if db is None:
with register_db() as db:
return func(*args, db=db, **kwargs)
return func(*args, db=db, **kwargs)
return wrapper
@require_env(EnvironmentVariable.DBName)
def register_db() -> Connection:
"""Get a connection to the appropriate database for this bot."""
name = getenv("DBName")
if name is None:
raise EnvironmentError()
do_initialize = not Path(name).exists()
conn = connect(name, detect_types=PARSE_COLNAMES | PARSE_DECLTYPES)
if do_initialize:
conn.execute(
"CREATE TABLE accounts "
"(id INTEGER PRIMARY KEY AUTOINCREMENT, manidfold_id TEXT NOT NULL, username TEXT NOT NULL, "
"raw_account BLOB, is_encrypted BOOLEAN, account Account)"
)
conn.execute(
"CREATE TABLE markets "
"(id INTEGER PRIMARY KEY AUTOINCREMENT, market Market, check_rate REAL NOT NULL, last_checked TIMESTAMP, "
"account INTEGER REFERENCES \"accounts\" (\"id\") ON DELETE SET NULL)"
)
conn.execute(
"CREATE TABLE pending "
"(id INTEGER PRIMARY KEY AUTOINCREMENT, request ManagerRequest NOT NULL, priority REAL NOT NULL, "
"account INTEGER REFERENCES \"accounts\" (\"id\") ON DELETE SET NULL)"
)
conn.execute(
"CREATE TABLE scanners "
"(id INTEGER PRIMARY KEY AUTOINCREMENT, scanner EventEmitter NOT NULL, state Namesapce, "
"check_rate REAL NOT NULL, last_checked TIMESTAMP, "
"account INTEGER REFERENCES \"accounts\" (\"id\") ON DELETE SET NULL)"
)
conn.commit()
logger.info("Database up and initialized.")
return conn
@db_wrapper
def remove_markets(
*row_id: int,
db: Connection = None # type: ignore[assignment]
) -> None:
"""Attempt to delete a market in the database."""
assert db is not None
db.execute(f"DELETE FROM markets WHERE {' OR '.join(['id = ?'] * len(row_id))}", row_id)
@db_wrapper
def find_account(
account: Account,
db: Connection = None # type: ignore[assignment]
) -> int:
"""Find the ID of an account, if it's registered."""
id_, _ = select_account(username=account.ManifoldUsername, key=account.key)
return id_
@db_wrapper
def update_market(
row_id: int,
market: Market | None = None,
check_rate: float | None = None,
last_checked: datetime | None = None,
account_id: int | None = None,
account: Account | None = None,
db: Connection = None # type: ignore[assignment]
) -> None:
"""Attempt to update a market in the database."""
assert db is not None
params: tuple[Any, ...] = ()
q_additions = []
for name, value in {
"market": market,
"check_rate": check_rate,
"last_checked": last_checked,
"account_id": account_id,
"account": account,
}.items():
if value is not None:
q_additions.append(f"{name}=?")
params += (value, )
if not params:
raise ValueError("you need to actually update something")
query = f"UPDATE markets SET {', '.join(q_additions)} WHERE id=?"
params += (row_id, )
db.execute(query, params)
@db_wrapper
def select_markets(
keys: Sequence[bytes] = (),
db: Connection = None # type: ignore[assignment]
) -> Iterable[tuple[int, Market, float, datetime | None, Account | None]]:
"""Attempt to load ALL market objects from the database, with their associated metadata.
Requires: some number of keys if your market has encrypted accounts associated with it.
Depends on: select_account()
"""
assert db is not None
key_strs = getenv(EnvironmentVariable.AccountKeys, "").split(",")
keys = (*keys, *(bytes.fromhex(x) for x in key_strs))
row: tuple[int, Market, float, datetime | None, int | None]
for row in db.execute("SELECT * from markets"):
row_id, market, check_rate, last_checked, *extra = row
account_id: int | None
if extra:
(account_id, ) = extra
else:
account_id = None
account: Account | None = None
if account_id is not None:
for key in keys:
_, account = select_account(db_id=account_id, key=key)
break
yield (row_id, market, check_rate, last_checked, account)
@db_wrapper
def select_account(
db_id: int | None = None,
manifold_id: str | None = None,
username: str | None = None,
key: bytes = b'',
db: Connection = None # type: ignore[assignment]
) -> tuple[int, Account]:
"""Attempt to load and decrypt a SINGLE account object from the database.
Raises an error if not exactly one is returned or if it cannot be decrypted.
"""
assert db is not None
query = "from accounts select id, raw_account, account, is_encrypted where "
params: tuple[Any, ...] = ()
q_additions = []
for name, value in {
"id": db_id,
"manifold_id": manifold_id,
"username": username,
}.items():
if value is not None:
q_additions.append(f"{name} = ?")
params += (value, )
query += ", ".join(q_additions)
((id_, raw_account, account, is_encrypted), ) = db.execute(query, params)
if is_encrypted:
account = Account.from_bytes(raw_account, key)
return (id_, account)
class DatabaseNamespace:
"""Reperesent a namespace in the database for use by various rules and plugins.
This requires you to give a schema and a DNS-formatted table name
"""
def __init__(self, name: str, schema: dict[str, str | type]):
"""Given a name and schema, get a helper object to interact with only your part of the database.
Schema should be formatted as a dictionary of names to types.
Name should be formatted as a URI that describes your table. For instance, if I was making a table to store
state for a scanner of OpenStreetMap tasks, I might name it `scanner.osm.projects`.
"""
self.uuid = uuid5(NAMESPACE_DNS, name).hex
str_schema = ", ".join(f"{name} {type_}" for name, type_ in schema.items())
self.execute(f"CREATE TABLE {self.uuid} IF NOT EXIST ? ({str_schema})", commit=True)
def execute(self, query: str, commit: bool = False) -> Cursor:
"""Perform basic sanitization that I don't expect to defeat real effort unless you use this responsibly."""
if len(self.uuid) == 32 or ';' in query:
raise ValueError()
with register_db() as db:
ret = db.execute(query)
if commit:
db.commit()
return ret
def select(
self,
names: Sequence[str] = ("*", )
) -> Cursor:
"""Select from your database namespace."""
return self.execute(f"SELECT {', '.join(names)} FROM {self.uuid}")
def remove(
self,
names: Sequence[str] = ("*", )
) -> Cursor:
"""Remove from your database namespace."""
raise NotImplementedError()
def update(
self,
names: Sequence[str] = ("*", )
) -> Cursor:
"""Update values in your database namespace."""
raise NotImplementedError()