import sqlite3 import time import os from typing import Optional, List, Tuple from post import Post class DB: settings: dict name: str connection: sqlite3.Connection cursor: sqlite3.Cursor def __init__(self, settings: dict) -> None: self.settings = settings db_file_path: str = self.settings["db"]["path"] if os.path.exists(db_file_path): self.connection = sqlite3.connect(db_file_path) self.cursor = self.connection.cursor() else: self.connection = sqlite3.connect(db_file_path) self.cursor = self.connection.cursor() self.init_post_table() self.create_default_posts() def init_post_table(self) -> None: self.cursor.executescript( """CREATE TABLE IF NOT EXISTS posts (ip_address TEXT, post_type TEXT, post_id INTEGER, timestamp INTEGER, path TEXT, title TEXT, text TEXT, options TEXT, username TEXT, cert_hash TEXT, parent INTEGER, num_children INTEGER, score INTEGER); CREATE TABLE IF NOT EXISTS users (username TEXT, client_cert TEXT, cert_hash TEXT, status TEXT);""" ) self.connection.commit() def latest_post_id(self) -> int: latest_post: int = self.connection.execute( """SELECT MAX(post_id) FROM posts""" ).fetchone()[0] return 0 if latest_post is None else latest_post def create_default_posts(self) -> None: for post in self.settings["defaultPosts"]: post_settings: dict = self.settings["defaultPosts"][post] self.add_post( ip_address="localhost", post_type=post_settings["postType"], post_id=(self.latest_post_id() + 1), timestamp=time.time_ns(), path=post_settings["path"], text=post_settings["text"], title=post, parent=post_settings["parent"], ) def hash_to_cert(self, cert_hash: str) -> Optional[str]: client_cert: Tuple[str] = self.cursor.execute("""SELECT client_cert FROM users WHERE cert_hash=?""", (cert_hash,)).fetchone() return None if client_cert is None else client_cert[0] def path_to_id(self, path: str) -> Optional[int]: post_id: Tuple[int] = self.cursor.execute( """SELECT post_id FROM posts WHERE path=?""", (path,) ).fetchone() return None if post_id is None else post_id[0] def set_user_status(self, status: str, username: str) -> None: self.cursor.execute( """UPDATE users SET status = ? WHERE username=?""", (status, username) ) self.connection.commit() def append_to_post(self, post_id: int, text: str) -> None: current_text: str = self.cursor.execute( """SELECT text FROM posts WHERE post_id=?""", (post_id,) ).fetchone()[0] if current_text is None: self.cursor.execute( """UPDATE posts SET text = ? WHERE post_id=?""", (text, post_id) ) else: self.cursor.execute( """UPDATE posts SET text = text || ? WHERE post_id=?""", (text, post_id) ) self.connection.commit() def add_user(self, username: str, user_cert: str, cert_hash: str) -> bool: user_exists: bool = bool( self.cursor.execute( """SELECT * FROM users WHERE username=?""", (username,) ).fetchone() ) if user_exists: return False else: self.cursor.execute( """INSERT INTO users VALUES (?,?,?,?)""", (username, user_cert, cert_hash, "registered"), ) self.connection.commit() return True def get_user_cert(self, username: str) -> Optional[str]: user_cert: Tuple[str] = self.cursor.execute( """SELECT client_cert FROM users WHERE username=?""", (username,) ).fetchone() return None if user_cert is None else user_cert[0] def del_post_by_id(self, post_id: int, children: Optional[bool] = True): self.cursor.execute("""DELETE FROM posts WHERE post_id=?""", (post_id,)) if children: self.cursor.execute("""DELETE FROM posts WHERE parent=?""", (post_id,)) self.connection.commit() def update_post_score(self, post_id: int) -> None: score: int = self.cursor.execute( """SELECT timestamp FROM posts WHERE post_id=?""", (post_id,) ).fetchone()[0] parent_id: int = self.cursor.execute( """SELECT parent FROM posts WHERE post_id=?""", (post_id,) ).fetchone()[0] child_ids: List[Tuple[int]] = self.cursor.execute( """SELECT post_id FROM posts WHERE parent=?""", (post_id,) ).fetchall() num_children: int = int(len(child_ids)) if num_children != 0: for child_id_tuple in child_ids: child_id = child_id_tuple[0] child_score: int = self.cursor.execute( """SELECT timestamp FROM posts WHERE post_id=?""", (child_id,) ).fetchone()[0] score += child_score score = score // (num_children + 1) self.cursor.execute( """UPDATE posts SET score = ? WHERE post_id=?""", (score, post_id) ) self.connection.commit() def add_post( self, ip_address: str, post_type: str, post_id: int, timestamp: int, path: Optional[str] = None, title: Optional[str] = None, text: Optional[str] = None, options: Optional[str] = None, username: Optional[str] = None, cert_hash: Optional[str] = None, parent: Optional[int] = None, ) -> None: parent_path: str parent_type: str num_children: int = 0 score: int = 0 if path is None: path = str(post_id) if parent is not None: parent_path = self.cursor.execute( """SELECT path FROM posts WHERE post_id=?""", (parent,) ).fetchone()[0] parent_type = self.cursor.execute( """SELECT post_type FROM posts WHERE post_id=?""", (parent,) ).fetchone()[0] path = "{}/{}".format(parent_path, path) post_tuple: tuple = ( ip_address, post_type, post_id, timestamp, path, title, text, options, username, cert_hash, parent, num_children, score, ) self.cursor.execute( """INSERT INTO posts VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?)""", post_tuple ) self.cursor.execute( """UPDATE posts SET num_children = num_children + 1 WHERE post_id=?""", (parent,), ) self.connection.commit() if self.settings["postTypes"][post_type]["scored"]: self.update_post_score(post_id) if parent is not None and self.settings["postTypes"][parent_type]["scored"]: self.update_post_score(parent) def get_post_by_id( self, post_id: int, page_num: Optional[int] = None, ) -> Optional[Post]: post_data: tuple post_type: str max_children: int paginated: bool = False sorting: str children: list = [] child_data: Optional[list] = None num_children: int = 1 post_data = self.cursor.execute( """SELECT * FROM posts WHERE post_id=?""", (post_id,) ).fetchone() if post_data is None: return None post_type = post_data[1] sorting = self.settings["postTypes"][post_type]["sorting"] num_children = self.cursor.execute( """SELECT COUNT (*) FROM posts WHERE parent=?""", (post_id,), ).fetchone()[0] max_children = self.settings["postTypes"][post_type]["maxChildren"] paginated = self.settings["postTypes"][post_type]["paginated"] if max_children is None: child_data = self.cursor.execute( """SELECT * FROM posts WHERE parent=? ORDER BY %s DESC""" % sorting, (post_id,), ).fetchall() elif paginated: max_page: int = int(num_children / max_children) + ( num_children % max_children > 0 ) # math.ceil alternative page_num = 1 if page_num is None or page_num < 1 else page_num page_num = max_page if page_num > max_page else page_num offset: int = (page_num * max_children) - max_children child_data = self.cursor.execute( """SELECT * FROM posts WHERE parent=? ORDER BY %s DESC LIMIT ?, ?""" % sorting, (post_id, offset, max_children), ).fetchall() else: child_data = self.cursor.execute( """SELECT * FROM posts WHERE parent=? ORDER BY %s DESC""" % sorting, (post_id,), ).fetchmany(max_children) if child_data is not None: for child_datum in child_data: children.append( Post( settings=self.settings, post_type=child_datum[1], post_id=child_datum[2], timestamp=child_datum[3], path=child_datum[4], title=child_datum[5], text=child_datum[6], username=child_datum[8], cert_hash=child_datum[9], parent=child_datum[10], num_children=child_datum[11], ) ) else: children = None return Post( settings=self.settings, post_type=post_data[1], post_id=post_data[2], timestamp=post_data[3], path=post_data[4], title=post_data[5], text=post_data[6], username=post_data[8], cert_hash=post_data[9], parent=post_data[10], num_children=post_data[11], children=children, current_page=page_num, ) def get_post_by_path( self, path: str, page_num: Optional[int] = None, ) -> Optional[Post]: post_id: Optional[int] = self.path_to_id(path) if post_id is None: return None else: return self.get_post_by_id( post_id=post_id, page_num=page_num, )