diff --git a/av98.py b/av98.py index 7671ee8..c877f95 100755 --- a/av98.py +++ b/av98.py @@ -287,6 +287,9 @@ class GeminiClient(cmd.Cmd): self._connect_to_tofu_db() + self.cache = {} + self.cache_timestamps = {} + def _connect_to_tofu_db(self): db_path = os.path.join(self.config_dir, "tofu.db") @@ -303,6 +306,7 @@ class GeminiClient(cmd.Cmd): sending the request over the network, parsing the response if its a menu, storing the response in a temporary file, choosing and calling a handler program, and updating the history.""" + # Don't try to speak to servers running other protocols if gi.scheme in ("http", "https"): webbrowser.open_new_tab(gi.url) @@ -316,12 +320,58 @@ you'll be able to transparently follow links to Gopherspace!""") elif gi.scheme not in ("gemini", "gopher"): print("Sorry, no support for {} links.".format(gi.scheme)) return + # Obey permanent redirects if gi.url in self.permanent_redirects: new_gi = GeminiItem(self.permanent_redirects[gi.url], name=gi.name) self._go_to_gi(new_gi) return + # Use cache, or hit the network if resource is not cached + if self._is_cached(gi.url): + mime, body, tmpfile = self._get_cached(gi.url) + else: + try: + gi, mime, body, tmpfile = self._fetch_over_network(gi) + except Exception as err: + # Print an error message + if isinstance(err, socket.gaierror): + self.log["dns_failures"] += 1 + print("ERROR: DNS error!") + elif isinstance(err, ConnectionRefusedError): + self.log["refused_connections"] += 1 + print("ERROR: Connection refused!") + elif isinstance(err, ConnectionResetError): + self.log["reset_connections"] += 1 + print("ERROR: Connection reset!") + elif isinstance(err, (TimeoutError, socket.timeout)): + self.log["timeouts"] += 1 + print("""ERROR: Connection timed out! + Slow internet connection? Use 'set timeout' to be more patient.""") + else: + print("ERROR: " + str(err)) + return + + # Pass file to handler, unless we were asked not to + if handle: + if mime == "text/gemini": + self._handle_gemtext(body, gi) + else: + cmd_str = self._get_handler_cmd(mime) + try: + subprocess.call(shlex.split(cmd_str % tmpf.name)) + except FileNotFoundError: + print("Handler program %s not found!" % shlex.split(cmd_str)[0]) + print("You can use the ! command to specify another handler program or pipeline.") + + # Update state + self.gi = gi + self.mime = mime + if update_hist: + self._update_history(gi) + + def _fetch_over_network(self, gi): + # Be careful with client certificates! # Are we crossing a domain boundary? if self.active_cert_domains and gi.host not in self.active_cert_domains: @@ -353,50 +403,26 @@ you'll be able to transparently follow links to Gopherspace!""") print("Remaining unidentified.") self.client_certs.pop(gi.host) - # Do everything which touches the network in one block, - # so we only need to catch exceptions once - try: - # Is this a local file? - if not gi.host: - address, f = None, open(gi.path, "rb") - else: - address, f = self._send_request(gi) + # Is this a local file? + if not gi.host: + address, f = None, open(gi.path, "rb") + else: + address, f = self._send_request(gi) - # Spec dictates should not exceed 1024 bytes, - # so maximum valid header length is 1027 bytes. - header = f.readline(1027) - header = header.decode("UTF-8") - if not header or header[-1] != '\n': - raise RuntimeError("Received invalid header from server!") - header = header.strip() - self._debug("Response header: %s." % header) - - # Catch network errors which may happen on initial connection - except Exception as err: - # Print an error message - if isinstance(err, socket.gaierror): - self.log["dns_failures"] += 1 - print("ERROR: DNS error!") - elif isinstance(err, ConnectionRefusedError): - self.log["refused_connections"] += 1 - print("ERROR: Connection refused!") - elif isinstance(err, ConnectionResetError): - self.log["reset_connections"] += 1 - print("ERROR: Connection reset!") - elif isinstance(err, (TimeoutError, socket.timeout)): - self.log["timeouts"] += 1 - print("""ERROR: Connection timed out! -Slow internet connection? Use 'set timeout' to be more patient.""") - else: - print("ERROR: " + str(err)) - return + # Spec dictates should not exceed 1024 bytes, + # so maximum valid header length is 1027 bytes. + header = f.readline(1027) + header = header.decode("UTF-8") + if not header or header[-1] != '\n': + raise RuntimeError("Received invalid header from server!") + header = header.strip() + self._debug("Response header: %s." % header) # Validate header status, meta = header.split(maxsplit=1) if len(meta) > 1024 or len(status) != 2 or not status.isnumeric(): - print("ERROR: Received invalid header from server!") f.close() - return + raise RuntimeError("Received invalid header from server!") # Update redirect loop/maze escaping state if not status.startswith("3"): @@ -410,20 +436,17 @@ Slow internet connection? Use 'set timeout' to be more patient.""") user_input = getpass.getpass("> ") else: user_input = input("> ") - self._go_to_gi(gi.query(user_input)) - return + return self._fetch_over_network(gi.query(user_input)) + # Redirects elif status.startswith("3"): new_gi = GeminiItem(gi.absolutise_url(meta)) if new_gi.url == gi.url: - print("Error: URL redirects to itself!") - return + raise RuntimeError("URL redirects to itself!") elif new_gi.url in self.previous_redirectors: - print("Error: caught in redirect loop!") - return + raise RuntimeError("Caught in redirect loop!") elif len(self.previous_redirectors) == _MAX_REDIRECTS: - print("Error: refusing to follow more than %d consecutive redirects!" % _MAX_REDIRECTS) - return + raise RuntimeError("Refusing to follow more than %d consecutive redirects!" % _MAX_REDIRECTS) # Never follow cross-domain redirects without asking elif new_gi.host != gi.host: follow = input("Follow cross-domain redirect to %s? (y/n) " % new_gi.url) @@ -444,12 +467,12 @@ Slow internet connection? Use 'set timeout' to be more patient.""") if status == "31": # Permanent redirect self.permanent_redirects[gi.url] = new_gi.url - self._go_to_gi(new_gi) - return + return self._fetch_over_network(new_gi) + # Errors elif status.startswith("4") or status.startswith("5"): - print("Error: %s" % meta) - return + raise RuntimeError(meta) + # Client cert elif status.startswith("6"): # Don't do client cert stuff in restricted mode, as in principle @@ -498,8 +521,7 @@ Slow internet connection? Use 'set timeout' to be more patient.""") # Invalid status elif not status.startswith("2"): - print("ERROR: Server returned undefined status code %s!" % status) - return + raise RuntimeError("Server returned undefined status code %s!" % status) # If we're here, this must be a success and there's a response body assert status.startswith("2") @@ -512,16 +534,12 @@ Slow internet connection? Use 'set timeout' to be more patient.""") try: codecs.lookup(mime_options["charset"]) except LookupError: - print("Header declared unknown encoding %s" % value) - return + raise RuntimeError("Header declared unknown encoding %s" % value) # Read the response body over the network body = f.read() # Save the result in a temporary file - ## Delete old file - if self.tmp_filename and os.path.exists(self.tmp_filename): - os.unlink(self.tmp_filename) ## Set file mode if mime.startswith("text/"): mode = "w" @@ -529,8 +547,7 @@ Slow internet connection? Use 'set timeout' to be more patient.""") try: body = body.decode(encoding) except UnicodeError: - print("Could not decode response body using %s encoding declared in header!" % encoding) - return + raise RuntimeError("Could not decode response body using %s encoding declared in header!" % encoding) else: mode = "wb" encoding = None @@ -541,24 +558,11 @@ Slow internet connection? Use 'set timeout' to be more patient.""") self.tmp_filename = tmpf.name self._debug("Wrote %d byte response to %s." % (size, self.tmp_filename)) - # Pass file to handler, unless we were asked not to - if handle: - if mime == "text/gemini": - self._handle_gemtext(body, gi) - else: - cmd_str = self._get_handler_cmd(mime) - try: - subprocess.call(shlex.split(cmd_str % tmpf.name)) - except FileNotFoundError: - print("Handler program %s not found!" % shlex.split(cmd_str)[0]) - print("You can use the ! command to specify another handler program or pipeline.") - - # Update state - self.gi = gi - self.mime = mime + # Maintain cache and log + self._add_to_cache(gi.url, mime, tmpf.name) self._log_visit(gi, address, size) - if update_hist: - self._update_history(gi) + + return gi, mime, body, tmpf def _send_request(self, gi): """Send a selector to a given host and port. @@ -663,6 +667,69 @@ Slow internet connection? Use 'set timeout' to be more patient.""") return addresses + def _is_cached(self, url): + if url not in self.cache: + return False + now = time.time() + cached = self.cache_timestamps[url] + if now - cached > 180: + self._debug("Expiring old cached copy of resource.") + self._remove_from_cache(url) + return False + self._debug("Found cached copy of resource.") + return True + + def _remove_from_cache(self, url): + self.cache_timestamps.pop(url) + mime, filename = self.cache.pop(url) + os.unlink(filename) + self._validate_cache() + + def _add_to_cache(self, url, mime, filename): + + self.cache_timestamps[url] = time.time() + self.cache[url] = (mime, filename) + if len(self.cache) > 10: + self._trim_cache() + self._validate_cache() + + def _trim_cache(self): + # Order cache entries by age + lru = [(t, u) for (u, t) in self.cache_timestamps.items()] + lru.sort() + # Drop the oldest entry no matter what + _, url = lru[0] + self._debug("Dropping cached copy of {} from full cache.".format(url)) + self._remove_from_cache(url) + # Drop other entries if they are older than the limit + now = time.time() + for cached, url in lru[1:]: + if now - cached > 180: + self._debug("Dropping cached copy of {} from full cache.".format(url)) + self._remove_from_cache(url) + else: + break + self._validate_cache() + + def _get_cached(self, url): + mime, filename = self.cache[url] + if mime.startswith("text/gemini"): + with open(filename, "r") as fp: + body = fp.read() + return mime, body, filename + else: + return mime, None, filename + + def _empty_cache(self): + for mime, filename in self.cache.values(): + if os.path.exists(filename): + os.unlink(filename) + + def _validate_cache(self): + assert self.cache.keys() == self.cache_timestamps.keys() + for _, filename in self.cache.values(): + assert os.path.isfile(filename) + def _validate_cert(self, address, host, cert): """ Validate a TLS certificate in TOFU mode. @@ -1483,10 +1550,12 @@ current gemini browsing session.""" self.db_conn.commit() self.db_conn.close() # Clean up after ourself + self._empty_cache() if self.tmp_filename and os.path.exists(self.tmp_filename): os.unlink(self.tmp_filename) if self.idx_filename and os.path.exists(self.idx_filename): os.unlink(self.idx_filename) + for cert in self.transient_certs_created: for ext in (".crt", ".key"): certfile = os.path.join(self.config_dir, "transient_certs", cert+ext)