122 lines
3.0 KiB
Python
122 lines
3.0 KiB
Python
|
#!/usr/bin/env python3
|
||
|
import json
|
||
|
import time
|
||
|
import os
|
||
|
import cmd
|
||
|
import argparse
|
||
|
import threading
|
||
|
|
||
|
from typing import Optional, Union
|
||
|
from getpass import getpass
|
||
|
from auth import CA
|
||
|
from db import DB
|
||
|
from post import Post
|
||
|
|
||
|
|
||
|
class CommandLine(cmd.Cmd):
|
||
|
def __init__(self, ca: CA, db: DB) -> None:
|
||
|
super().__init__()
|
||
|
self.intro: str = "hi."
|
||
|
self.prompt: str = "motm> "
|
||
|
self.ca = ca
|
||
|
self.db = db
|
||
|
|
||
|
def do_get_post(self, arg) -> None:
|
||
|
post_id: int
|
||
|
if not arg.isnumeric():
|
||
|
print("Please supply an integer post id")
|
||
|
return
|
||
|
|
||
|
post_id = int(arg)
|
||
|
|
||
|
post = self.db.get_post_by_id(post_id)
|
||
|
if post is None:
|
||
|
print("Post with id", post_id, "not found")
|
||
|
return
|
||
|
|
||
|
print(post.render())
|
||
|
|
||
|
def do_del_post(self, arg) -> None:
|
||
|
post_id: int
|
||
|
if not arg.isnumeric():
|
||
|
print("Please supply an integer post id")
|
||
|
return
|
||
|
post_id = int(arg)
|
||
|
self.db.del_post_by_id(post_id)
|
||
|
print("Deleted post", post_id)
|
||
|
|
||
|
def do_ban(self, arg) -> None:
|
||
|
username: str = arg
|
||
|
user_cert: Optional[str] = self.db.get_user_cert(username)
|
||
|
if user_cert is None:
|
||
|
print("User", username, "not found")
|
||
|
return
|
||
|
|
||
|
self.ca.revoke_cert(user_cert)
|
||
|
self.db.set_user_status("banned", username)
|
||
|
print("Banned user", username)
|
||
|
|
||
|
def do_unban(self, arg) -> None:
|
||
|
username: str = arg
|
||
|
user_cert = self.db.get_user_cert(username)
|
||
|
if user_cert is None:
|
||
|
print("User", username, "not found")
|
||
|
return
|
||
|
|
||
|
self.ca.unrevoke_cert(user_cert)
|
||
|
self.db.set_user_status("registered", username)
|
||
|
print("Unbanned user", username)
|
||
|
|
||
|
def end(self) -> None:
|
||
|
print("bye.")
|
||
|
return
|
||
|
|
||
|
|
||
|
def main():
|
||
|
settings_path: str = "config/settings.json"
|
||
|
settings: dict
|
||
|
ca_password: str = None
|
||
|
ca: CA
|
||
|
db: DB
|
||
|
command_line: CommandLine
|
||
|
|
||
|
with open(settings_path, "r") as settings_file:
|
||
|
settings = json.load(settings_file)
|
||
|
|
||
|
if (
|
||
|
os.path.exists(settings["auth"]["caCert"])
|
||
|
and os.path.exists(settings["auth"]["caKey"])
|
||
|
and os.path.exists(settings["auth"]["caCRL"])
|
||
|
):
|
||
|
ca = CA(
|
||
|
settings["server"]["hostname"],
|
||
|
settings["auth"]["caCert"],
|
||
|
settings["auth"]["caKey"],
|
||
|
settings["auth"]["caCRL"],
|
||
|
ca_password,
|
||
|
)
|
||
|
else:
|
||
|
print("Some or all of the CA files are missing. Create a new CA?",
|
||
|
"\n(this will overwrite any existing CA files) [y/n]")
|
||
|
|
||
|
if input().lower() != "y":
|
||
|
return
|
||
|
|
||
|
ca = CA.new_ca(
|
||
|
cert_dir=settings["auth"]["certDir"],
|
||
|
hostname=settings["server"]["hostname"],
|
||
|
password=ca_password,
|
||
|
)
|
||
|
|
||
|
db = DB(settings)
|
||
|
command_line = CommandLine(ca, db)
|
||
|
|
||
|
try:
|
||
|
command_line.cmdloop()
|
||
|
except KeyboardInterrupt:
|
||
|
command_line.end()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|