motm/db.py

333 lines
11 KiB
Python

import sqlite3
import time
import os
from typing import Optional, List, Tuple
from post import Post
class DB:
settings: dict
name: str
connection: sqlite3.Connection
cursor: sqlite3.Cursor
def __init__(self, settings: dict) -> None:
self.settings = settings
db_file_path: str = self.settings["db"]["path"]
if os.path.exists(db_file_path):
self.connection = sqlite3.connect(db_file_path)
self.cursor = self.connection.cursor()
else:
self.connection = sqlite3.connect(db_file_path)
self.cursor = self.connection.cursor()
self.init_post_table()
self.create_default_posts()
def init_post_table(self) -> None:
self.cursor.executescript(
"""CREATE TABLE IF NOT EXISTS posts
(ip_address TEXT,
post_type TEXT,
post_id INTEGER,
timestamp INTEGER,
path TEXT,
title TEXT,
text TEXT,
options TEXT,
username TEXT,
cert_hash TEXT,
parent INTEGER,
num_children INTEGER,
score INTEGER);
CREATE TABLE IF NOT EXISTS users (username TEXT, client_cert TEXT, cert_hash TEXT, status TEXT);"""
)
self.connection.commit()
def latest_post_id(self) -> int:
latest_post: int = self.connection.execute(
"""SELECT MAX(post_id) FROM posts"""
).fetchone()[0]
return 0 if latest_post is None else latest_post
def create_default_posts(self) -> None:
for post in self.settings["defaultPosts"]:
post_settings: dict = self.settings["defaultPosts"][post]
self.add_post(
ip_address="localhost",
post_type=post_settings["postType"],
post_id=(self.latest_post_id() + 1),
timestamp=time.time_ns(),
path=post_settings["path"],
text=post_settings["text"],
title=post,
parent=post_settings["parent"],
)
def hash_to_cert(self, cert_hash: str) -> Optional[str]:
client_cert: Tuple[str] = self.cursor.execute("""SELECT client_cert FROM users WHERE cert_hash=?""", (cert_hash,)).fetchone()
return None if client_cert is None else client_cert[0]
def path_to_id(self, path: str) -> Optional[int]:
post_id: Tuple[int] = self.cursor.execute(
"""SELECT post_id FROM posts WHERE path=?""", (path,)
).fetchone()
return None if post_id is None else post_id[0]
def set_user_status(self, status: str, username: str) -> None:
self.cursor.execute(
"""UPDATE users SET status = ? WHERE username=?""", (status, username)
)
self.connection.commit()
def append_to_post(self, post_id: int, text: str) -> None:
current_text: str = self.cursor.execute(
"""SELECT text FROM posts WHERE post_id=?""", (post_id,)
).fetchone()[0]
if current_text is None:
self.cursor.execute(
"""UPDATE posts SET text = ? WHERE post_id=?""", (text, post_id)
)
else:
self.cursor.execute(
"""UPDATE posts SET text = text || ? WHERE post_id=?""", (text, post_id)
)
self.connection.commit()
def add_user(self, username: str, user_cert: str, cert_hash: str) -> bool:
user_exists: bool = bool(
self.cursor.execute(
"""SELECT * FROM users WHERE username=?""", (username,)
).fetchone()
)
if user_exists:
return False
else:
self.cursor.execute(
"""INSERT INTO users VALUES (?,?,?,?)""",
(username, user_cert, cert_hash, "registered"),
)
self.connection.commit()
return True
def get_user_cert(self, username: str) -> Optional[str]:
user_cert: Tuple[str] = self.cursor.execute(
"""SELECT client_cert FROM users WHERE username=?""", (username,)
).fetchone()
return None if user_cert is None else user_cert[0]
def del_post_by_id(self, post_id: int, children: Optional[bool] = True):
self.cursor.execute("""DELETE FROM posts WHERE post_id=?""", (post_id,))
if children:
self.cursor.execute("""DELETE FROM posts WHERE parent=?""", (post_id,))
self.connection.commit()
def update_post_score(self, post_id: int) -> None:
score: int = self.cursor.execute(
"""SELECT timestamp FROM posts WHERE post_id=?""", (post_id,)
).fetchone()[0]
parent_id: int = self.cursor.execute(
"""SELECT parent FROM posts WHERE post_id=?""", (post_id,)
).fetchone()[0]
child_ids: List[Tuple[int]] = self.cursor.execute(
"""SELECT post_id FROM posts WHERE parent=?""", (post_id,)
).fetchall()
num_children: int = int(len(child_ids))
if num_children != 0:
for child_id_tuple in child_ids:
child_id = child_id_tuple[0]
child_score: int = self.cursor.execute(
"""SELECT timestamp FROM posts WHERE post_id=?""", (child_id,)
).fetchone()[0]
score += child_score
score = score // (num_children + 1)
self.cursor.execute(
"""UPDATE posts SET score = ? WHERE post_id=?""", (score, post_id)
)
self.connection.commit()
def add_post(
self,
ip_address: str,
post_type: str,
post_id: int,
timestamp: int,
path: Optional[str] = None,
title: Optional[str] = None,
text: Optional[str] = None,
options: Optional[str] = None,
username: Optional[str] = None,
cert_hash: Optional[str] = None,
parent: Optional[int] = None,
) -> None:
parent_path: str
parent_type: str
num_children: int = 0
score: int = 0
if path is None:
path = str(post_id)
if parent is not None:
parent_path = self.cursor.execute(
"""SELECT path FROM posts WHERE post_id=?""", (parent,)
).fetchone()[0]
parent_type = self.cursor.execute(
"""SELECT post_type FROM posts WHERE post_id=?""", (parent,)
).fetchone()[0]
path = "{}/{}".format(parent_path, path)
post_tuple: tuple = (
ip_address,
post_type,
post_id,
timestamp,
path,
title,
text,
options,
username,
cert_hash,
parent,
num_children,
score,
)
self.cursor.execute(
"""INSERT INTO posts VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?)""", post_tuple
)
self.cursor.execute(
"""UPDATE posts SET num_children = num_children + 1 WHERE post_id=?""",
(parent,),
)
self.connection.commit()
if self.settings["postTypes"][post_type]["scored"]:
self.update_post_score(post_id)
if parent is not None and self.settings["postTypes"][parent_type]["scored"]:
self.update_post_score(parent)
def get_post_by_id(
self,
post_id: int,
page_num: Optional[int] = None,
) -> Optional[Post]:
post_data: tuple
post_type: str
max_children: int
paginated: bool = False
sorting: str
children: list = []
child_data: Optional[list] = None
num_children: int = 1
post_data = self.cursor.execute(
"""SELECT * FROM posts WHERE post_id=?""", (post_id,)
).fetchone()
if post_data is None:
return None
post_type = post_data[1]
sorting = self.settings["postTypes"][post_type]["sorting"]
num_children = self.cursor.execute(
"""SELECT COUNT (*) FROM posts WHERE parent=?""",
(post_id,),
).fetchone()[0]
max_children = self.settings["postTypes"][post_type]["maxChildren"]
paginated = self.settings["postTypes"][post_type]["paginated"]
if max_children is None:
child_data = self.cursor.execute(
"""SELECT * FROM posts WHERE parent=? ORDER BY %s DESC""" % sorting,
(post_id,),
).fetchall()
elif paginated:
max_page: int = int(num_children / max_children) + (
num_children % max_children > 0
) # math.ceil alternative
page_num = 1 if page_num is None or page_num < 1 else page_num
page_num = max_page if page_num > max_page else page_num
offset: int = (page_num * max_children) - max_children
child_data = self.cursor.execute(
"""SELECT * FROM posts WHERE parent=? ORDER BY %s DESC LIMIT ?, ?""" % sorting,
(post_id, offset, max_children),
).fetchall()
else:
child_data = self.cursor.execute(
"""SELECT * FROM posts WHERE parent=? ORDER BY %s DESC""" % sorting,
(post_id,),
).fetchmany(max_children)
if child_data is not None:
for child_datum in child_data:
children.append(
Post(
settings=self.settings,
post_type=child_datum[1],
post_id=child_datum[2],
timestamp=child_datum[3],
path=child_datum[4],
title=child_datum[5],
text=child_datum[6],
username=child_datum[8],
cert_hash=child_datum[9],
parent=child_datum[10],
num_children=child_datum[11],
)
)
else:
children = None
return Post(
settings=self.settings,
post_type=post_data[1],
post_id=post_data[2],
timestamp=post_data[3],
path=post_data[4],
title=post_data[5],
text=post_data[6],
username=post_data[8],
cert_hash=post_data[9],
parent=post_data[10],
num_children=post_data[11],
children=children,
current_page=page_num,
)
def get_post_by_path(
self,
path: str,
page_num: Optional[int] = None,
) -> Optional[Post]:
post_id: Optional[int] = self.path_to_id(path)
if post_id is None:
return None
else:
return self.get_post_by_id(
post_id=post_id,
page_num=page_num,
)