diff --git a/config.py b/config.py index 1066fb0..bbb20e4 100644 --- a/config.py +++ b/config.py @@ -78,9 +78,10 @@ def init(): USER.datadir.mkdir(mode=0o755, exist_ok=True) if not USER.settingsfile.is_file(): - with open(USER.settingsfile, 'w+') as the_file: - the_file.write("[User Status]\nlastlogin = 0.0\n\n[User Settings]\nbrowser = lynx") - + with open(USER.settingsfile, "w+") as the_file: + the_file.write( + "[User Status]\nlastlogin = 0.0\n\n[User Settings]\nbrowser = lynx" + ) if not is_readable(USER.datadir.stat().st_mode): print( diff --git a/data.py b/data.py index f336c5b..ace7728 100644 --- a/data.py +++ b/data.py @@ -158,7 +158,13 @@ class LinkData: new_post_id = -1 if record.category: - new_post_id = max([record[0] for record in self.link_data if record[0]]) + 1 + if self.link_data: + new_post_id = ( + max([record[0] if record[0] else 0 for record in self.link_data]) + + 1 + ) + else: + new_post_id = 1 record = record._replace(ID_if_parent=new_post_id) self.link_data.insert(0, list(record)) self.generate_category_data() @@ -190,7 +196,11 @@ class LinkData: if keyword == "": raise ValueError("a search keyword must be specified") - query = (record for record in self.link_data if keyword.lower() in str(record).lower()) + query = ( + record + for record in self.link_data + if keyword.lower() in str(record).lower() + ) if query: search_results: set = set() diff --git a/linkulator.py b/linkulator.py index ce41ce1..77f03cc 100755 --- a/linkulator.py +++ b/linkulator.py @@ -24,18 +24,23 @@ categories: list = LinkData.categories def print_categories(): """Prints the list of categories with an indicator for new activity""" - print("\n{:>4s} New {:<25s}".format("ID#", "Category")) + header = "\n{:>4s} New {:<25s}".format("ID#", "Category") + out = "" for i, record in enumerate(categories): - print( - "{:4d} {} {} ({})".format( - i + 1, - "x" if record["last_updated"] >= config.USER.lastlogin else " ", - record["name"], - record["count"], - ) + out += "{:4d} {} {} ({})\n".format( + i + 1, + "x" if record["last_updated"] >= config.USER.lastlogin else " ", + record["name"], + record["count"], ) + if len(out) > 0: + print(header) + print(out) + else: + print("\n There are no posts yet - enter p to post a new link\n") + def print_category_details(view_cat): """produces category detail data, prints it to the console. returns dict @@ -90,7 +95,9 @@ def print_thread_details(post_id) -> tuple: raise ValueError("Sorry, no thread found with that ID.") # get replies data - replies = sorted([line for line in link_data if line[3] == parent_id], key=lambda x: x[2]) + replies = sorted( + [line for line in link_data if line[3] == parent_id], key=lambda x: x[2] + ) # post detail view print("\n\n{:<17}: {}".format(style_text("Title", "bold"), post_title)) @@ -173,6 +180,7 @@ def search(): # Catch a Post ID that is not in the thread list or is not a number print("\n{}\n".format(style_text("Invalid entry", "bold"))) + def view_link_in_browser(url): """Attempts to view the specified URL in the configured browser""" if which(config.USER.browser) is None: @@ -443,7 +451,12 @@ def main(): py_ver = str(sys.version_info[0]) + "." + str(sys.version_info[1]) py_ver = float(py_ver) if py_ver < low_py_ver: - raise Exception("Must be using Python " + str(low_py_ver) + " or higher. Instead you're using " + str(py_ver)) + raise Exception( + "Must be using Python " + + str(low_py_ver) + + " or higher. Instead you're using " + + str(py_ver) + ) signal.signal(signal.SIGINT, signal_handler) args = sys.argv[1:] diff --git a/tests/data_test.py b/tests/data_test.py index f9c5854..1cb586f 100644 --- a/tests/data_test.py +++ b/tests/data_test.py @@ -1,6 +1,7 @@ """unit tests for the data module""" import unittest import unittest.mock +from unittest.mock import mock_open from time import time import data @@ -214,5 +215,45 @@ class TestLinkDataSearch(unittest.TestCase): self.assertEqual(link_data.search("keyword"), test_results) +class TestAddLink(unittest.TestCase): + """Tests Add method of LinkData class""" + + @unittest.mock.patch.object(data.LinkData, "get") + def test_add_to_empty_db(self, mock_get): + """Tests adding to an empty db as per issue #76""" + link_data = data.LinkData() + mock_get.assert_called() + + link_data.link_data = [] + test_record = data.LinkDataRecord( + username="testuser", + timestamp="1000", + category="test_category", + link_URL="test_url", + link_title_or_comment="test_title", + ) + + with unittest.mock.patch("builtins.open", mock_open()) as mock_file: + + result = link_data.add(test_record) + + mock_file.assert_called() + self.assertEqual(1, result) + self.assertListEqual( + link_data.link_data, + [ + [ + 1, + "testuser", + "1000", + "", + "test_category", + "test_url", + "test_title", + ] + ], + ) + + if __name__ == "__main__": unittest.main()