333 lines
11 KiB
Python
333 lines
11 KiB
Python
import sqlite3
|
|
import time
|
|
import os
|
|
from typing import Optional, List, Tuple
|
|
from post import Post
|
|
|
|
|
|
class DB:
|
|
settings: dict
|
|
name: str
|
|
connection: sqlite3.Connection
|
|
cursor: sqlite3.Cursor
|
|
|
|
def __init__(self, settings: dict) -> None:
|
|
self.settings = settings
|
|
db_file_path: str = self.settings["db"]["path"]
|
|
|
|
if os.path.exists(db_file_path):
|
|
self.connection = sqlite3.connect(db_file_path)
|
|
self.cursor = self.connection.cursor()
|
|
else:
|
|
self.connection = sqlite3.connect(db_file_path)
|
|
self.cursor = self.connection.cursor()
|
|
self.init_post_table()
|
|
self.create_default_posts()
|
|
|
|
def init_post_table(self) -> None:
|
|
self.cursor.executescript(
|
|
"""CREATE TABLE IF NOT EXISTS posts
|
|
(ip_address TEXT,
|
|
post_type TEXT,
|
|
post_id INTEGER,
|
|
timestamp INTEGER,
|
|
path TEXT,
|
|
title TEXT,
|
|
text TEXT,
|
|
options TEXT,
|
|
username TEXT,
|
|
cert_hash TEXT,
|
|
parent INTEGER,
|
|
num_children INTEGER,
|
|
score INTEGER);
|
|
|
|
CREATE TABLE IF NOT EXISTS users (username TEXT, client_cert TEXT, cert_hash TEXT, status TEXT);"""
|
|
)
|
|
|
|
self.connection.commit()
|
|
|
|
def latest_post_id(self) -> int:
|
|
latest_post: int = self.connection.execute(
|
|
"""SELECT MAX(post_id) FROM posts"""
|
|
).fetchone()[0]
|
|
return 0 if latest_post is None else latest_post
|
|
|
|
def create_default_posts(self) -> None:
|
|
for post in self.settings["defaultPosts"]:
|
|
post_settings: dict = self.settings["defaultPosts"][post]
|
|
self.add_post(
|
|
ip_address="localhost",
|
|
post_type=post_settings["postType"],
|
|
post_id=(self.latest_post_id() + 1),
|
|
timestamp=time.time_ns(),
|
|
path=post_settings["path"],
|
|
text=post_settings["text"],
|
|
title=post,
|
|
parent=post_settings["parent"],
|
|
)
|
|
|
|
def hash_to_cert(self, cert_hash: str) -> Optional[str]:
|
|
client_cert: Tuple[str] = self.cursor.execute("""SELECT client_cert FROM users WHERE cert_hash=?""", (cert_hash,)).fetchone()
|
|
return None if client_cert is None else client_cert[0]
|
|
|
|
def path_to_id(self, path: str) -> Optional[int]:
|
|
post_id: Tuple[int] = self.cursor.execute(
|
|
"""SELECT post_id FROM posts WHERE path=?""", (path,)
|
|
).fetchone()
|
|
|
|
return None if post_id is None else post_id[0]
|
|
|
|
def set_user_status(self, status: str, username: str) -> None:
|
|
self.cursor.execute(
|
|
"""UPDATE users SET status = ? WHERE username=?""", (status, username)
|
|
)
|
|
self.connection.commit()
|
|
|
|
def append_to_post(self, post_id: int, text: str) -> None:
|
|
current_text: str = self.cursor.execute(
|
|
"""SELECT text FROM posts WHERE post_id=?""", (post_id,)
|
|
).fetchone()[0]
|
|
if current_text is None:
|
|
self.cursor.execute(
|
|
"""UPDATE posts SET text = ? WHERE post_id=?""", (text, post_id)
|
|
)
|
|
else:
|
|
self.cursor.execute(
|
|
"""UPDATE posts SET text = text || ? WHERE post_id=?""", (text, post_id)
|
|
)
|
|
|
|
self.connection.commit()
|
|
|
|
def add_user(self, username: str, user_cert: str, cert_hash: str) -> bool:
|
|
user_exists: bool = bool(
|
|
self.cursor.execute(
|
|
"""SELECT * FROM users WHERE username=?""", (username,)
|
|
).fetchone()
|
|
)
|
|
if user_exists:
|
|
return False
|
|
else:
|
|
self.cursor.execute(
|
|
"""INSERT INTO users VALUES (?,?,?,?)""",
|
|
(username, user_cert, cert_hash, "registered"),
|
|
)
|
|
self.connection.commit()
|
|
return True
|
|
|
|
def get_user_cert(self, username: str) -> Optional[str]:
|
|
user_cert: Tuple[str] = self.cursor.execute(
|
|
"""SELECT client_cert FROM users WHERE username=?""", (username,)
|
|
).fetchone()
|
|
return None if user_cert is None else user_cert[0]
|
|
|
|
def del_post_by_id(self, post_id: int, children: Optional[bool] = True):
|
|
self.cursor.execute("""DELETE FROM posts WHERE post_id=?""", (post_id,))
|
|
|
|
if children:
|
|
self.cursor.execute("""DELETE FROM posts WHERE parent=?""", (post_id,))
|
|
|
|
self.connection.commit()
|
|
|
|
def update_post_score(self, post_id: int) -> None:
|
|
score: int = self.cursor.execute(
|
|
"""SELECT timestamp FROM posts WHERE post_id=?""", (post_id,)
|
|
).fetchone()[0]
|
|
|
|
parent_id: int = self.cursor.execute(
|
|
"""SELECT parent FROM posts WHERE post_id=?""", (post_id,)
|
|
).fetchone()[0]
|
|
|
|
child_ids: List[Tuple[int]] = self.cursor.execute(
|
|
"""SELECT post_id FROM posts WHERE parent=?""", (post_id,)
|
|
).fetchall()
|
|
|
|
num_children: int = int(len(child_ids))
|
|
if num_children != 0:
|
|
for child_id_tuple in child_ids:
|
|
child_id = child_id_tuple[0]
|
|
child_score: int = self.cursor.execute(
|
|
"""SELECT timestamp FROM posts WHERE post_id=?""", (child_id,)
|
|
).fetchone()[0]
|
|
score += child_score
|
|
score = score // (num_children + 1)
|
|
|
|
self.cursor.execute(
|
|
"""UPDATE posts SET score = ? WHERE post_id=?""", (score, post_id)
|
|
)
|
|
self.connection.commit()
|
|
|
|
def add_post(
|
|
self,
|
|
ip_address: str,
|
|
post_type: str,
|
|
post_id: int,
|
|
timestamp: int,
|
|
path: Optional[str] = None,
|
|
title: Optional[str] = None,
|
|
text: Optional[str] = None,
|
|
options: Optional[str] = None,
|
|
username: Optional[str] = None,
|
|
cert_hash: Optional[str] = None,
|
|
parent: Optional[int] = None,
|
|
) -> None:
|
|
parent_path: str
|
|
parent_type: str
|
|
num_children: int = 0
|
|
score: int = 0
|
|
|
|
if path is None:
|
|
path = str(post_id)
|
|
|
|
if parent is not None:
|
|
parent_path = self.cursor.execute(
|
|
"""SELECT path FROM posts WHERE post_id=?""", (parent,)
|
|
).fetchone()[0]
|
|
|
|
parent_type = self.cursor.execute(
|
|
"""SELECT post_type FROM posts WHERE post_id=?""", (parent,)
|
|
).fetchone()[0]
|
|
|
|
path = "{}/{}".format(parent_path, path)
|
|
|
|
|
|
post_tuple: tuple = (
|
|
ip_address,
|
|
post_type,
|
|
post_id,
|
|
timestamp,
|
|
path,
|
|
title,
|
|
text,
|
|
options,
|
|
username,
|
|
cert_hash,
|
|
parent,
|
|
num_children,
|
|
score,
|
|
)
|
|
self.cursor.execute(
|
|
"""INSERT INTO posts VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?)""", post_tuple
|
|
)
|
|
|
|
self.cursor.execute(
|
|
"""UPDATE posts SET num_children = num_children + 1 WHERE post_id=?""",
|
|
(parent,),
|
|
)
|
|
|
|
self.connection.commit()
|
|
if self.settings["postTypes"][post_type]["scored"]:
|
|
self.update_post_score(post_id)
|
|
|
|
if parent is not None and self.settings["postTypes"][parent_type]["scored"]:
|
|
self.update_post_score(parent)
|
|
|
|
|
|
def get_post_by_id(
|
|
self,
|
|
post_id: int,
|
|
page_num: Optional[int] = None,
|
|
) -> Optional[Post]:
|
|
post_data: tuple
|
|
post_type: str
|
|
max_children: int
|
|
paginated: bool = False
|
|
sorting: str
|
|
children: list = []
|
|
child_data: Optional[list] = None
|
|
num_children: int = 1
|
|
|
|
post_data = self.cursor.execute(
|
|
"""SELECT * FROM posts WHERE post_id=?""", (post_id,)
|
|
).fetchone()
|
|
|
|
if post_data is None:
|
|
return None
|
|
|
|
post_type = post_data[1]
|
|
|
|
sorting = self.settings["postTypes"][post_type]["sorting"]
|
|
|
|
num_children = self.cursor.execute(
|
|
"""SELECT COUNT (*) FROM posts WHERE parent=?""",
|
|
(post_id,),
|
|
).fetchone()[0]
|
|
|
|
max_children = self.settings["postTypes"][post_type]["maxChildren"]
|
|
paginated = self.settings["postTypes"][post_type]["paginated"]
|
|
|
|
if max_children is None:
|
|
child_data = self.cursor.execute(
|
|
"""SELECT * FROM posts WHERE parent=? ORDER BY %s DESC""" % sorting,
|
|
(post_id,),
|
|
).fetchall()
|
|
elif paginated:
|
|
max_page: int = int(num_children / max_children) + (
|
|
num_children % max_children > 0
|
|
) # math.ceil alternative
|
|
page_num = 1 if page_num is None or page_num < 1 else page_num
|
|
page_num = max_page if page_num > max_page else page_num
|
|
|
|
offset: int = (page_num * max_children) - max_children
|
|
|
|
child_data = self.cursor.execute(
|
|
"""SELECT * FROM posts WHERE parent=? ORDER BY %s DESC LIMIT ?, ?""" % sorting,
|
|
(post_id, offset, max_children),
|
|
).fetchall()
|
|
|
|
else:
|
|
child_data = self.cursor.execute(
|
|
"""SELECT * FROM posts WHERE parent=? ORDER BY %s DESC""" % sorting,
|
|
(post_id,),
|
|
).fetchmany(max_children)
|
|
|
|
if child_data is not None:
|
|
for child_datum in child_data:
|
|
children.append(
|
|
Post(
|
|
settings=self.settings,
|
|
post_type=child_datum[1],
|
|
post_id=child_datum[2],
|
|
timestamp=child_datum[3],
|
|
path=child_datum[4],
|
|
title=child_datum[5],
|
|
text=child_datum[6],
|
|
username=child_datum[8],
|
|
cert_hash=child_datum[9],
|
|
parent=child_datum[10],
|
|
num_children=child_datum[11],
|
|
)
|
|
)
|
|
else:
|
|
children = None
|
|
|
|
return Post(
|
|
settings=self.settings,
|
|
post_type=post_data[1],
|
|
post_id=post_data[2],
|
|
timestamp=post_data[3],
|
|
path=post_data[4],
|
|
title=post_data[5],
|
|
text=post_data[6],
|
|
username=post_data[8],
|
|
cert_hash=post_data[9],
|
|
parent=post_data[10],
|
|
num_children=post_data[11],
|
|
children=children,
|
|
current_page=page_num,
|
|
)
|
|
|
|
def get_post_by_path(
|
|
self,
|
|
path: str,
|
|
page_num: Optional[int] = None,
|
|
) -> Optional[Post]:
|
|
post_id: Optional[int] = self.path_to_id(path)
|
|
|
|
if post_id is None:
|
|
return None
|
|
else:
|
|
return self.get_post_by_id(
|
|
post_id=post_id,
|
|
page_num=page_num,
|
|
)
|