Initial implementation of short-term caching.

This commit is contained in:
Solderpunk 2020-08-30 20:21:15 +02:00
parent 4d652e0fef
commit 0f328141b9
1 changed files with 145 additions and 76 deletions

221
av98.py
View File

@ -287,6 +287,9 @@ class GeminiClient(cmd.Cmd):
self._connect_to_tofu_db()
self.cache = {}
self.cache_timestamps = {}
def _connect_to_tofu_db(self):
db_path = os.path.join(self.config_dir, "tofu.db")
@ -303,6 +306,7 @@ class GeminiClient(cmd.Cmd):
sending the request over the network, parsing the response if
its a menu, storing the response in a temporary file, choosing
and calling a handler program, and updating the history."""
# Don't try to speak to servers running other protocols
if gi.scheme in ("http", "https"):
webbrowser.open_new_tab(gi.url)
@ -316,12 +320,58 @@ you'll be able to transparently follow links to Gopherspace!""")
elif gi.scheme not in ("gemini", "gopher"):
print("Sorry, no support for {} links.".format(gi.scheme))
return
# Obey permanent redirects
if gi.url in self.permanent_redirects:
new_gi = GeminiItem(self.permanent_redirects[gi.url], name=gi.name)
self._go_to_gi(new_gi)
return
# Use cache, or hit the network if resource is not cached
if self._is_cached(gi.url):
mime, body, tmpfile = self._get_cached(gi.url)
else:
try:
gi, mime, body, tmpfile = self._fetch_over_network(gi)
except Exception as err:
# Print an error message
if isinstance(err, socket.gaierror):
self.log["dns_failures"] += 1
print("ERROR: DNS error!")
elif isinstance(err, ConnectionRefusedError):
self.log["refused_connections"] += 1
print("ERROR: Connection refused!")
elif isinstance(err, ConnectionResetError):
self.log["reset_connections"] += 1
print("ERROR: Connection reset!")
elif isinstance(err, (TimeoutError, socket.timeout)):
self.log["timeouts"] += 1
print("""ERROR: Connection timed out!
Slow internet connection? Use 'set timeout' to be more patient.""")
else:
print("ERROR: " + str(err))
return
# Pass file to handler, unless we were asked not to
if handle:
if mime == "text/gemini":
self._handle_gemtext(body, gi)
else:
cmd_str = self._get_handler_cmd(mime)
try:
subprocess.call(shlex.split(cmd_str % tmpf.name))
except FileNotFoundError:
print("Handler program %s not found!" % shlex.split(cmd_str)[0])
print("You can use the ! command to specify another handler program or pipeline.")
# Update state
self.gi = gi
self.mime = mime
if update_hist:
self._update_history(gi)
def _fetch_over_network(self, gi):
# Be careful with client certificates!
# Are we crossing a domain boundary?
if self.active_cert_domains and gi.host not in self.active_cert_domains:
@ -353,50 +403,26 @@ you'll be able to transparently follow links to Gopherspace!""")
print("Remaining unidentified.")
self.client_certs.pop(gi.host)
# Do everything which touches the network in one block,
# so we only need to catch exceptions once
try:
# Is this a local file?
if not gi.host:
address, f = None, open(gi.path, "rb")
else:
address, f = self._send_request(gi)
# Is this a local file?
if not gi.host:
address, f = None, open(gi.path, "rb")
else:
address, f = self._send_request(gi)
# Spec dictates <META> should not exceed 1024 bytes,
# so maximum valid header length is 1027 bytes.
header = f.readline(1027)
header = header.decode("UTF-8")
if not header or header[-1] != '\n':
raise RuntimeError("Received invalid header from server!")
header = header.strip()
self._debug("Response header: %s." % header)
# Catch network errors which may happen on initial connection
except Exception as err:
# Print an error message
if isinstance(err, socket.gaierror):
self.log["dns_failures"] += 1
print("ERROR: DNS error!")
elif isinstance(err, ConnectionRefusedError):
self.log["refused_connections"] += 1
print("ERROR: Connection refused!")
elif isinstance(err, ConnectionResetError):
self.log["reset_connections"] += 1
print("ERROR: Connection reset!")
elif isinstance(err, (TimeoutError, socket.timeout)):
self.log["timeouts"] += 1
print("""ERROR: Connection timed out!
Slow internet connection? Use 'set timeout' to be more patient.""")
else:
print("ERROR: " + str(err))
return
# Spec dictates <META> should not exceed 1024 bytes,
# so maximum valid header length is 1027 bytes.
header = f.readline(1027)
header = header.decode("UTF-8")
if not header or header[-1] != '\n':
raise RuntimeError("Received invalid header from server!")
header = header.strip()
self._debug("Response header: %s." % header)
# Validate header
status, meta = header.split(maxsplit=1)
if len(meta) > 1024 or len(status) != 2 or not status.isnumeric():
print("ERROR: Received invalid header from server!")
f.close()
return
raise RuntimeError("Received invalid header from server!")
# Update redirect loop/maze escaping state
if not status.startswith("3"):
@ -410,20 +436,17 @@ Slow internet connection? Use 'set timeout' to be more patient.""")
user_input = getpass.getpass("> ")
else:
user_input = input("> ")
self._go_to_gi(gi.query(user_input))
return
return self._fetch_over_network(gi.query(user_input))
# Redirects
elif status.startswith("3"):
new_gi = GeminiItem(gi.absolutise_url(meta))
if new_gi.url == gi.url:
print("Error: URL redirects to itself!")
return
raise RuntimeError("URL redirects to itself!")
elif new_gi.url in self.previous_redirectors:
print("Error: caught in redirect loop!")
return
raise RuntimeError("Caught in redirect loop!")
elif len(self.previous_redirectors) == _MAX_REDIRECTS:
print("Error: refusing to follow more than %d consecutive redirects!" % _MAX_REDIRECTS)
return
raise RuntimeError("Refusing to follow more than %d consecutive redirects!" % _MAX_REDIRECTS)
# Never follow cross-domain redirects without asking
elif new_gi.host != gi.host:
follow = input("Follow cross-domain redirect to %s? (y/n) " % new_gi.url)
@ -444,12 +467,12 @@ Slow internet connection? Use 'set timeout' to be more patient.""")
if status == "31":
# Permanent redirect
self.permanent_redirects[gi.url] = new_gi.url
self._go_to_gi(new_gi)
return
return self._fetch_over_network(new_gi)
# Errors
elif status.startswith("4") or status.startswith("5"):
print("Error: %s" % meta)
return
raise RuntimeError(meta)
# Client cert
elif status.startswith("6"):
# Don't do client cert stuff in restricted mode, as in principle
@ -498,8 +521,7 @@ Slow internet connection? Use 'set timeout' to be more patient.""")
# Invalid status
elif not status.startswith("2"):
print("ERROR: Server returned undefined status code %s!" % status)
return
raise RuntimeError("Server returned undefined status code %s!" % status)
# If we're here, this must be a success and there's a response body
assert status.startswith("2")
@ -512,16 +534,12 @@ Slow internet connection? Use 'set timeout' to be more patient.""")
try:
codecs.lookup(mime_options["charset"])
except LookupError:
print("Header declared unknown encoding %s" % value)
return
raise RuntimeError("Header declared unknown encoding %s" % value)
# Read the response body over the network
body = f.read()
# Save the result in a temporary file
## Delete old file
if self.tmp_filename and os.path.exists(self.tmp_filename):
os.unlink(self.tmp_filename)
## Set file mode
if mime.startswith("text/"):
mode = "w"
@ -529,8 +547,7 @@ Slow internet connection? Use 'set timeout' to be more patient.""")
try:
body = body.decode(encoding)
except UnicodeError:
print("Could not decode response body using %s encoding declared in header!" % encoding)
return
raise RuntimeError("Could not decode response body using %s encoding declared in header!" % encoding)
else:
mode = "wb"
encoding = None
@ -541,24 +558,11 @@ Slow internet connection? Use 'set timeout' to be more patient.""")
self.tmp_filename = tmpf.name
self._debug("Wrote %d byte response to %s." % (size, self.tmp_filename))
# Pass file to handler, unless we were asked not to
if handle:
if mime == "text/gemini":
self._handle_gemtext(body, gi)
else:
cmd_str = self._get_handler_cmd(mime)
try:
subprocess.call(shlex.split(cmd_str % tmpf.name))
except FileNotFoundError:
print("Handler program %s not found!" % shlex.split(cmd_str)[0])
print("You can use the ! command to specify another handler program or pipeline.")
# Update state
self.gi = gi
self.mime = mime
# Maintain cache and log
self._add_to_cache(gi.url, mime, tmpf.name)
self._log_visit(gi, address, size)
if update_hist:
self._update_history(gi)
return gi, mime, body, tmpf
def _send_request(self, gi):
"""Send a selector to a given host and port.
@ -663,6 +667,69 @@ Slow internet connection? Use 'set timeout' to be more patient.""")
return addresses
def _is_cached(self, url):
if url not in self.cache:
return False
now = time.time()
cached = self.cache_timestamps[url]
if now - cached > 180:
self._debug("Expiring old cached copy of resource.")
self._remove_from_cache(url)
return False
self._debug("Found cached copy of resource.")
return True
def _remove_from_cache(self, url):
self.cache_timestamps.pop(url)
mime, filename = self.cache.pop(url)
os.unlink(filename)
self._validate_cache()
def _add_to_cache(self, url, mime, filename):
self.cache_timestamps[url] = time.time()
self.cache[url] = (mime, filename)
if len(self.cache) > 10:
self._trim_cache()
self._validate_cache()
def _trim_cache(self):
# Order cache entries by age
lru = [(t, u) for (u, t) in self.cache_timestamps.items()]
lru.sort()
# Drop the oldest entry no matter what
_, url = lru[0]
self._debug("Dropping cached copy of {} from full cache.".format(url))
self._remove_from_cache(url)
# Drop other entries if they are older than the limit
now = time.time()
for cached, url in lru[1:]:
if now - cached > 180:
self._debug("Dropping cached copy of {} from full cache.".format(url))
self._remove_from_cache(url)
else:
break
self._validate_cache()
def _get_cached(self, url):
mime, filename = self.cache[url]
if mime.startswith("text/gemini"):
with open(filename, "r") as fp:
body = fp.read()
return mime, body, filename
else:
return mime, None, filename
def _empty_cache(self):
for mime, filename in self.cache.values():
if os.path.exists(filename):
os.unlink(filename)
def _validate_cache(self):
assert self.cache.keys() == self.cache_timestamps.keys()
for _, filename in self.cache.values():
assert os.path.isfile(filename)
def _validate_cert(self, address, host, cert):
"""
Validate a TLS certificate in TOFU mode.
@ -1483,10 +1550,12 @@ current gemini browsing session."""
self.db_conn.commit()
self.db_conn.close()
# Clean up after ourself
self._empty_cache()
if self.tmp_filename and os.path.exists(self.tmp_filename):
os.unlink(self.tmp_filename)
if self.idx_filename and os.path.exists(self.idx_filename):
os.unlink(self.idx_filename)
for cert in self.transient_certs_created:
for ext in (".crt", ".key"):
certfile = os.path.join(self.config_dir, "transient_certs", cert+ext)