diff --git a/av98.py b/av98.py
index 7671ee8..c877f95 100755
--- a/av98.py
+++ b/av98.py
@@ -287,6 +287,9 @@ class GeminiClient(cmd.Cmd):
self._connect_to_tofu_db()
+ self.cache = {}
+ self.cache_timestamps = {}
+
def _connect_to_tofu_db(self):
db_path = os.path.join(self.config_dir, "tofu.db")
@@ -303,6 +306,7 @@ class GeminiClient(cmd.Cmd):
sending the request over the network, parsing the response if
its a menu, storing the response in a temporary file, choosing
and calling a handler program, and updating the history."""
+
# Don't try to speak to servers running other protocols
if gi.scheme in ("http", "https"):
webbrowser.open_new_tab(gi.url)
@@ -316,12 +320,58 @@ you'll be able to transparently follow links to Gopherspace!""")
elif gi.scheme not in ("gemini", "gopher"):
print("Sorry, no support for {} links.".format(gi.scheme))
return
+
# Obey permanent redirects
if gi.url in self.permanent_redirects:
new_gi = GeminiItem(self.permanent_redirects[gi.url], name=gi.name)
self._go_to_gi(new_gi)
return
+ # Use cache, or hit the network if resource is not cached
+ if self._is_cached(gi.url):
+ mime, body, tmpfile = self._get_cached(gi.url)
+ else:
+ try:
+ gi, mime, body, tmpfile = self._fetch_over_network(gi)
+ except Exception as err:
+ # Print an error message
+ if isinstance(err, socket.gaierror):
+ self.log["dns_failures"] += 1
+ print("ERROR: DNS error!")
+ elif isinstance(err, ConnectionRefusedError):
+ self.log["refused_connections"] += 1
+ print("ERROR: Connection refused!")
+ elif isinstance(err, ConnectionResetError):
+ self.log["reset_connections"] += 1
+ print("ERROR: Connection reset!")
+ elif isinstance(err, (TimeoutError, socket.timeout)):
+ self.log["timeouts"] += 1
+ print("""ERROR: Connection timed out!
+ Slow internet connection? Use 'set timeout' to be more patient.""")
+ else:
+ print("ERROR: " + str(err))
+ return
+
+ # Pass file to handler, unless we were asked not to
+ if handle:
+ if mime == "text/gemini":
+ self._handle_gemtext(body, gi)
+ else:
+ cmd_str = self._get_handler_cmd(mime)
+ try:
+ subprocess.call(shlex.split(cmd_str % tmpf.name))
+ except FileNotFoundError:
+ print("Handler program %s not found!" % shlex.split(cmd_str)[0])
+ print("You can use the ! command to specify another handler program or pipeline.")
+
+ # Update state
+ self.gi = gi
+ self.mime = mime
+ if update_hist:
+ self._update_history(gi)
+
+ def _fetch_over_network(self, gi):
+
# Be careful with client certificates!
# Are we crossing a domain boundary?
if self.active_cert_domains and gi.host not in self.active_cert_domains:
@@ -353,50 +403,26 @@ you'll be able to transparently follow links to Gopherspace!""")
print("Remaining unidentified.")
self.client_certs.pop(gi.host)
- # Do everything which touches the network in one block,
- # so we only need to catch exceptions once
- try:
- # Is this a local file?
- if not gi.host:
- address, f = None, open(gi.path, "rb")
- else:
- address, f = self._send_request(gi)
+ # Is this a local file?
+ if not gi.host:
+ address, f = None, open(gi.path, "rb")
+ else:
+ address, f = self._send_request(gi)
- # Spec dictates should not exceed 1024 bytes,
- # so maximum valid header length is 1027 bytes.
- header = f.readline(1027)
- header = header.decode("UTF-8")
- if not header or header[-1] != '\n':
- raise RuntimeError("Received invalid header from server!")
- header = header.strip()
- self._debug("Response header: %s." % header)
-
- # Catch network errors which may happen on initial connection
- except Exception as err:
- # Print an error message
- if isinstance(err, socket.gaierror):
- self.log["dns_failures"] += 1
- print("ERROR: DNS error!")
- elif isinstance(err, ConnectionRefusedError):
- self.log["refused_connections"] += 1
- print("ERROR: Connection refused!")
- elif isinstance(err, ConnectionResetError):
- self.log["reset_connections"] += 1
- print("ERROR: Connection reset!")
- elif isinstance(err, (TimeoutError, socket.timeout)):
- self.log["timeouts"] += 1
- print("""ERROR: Connection timed out!
-Slow internet connection? Use 'set timeout' to be more patient.""")
- else:
- print("ERROR: " + str(err))
- return
+ # Spec dictates should not exceed 1024 bytes,
+ # so maximum valid header length is 1027 bytes.
+ header = f.readline(1027)
+ header = header.decode("UTF-8")
+ if not header or header[-1] != '\n':
+ raise RuntimeError("Received invalid header from server!")
+ header = header.strip()
+ self._debug("Response header: %s." % header)
# Validate header
status, meta = header.split(maxsplit=1)
if len(meta) > 1024 or len(status) != 2 or not status.isnumeric():
- print("ERROR: Received invalid header from server!")
f.close()
- return
+ raise RuntimeError("Received invalid header from server!")
# Update redirect loop/maze escaping state
if not status.startswith("3"):
@@ -410,20 +436,17 @@ Slow internet connection? Use 'set timeout' to be more patient.""")
user_input = getpass.getpass("> ")
else:
user_input = input("> ")
- self._go_to_gi(gi.query(user_input))
- return
+ return self._fetch_over_network(gi.query(user_input))
+
# Redirects
elif status.startswith("3"):
new_gi = GeminiItem(gi.absolutise_url(meta))
if new_gi.url == gi.url:
- print("Error: URL redirects to itself!")
- return
+ raise RuntimeError("URL redirects to itself!")
elif new_gi.url in self.previous_redirectors:
- print("Error: caught in redirect loop!")
- return
+ raise RuntimeError("Caught in redirect loop!")
elif len(self.previous_redirectors) == _MAX_REDIRECTS:
- print("Error: refusing to follow more than %d consecutive redirects!" % _MAX_REDIRECTS)
- return
+ raise RuntimeError("Refusing to follow more than %d consecutive redirects!" % _MAX_REDIRECTS)
# Never follow cross-domain redirects without asking
elif new_gi.host != gi.host:
follow = input("Follow cross-domain redirect to %s? (y/n) " % new_gi.url)
@@ -444,12 +467,12 @@ Slow internet connection? Use 'set timeout' to be more patient.""")
if status == "31":
# Permanent redirect
self.permanent_redirects[gi.url] = new_gi.url
- self._go_to_gi(new_gi)
- return
+ return self._fetch_over_network(new_gi)
+
# Errors
elif status.startswith("4") or status.startswith("5"):
- print("Error: %s" % meta)
- return
+ raise RuntimeError(meta)
+
# Client cert
elif status.startswith("6"):
# Don't do client cert stuff in restricted mode, as in principle
@@ -498,8 +521,7 @@ Slow internet connection? Use 'set timeout' to be more patient.""")
# Invalid status
elif not status.startswith("2"):
- print("ERROR: Server returned undefined status code %s!" % status)
- return
+ raise RuntimeError("Server returned undefined status code %s!" % status)
# If we're here, this must be a success and there's a response body
assert status.startswith("2")
@@ -512,16 +534,12 @@ Slow internet connection? Use 'set timeout' to be more patient.""")
try:
codecs.lookup(mime_options["charset"])
except LookupError:
- print("Header declared unknown encoding %s" % value)
- return
+ raise RuntimeError("Header declared unknown encoding %s" % value)
# Read the response body over the network
body = f.read()
# Save the result in a temporary file
- ## Delete old file
- if self.tmp_filename and os.path.exists(self.tmp_filename):
- os.unlink(self.tmp_filename)
## Set file mode
if mime.startswith("text/"):
mode = "w"
@@ -529,8 +547,7 @@ Slow internet connection? Use 'set timeout' to be more patient.""")
try:
body = body.decode(encoding)
except UnicodeError:
- print("Could not decode response body using %s encoding declared in header!" % encoding)
- return
+ raise RuntimeError("Could not decode response body using %s encoding declared in header!" % encoding)
else:
mode = "wb"
encoding = None
@@ -541,24 +558,11 @@ Slow internet connection? Use 'set timeout' to be more patient.""")
self.tmp_filename = tmpf.name
self._debug("Wrote %d byte response to %s." % (size, self.tmp_filename))
- # Pass file to handler, unless we were asked not to
- if handle:
- if mime == "text/gemini":
- self._handle_gemtext(body, gi)
- else:
- cmd_str = self._get_handler_cmd(mime)
- try:
- subprocess.call(shlex.split(cmd_str % tmpf.name))
- except FileNotFoundError:
- print("Handler program %s not found!" % shlex.split(cmd_str)[0])
- print("You can use the ! command to specify another handler program or pipeline.")
-
- # Update state
- self.gi = gi
- self.mime = mime
+ # Maintain cache and log
+ self._add_to_cache(gi.url, mime, tmpf.name)
self._log_visit(gi, address, size)
- if update_hist:
- self._update_history(gi)
+
+ return gi, mime, body, tmpf
def _send_request(self, gi):
"""Send a selector to a given host and port.
@@ -663,6 +667,69 @@ Slow internet connection? Use 'set timeout' to be more patient.""")
return addresses
+ def _is_cached(self, url):
+ if url not in self.cache:
+ return False
+ now = time.time()
+ cached = self.cache_timestamps[url]
+ if now - cached > 180:
+ self._debug("Expiring old cached copy of resource.")
+ self._remove_from_cache(url)
+ return False
+ self._debug("Found cached copy of resource.")
+ return True
+
+ def _remove_from_cache(self, url):
+ self.cache_timestamps.pop(url)
+ mime, filename = self.cache.pop(url)
+ os.unlink(filename)
+ self._validate_cache()
+
+ def _add_to_cache(self, url, mime, filename):
+
+ self.cache_timestamps[url] = time.time()
+ self.cache[url] = (mime, filename)
+ if len(self.cache) > 10:
+ self._trim_cache()
+ self._validate_cache()
+
+ def _trim_cache(self):
+ # Order cache entries by age
+ lru = [(t, u) for (u, t) in self.cache_timestamps.items()]
+ lru.sort()
+ # Drop the oldest entry no matter what
+ _, url = lru[0]
+ self._debug("Dropping cached copy of {} from full cache.".format(url))
+ self._remove_from_cache(url)
+ # Drop other entries if they are older than the limit
+ now = time.time()
+ for cached, url in lru[1:]:
+ if now - cached > 180:
+ self._debug("Dropping cached copy of {} from full cache.".format(url))
+ self._remove_from_cache(url)
+ else:
+ break
+ self._validate_cache()
+
+ def _get_cached(self, url):
+ mime, filename = self.cache[url]
+ if mime.startswith("text/gemini"):
+ with open(filename, "r") as fp:
+ body = fp.read()
+ return mime, body, filename
+ else:
+ return mime, None, filename
+
+ def _empty_cache(self):
+ for mime, filename in self.cache.values():
+ if os.path.exists(filename):
+ os.unlink(filename)
+
+ def _validate_cache(self):
+ assert self.cache.keys() == self.cache_timestamps.keys()
+ for _, filename in self.cache.values():
+ assert os.path.isfile(filename)
+
def _validate_cert(self, address, host, cert):
"""
Validate a TLS certificate in TOFU mode.
@@ -1483,10 +1550,12 @@ current gemini browsing session."""
self.db_conn.commit()
self.db_conn.close()
# Clean up after ourself
+ self._empty_cache()
if self.tmp_filename and os.path.exists(self.tmp_filename):
os.unlink(self.tmp_filename)
if self.idx_filename and os.path.exists(self.idx_filename):
os.unlink(self.idx_filename)
+
for cert in self.transient_certs_created:
for ext in (".crt", ".key"):
certfile = os.path.join(self.config_dir, "transient_certs", cert+ext)