Refactor of networking logic.

1. Move client certificate handling stuff inside of _send_request().
2. Change _fetch_over_network() to not be recursive, by just looping
   through calls to _send_request().  This facilitates moving the
   redirect-tracking state inside _fetch_over_network(), instead of
   keeping it in GeminiClient.
3. Also allow _fetch_over_network() to save response to a provided
   filename, and use this to implement do_save(), rather than
   _go_to_gi().  This avoids the need for awkward gymnastics with
   the internal state.
This commit is contained in:
Solderpunk 2023-11-14 19:11:52 +01:00
parent 01da844141
commit 713616d556
1 changed files with 162 additions and 139 deletions

301
av98.py
View File

@ -248,7 +248,6 @@ class GeminiClient(cmd.Cmd):
self.marks = {}
self.page_index = 0
self.permanent_redirects = {}
self.previous_redirectors = set()
self.restricted = restricted
self.tmp_filename = ""
self.visited_hosts = set()
@ -388,7 +387,125 @@ Slow internet connection? Use 'set timeout' to be more patient.""")
body = None
return mime, body, gi.path
def _fetch_over_network(self, gi):
def _fetch_over_network(self, gi, destination=None):
previous_redirectors = set()
while True:
# Send request to server
try:
status, meta, address, f = self._send_request(gi)
except Exception as err:
if isinstance(err, socket.gaierror):
self.log["dns_failures"] += 1
elif isinstance(err, ConnectionRefusedError):
self.log["refused_connections"] += 1
elif isinstance(err, ConnectionResetError):
self.log["reset_connections"] += 1
elif isinstance(err, (TimeoutError, socket.timeout)):
self.log["timeouts"] += 1
raise err
# Update redirect loop/maze escaping state
if not status.startswith("3"):
previous_redirectors = set()
# Handle non-SUCCESS headers, which don't have a response body
# Inputs
if status.startswith("1"):
if status == "11":
user_input = getpass.getpass("> ")
else:
user_input = input("> ")
gi = gi.query(user_input)
continue
# Redirects
elif status.startswith("3"):
new_gi = GeminiItem(gi.absolutise_url(meta))
if new_gi.url == gi.url:
raise RuntimeError("URL redirects to itself!")
elif new_gi.url in previous_redirectors:
raise RuntimeError("Caught in redirect loop!")
elif len(previous_redirectors) == _MAX_REDIRECTS:
raise RuntimeError("Refusing to follow more than %d consecutive redirects!" % _MAX_REDIRECTS)
# Never follow cross-domain redirects without asking
elif new_gi.host != gi.host:
follow = input("Follow cross-domain redirect to %s? (y/n) " % new_gi.url)
# Never follow cross-protocol redirects without asking
elif new_gi.scheme != gi.scheme:
follow = input("Follow cross-protocol redirect to %s? (y/n) " % new_gi.url)
# Don't follow *any* redirect without asking if auto-follow is off
elif not self.options["auto_follow_redirects"]:
follow = input("Follow redirect to %s? (y/n) " % new_gi.url)
# Otherwise, follow away
else:
follow = "yes"
if follow.strip().lower() not in ("y", "yes"):
raise UserAbortException()
ui_out.debug("Following redirect to %s." % new_gi.url)
ui_out.debug("This is consecutive redirect number %d." % len(previous_redirectors))
previous_redirectors.add(gi.url)
if status == "31":
# Permanent redirect
self.permanent_redirects[gi.url] = new_gi.url
gi = new_gi
continue
# Errors
elif status.startswith("4") or status.startswith("5"):
raise RuntimeError(meta)
# Client cert
elif status.startswith("6"):
self._handle_cert_request(meta, status, gi.host)
continue
# Invalid status
elif not status.startswith("2"):
raise RuntimeError("Server returned undefined status code %s!" % status)
# If we're here, this must be a success and there's a response body,
# so break out of the request loop
assert status.startswith("2")
break
# Fill in default MIME type or validate a provided one
mime = meta
if mime == "":
mime = "text/gemini; charset=utf-8"
mime, mime_options = cgi.parse_header(mime)
if "charset" in mime_options:
try:
codecs.lookup(mime_options["charset"])
except LookupError:
raise RuntimeError("Header declared unknown encoding %s" % value)
# Save response body to disk
body, size, filename = self._write_response_to_file(mime, mime_options, f, destination)
ui_out.debug("Wrote %d byte response to %s." % (size, filename))
# Maintain cache and log
if self.options["cache"]:
self.cache.add(gi.url, mime, filename)
self._log_visit(gi, address, size)
return gi, mime, body, filename
def _send_request(self, gi):
"""Send a selector to a given host and port.
Returns the resolved address and binary file with the reply."""
# Figure out which host to connect to
if gi.scheme == "gemini":
# For Gemini requests, connect to the host and port specified in the URL
host, port = gi.host, gi.port
elif gi.scheme == "gopher":
# For Gopher requests, use the configured proxy
host, port = self.options["gopher_proxy"].rsplit(":", 1)
ui_out.debug("Using gopher proxy: " + self.options["gopher_proxy"])
elif gi.scheme in ("http", "https"):
host, port = self.options["http_proxy"].rsplit(":",1)
ui_out.debug("Using http proxy: " + self.options["http_proxy"])
# Be careful with client certificates!
# Are we crossing a domain boundary?
@ -421,133 +538,6 @@ Slow internet connection? Use 'set timeout' to be more patient.""")
print("Remaining unidentified.")
self.client_certs.pop(gi.host)
# Send request to server
try:
status, meta, address, f = self._send_request(gi)
except Exception as err:
if isinstance(err, socket.gaierror):
self.log["dns_failures"] += 1
elif isinstance(err, ConnectionRefusedError):
self.log["refused_connections"] += 1
elif isinstance(err, ConnectionResetError):
self.log["reset_connections"] += 1
elif isinstance(err, (TimeoutError, socket.timeout)):
self.log["timeouts"] += 1
raise err
# Update redirect loop/maze escaping state
if not status.startswith("3"):
self.previous_redirectors = set()
# Handle non-SUCCESS headers, which don't have a response body
# Inputs
if status.startswith("1"):
print(meta)
if status == "11":
user_input = getpass.getpass("> ")
else:
user_input = input("> ")
return self._fetch_over_network(gi.query(user_input))
# Redirects
elif status.startswith("3"):
new_gi = GeminiItem(gi.absolutise_url(meta))
if new_gi.url == gi.url:
raise RuntimeError("URL redirects to itself!")
elif new_gi.url in self.previous_redirectors:
raise RuntimeError("Caught in redirect loop!")
elif len(self.previous_redirectors) == _MAX_REDIRECTS:
raise RuntimeError("Refusing to follow more than %d consecutive redirects!" % _MAX_REDIRECTS)
# Never follow cross-domain redirects without asking
elif new_gi.host != gi.host:
follow = input("Follow cross-domain redirect to %s? (y/n) " % new_gi.url)
# Never follow cross-protocol redirects without asking
elif new_gi.scheme != gi.scheme:
follow = input("Follow cross-protocol redirect to %s? (y/n) " % new_gi.url)
# Don't follow *any* redirect without asking if auto-follow is off
elif not self.options["auto_follow_redirects"]:
follow = input("Follow redirect to %s? (y/n) " % new_gi.url)
# Otherwise, follow away
else:
follow = "yes"
if follow.strip().lower() not in ("y", "yes"):
raise UserAbortException()
ui_out.debug("Following redirect to %s." % new_gi.url)
ui_out.debug("This is consecutive redirect number %d." % len(self.previous_redirectors))
self.previous_redirectors.add(gi.url)
if status == "31":
# Permanent redirect
self.permanent_redirects[gi.url] = new_gi.url
return self._fetch_over_network(new_gi)
# Errors
elif status.startswith("4") or status.startswith("5"):
raise RuntimeError(meta)
# Client cert
elif status.startswith("6"):
self._handle_cert_request(meta, status, gi.host)
return self._fetch_over_network(gi)
# Invalid status
elif not status.startswith("2"):
raise RuntimeError("Server returned undefined status code %s!" % status)
# If we're here, this must be a success and there's a response body
assert status.startswith("2")
mime = meta
if mime == "":
mime = "text/gemini; charset=utf-8"
mime, mime_options = cgi.parse_header(mime)
if "charset" in mime_options:
try:
codecs.lookup(mime_options["charset"])
except LookupError:
raise RuntimeError("Header declared unknown encoding %s" % value)
# Read the response body over the network
body = f.read()
# Save the result in a temporary file
## Set file mode
if mime.startswith("text/"):
mode = "w"
encoding = mime_options.get("charset", "UTF-8")
try:
body = body.decode(encoding)
except UnicodeError:
raise RuntimeError("Could not decode response body using %s encoding declared in header!" % encoding)
else:
mode = "wb"
encoding = None
## Write
tmpf = tempfile.NamedTemporaryFile(mode, encoding=encoding, delete=False)
size = tmpf.write(body)
tmpf.close()
self.tmp_filename = tmpf.name
ui_out.debug("Wrote %d byte response to %s." % (size, self.tmp_filename))
# Maintain cache and log
if self.options["cache"]:
self.cache.add(gi.url, mime, self.tmp_filename)
self._log_visit(gi, address, size)
return gi, mime, body, self.tmp_filename
def _send_request(self, gi):
"""Send a selector to a given host and port.
Returns the resolved address and binary file with the reply."""
if gi.scheme == "gemini":
# For Gemini requests, connect to the host and port specified in the URL
host, port = gi.host, gi.port
elif gi.scheme == "gopher":
# For Gopher requests, use the configured proxy
host, port = self.options["gopher_proxy"].rsplit(":", 1)
ui_out.debug("Using gopher proxy: " + self.options["gopher_proxy"])
elif gi.scheme in ("http", "https"):
host, port = self.options["http_proxy"].rsplit(":",1)
ui_out.debug("Using http proxy: " + self.options["http_proxy"])
# Do DNS resolution
addresses = self._get_addresses(host, port)
@ -646,6 +636,37 @@ Slow internet connection? Use 'set timeout' to be more patient.""")
return status, meta, address, f
def _write_response_to_file(self, mime, mime_options, f, destination):
# Read the response body over the network
body = f.read()
# Save the result to a temporary file
## Determine file mode
if mime.startswith("text/"):
mode = "w"
encoding = mime_options.get("charset", "UTF-8")
try:
body = body.decode(encoding)
except UnicodeError:
raise RuntimeError("Could not decode response body using %s encoding declared in header!" % encoding)
else:
mode = "wb"
encoding = None
## Use a temporary file if a filename was not provided
if destination:
fp = open(destination, mode, encoding=encoding)
else:
fp = tempfile.NamedTemporaryFile(mode, encoding=encoding, delete=False)
self.tmp_filename = fp.name
## Write
size = fp.write(body)
fp.close()
return body, size, destination or self.tmp_filename
def _get_addresses(self, host, port):
# DNS lookup - will get IPv4 and IPv6 records if IPv6 is enabled
if ":" in host:
@ -1296,37 +1317,39 @@ Use 'ls -l' to see URLs."""
print("You must provide an index, a filename, or both.")
return
# Next, fetch the item to save, if it's not the current one.
# Determine GI to save
if index:
last_gi = self.gi
try:
gi = self.lookup[index-1]
self._go_to_gi(gi, update_hist = False, handle = False)
saving_current = False
except IndexError:
print ("Index too high!")
self.gi = last_gi
return
else:
gi = self.gi
saving_current = True
# Derive filename from current GI's path, if one hasn't been set
if not filename:
filename = os.path.basename(gi.path)
# Check for filename collisions and actually do the save if safe
# Check for filename collisions
if os.path.exists(filename):
print("File %s already exists!" % filename)
else:
return
# Actually do the save operation
if saving_current:
# Don't use _get_active_tmpfile() here, because we want to save the
# "source code" of menus, not the rendered view - this way AV-98
# can navigate to it later.
src = gi.path if gi.scheme == "file" else self.tmp_filename
shutil.copyfile(src, filename)
print("Saved to %s" % filename)
else:
## Download an item that's not the current one
self._fetch_over_network(gi, filename)
# Restore gi if necessary
if index != None:
self._go_to_gi(last_gi, handle=False)
print("Saved to %s" % filename)
@needs_gi
def do_url(self, *args):