dial timeouts for clients, and catch up on test fixes

This commit is contained in:
tjp 2024-01-13 11:29:17 -07:00
parent de1490808f
commit 4d861a2c39
13 changed files with 89 additions and 54 deletions

View File

@ -1,12 +1,12 @@
package sliderule
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
"net/url"
neturl "net/url"
"tildegit.org/tjp/sliderule/finger"
@ -18,7 +18,7 @@ import (
)
type protocolClient interface {
RoundTrip(*Request) (*Response, error)
RoundTrip(context.Context, *Request) (*Response, error)
IsRedirect(*Response) bool
}
@ -61,23 +61,23 @@ func NewClient(tlsConf *tls.Config) Client {
// RoundTrip sends a single request and returns the repsonse.
//
// If the response is a redirect it will be returned, rather than fetched.
func (c Client) RoundTrip(request *Request) (*Response, error) {
func (c Client) RoundTrip(ctx context.Context, request *Request) (*Response, error) {
pc, ok := c.protos[request.Scheme]
if !ok {
return nil, fmt.Errorf("unrecognized protocol: %s", request.Scheme)
}
return pc.RoundTrip(request)
return pc.RoundTrip(ctx, request)
}
// Fetch collects a resource from a URL including following any redirects.
func (c Client) Fetch(url string) (*Response, error) {
func (c Client) Fetch(ctx context.Context, url string) (*Response, error) {
u, err := neturl.Parse(url)
if err != nil {
return nil, err
}
for i := 0; i <= c.MaxRedirects; i += 1 {
response, err := c.RoundTrip(&types.Request{URL: u})
response, err := c.RoundTrip(ctx, &types.Request{URL: u})
if err != nil {
return nil, err
}
@ -100,23 +100,23 @@ func (c Client) Fetch(url string) (*Response, error) {
}
// Upload sends a request with a body and returns any redirect response.
func (c Client) Upload(url string, contents io.Reader) (*Response, error) {
func (c Client) Upload(ctx context.Context, url string, contents io.Reader) (*Response, error) {
u, err := neturl.Parse(url)
if err != nil {
return nil, err
}
switch u.Scheme {
case "titan", "spartan", "http", "https":
return c.RoundTrip(&types.Request{URL: u, Meta: contents})
return c.RoundTrip(ctx, &types.Request{URL: u, Meta: contents})
default:
return nil, fmt.Errorf("upload not supported on %s", u.Scheme)
}
}
func getRedirectLocation(prev *url.URL, proto string, meta any) string {
func getRedirectLocation(prev *neturl.URL, proto string, meta any) string {
switch proto {
case "gemini", "spartan":
u, _ := url.Parse(meta.(string))
u, _ := neturl.Parse(meta.(string))
return prev.ResolveReference(u).String()
case "http", "https":
return meta.(*http.Response).Header.Get("Location")
@ -128,9 +128,9 @@ type httpClient struct {
tp *http.Transport
}
func (hc httpClient) RoundTrip(request *Request) (*Response, error) {
func (hc httpClient) RoundTrip(ctx context.Context, request *Request) (*Response, error) {
body, _ := request.Meta.(io.Reader)
hreq, err := http.NewRequest("GET", request.URL.String(), body)
hreq, err := http.NewRequestWithContext(ctx, "GET", request.URL.String(), body)
if err != nil {
return nil, err
}

View File

@ -98,7 +98,7 @@ func requestPath(t *testing.T, client gemini.Client, server sr.Server, path stri
u, err := url.Parse("gemini://" + server.Address() + path)
require.Nil(t, err)
response, err := client.RoundTrip(&sr.Request{URL: u})
response, err := client.RoundTrip(context.Background(), &sr.Request{URL: u})
require.Nil(t, err)
return response

View File

@ -1,6 +1,7 @@
package main
import (
"context"
"fmt"
"io"
"log"
@ -33,7 +34,7 @@ func main() {
request := &sr.Request{URL: buildURL(os.Args[1])}
// fetch the response
response, err := client.RoundTrip(request)
response, err := client.RoundTrip(context.Background(), request)
if err != nil {
log.Fatal(err)
}

View File

@ -2,6 +2,7 @@ package finger
import (
"bytes"
"context"
"errors"
"io"
"net"
@ -18,7 +19,7 @@ import (
type Client struct{}
// RoundTrip sends a single finger request and returns its response.
func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) {
if request.Scheme != "finger" && request.Scheme != "" {
return nil, errors.New("non-finger protocols not supported")
}
@ -28,7 +29,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
host = net.JoinHostPort(host, "79")
}
conn, err := net.Dial("tcp", host)
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", host)
if err != nil {
return nil, err
}
@ -55,12 +56,12 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
}
// Fetch resolves a finger query.
func (c Client) Fetch(query string) (*types.Response, error) {
func (c Client) Fetch(ctx context.Context, query string) (*types.Response, error) {
req, err := ParseRequest(bytes.NewBufferString(query + "\r\n"))
if err != nil {
return nil, err
}
return c.RoundTrip(req)
return c.RoundTrip(ctx, req)
}
func (c Client) IsRedirect(_ *types.Response) bool { return false }

View File

@ -2,6 +2,7 @@ package gemini
import (
"bytes"
"context"
"crypto/tls"
"errors"
"io"
@ -49,7 +50,7 @@ var ExceededMaxRedirects = errors.New("gemini.Client: exceeded MaxRedirects")
//
// This method will not automatically follow redirects or cache permanent failures or
// redirects.
func (client Client) RoundTrip(request *types.Request) (*types.Response, error) {
func (client Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) {
if request.Scheme != "gemini" && request.Scheme != "titan" && request.Scheme != "" {
return nil, errors.New("non-gemini protocols not supported")
}
@ -64,14 +65,14 @@ func (client Client) RoundTrip(request *types.Request) (*types.Response, error)
tlsConf = &tls.Config{InsecureSkipVerify: true}
}
conn, err := tls.Dial("tcp", host, tlsConf)
conn, err := (&tls.Dialer{Config: tlsConf}).DialContext(ctx, "tcp", host)
if err != nil {
return nil, err
}
defer conn.Close()
request.RemoteAddr = conn.RemoteAddr()
st := conn.ConnectionState()
st := conn.(*tls.Conn).ConnectionState()
request.TLSState = &st
destURL := *request.URL
@ -124,14 +125,14 @@ func (client Client) RoundTrip(request *types.Request) (*types.Response, error)
// Fetch parses a URL string and fetches the gemini resource.
//
// It will resolve any redirects along the way, up to client.MaxRedirects.
func (c Client) Fetch(url string) (*types.Response, error) {
func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) {
u, err := neturl.Parse(url)
if err != nil {
return nil, err
}
for i := 0; i <= c.MaxRedirects; i += 1 {
response, err := c.RoundTrip(&types.Request{URL: u})
response, err := c.RoundTrip(ctx, &types.Request{URL: u})
if err != nil {
return nil, err
}

View File

@ -92,7 +92,7 @@ func TestParsePromptLine(t *testing.T) {
if line.Type() != gemtext.LineTypePrompt{
t.Errorf("expected LineTypePrompt, got %d", line.Type())
}
link, ok := line.(gemtext.PromptLine)
link, ok := line.(gemtext.LinkLine)
if !ok {
t.Fatalf("expected a PromptLine, got %T", line)
}

View File

@ -78,8 +78,8 @@ This is some non-blank regular text.
assert.Equal(t, gemtext.LineTypePrompt, doc[12].Type())
assert.Equal(t, "=: spartan://foo.bar/baz this should be a spartan prompt\n", string(doc[12].Raw()))
assert.Equal(t, "spartan://foo.bar/baz", doc[12].(gemtext.PromptLine).URL())
assert.Equal(t, "this should be a spartan prompt", doc[12].(gemtext.PromptLine).Label())
assert.Equal(t, "spartan://foo.bar/baz", doc[12].(gemtext.LinkLine).URL())
assert.Equal(t, "this should be a spartan prompt", doc[12].(gemtext.LinkLine).Label())
assertEmptyLine(t, doc[13])

View File

@ -12,8 +12,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tildegit.org/tjp/sliderule/internal/types"
"tildegit.org/tjp/sliderule/gemini"
"tildegit.org/tjp/sliderule/internal/types"
)
func TestRoundTrip(t *testing.T) {
@ -36,7 +36,7 @@ func TestRoundTrip(t *testing.T) {
require.Nil(t, err)
cli := gemini.NewClient(testClientTLS())
response, err := cli.RoundTrip(&types.Request{URL: u})
response, err := cli.RoundTrip(context.Background(), &types.Request{URL: u})
require.Nil(t, err)
assert.Equal(t, gemini.StatusSuccess, response.Status)

View File

@ -2,6 +2,7 @@ package gopher
import (
"bytes"
"context"
"errors"
"io"
"net"
@ -18,7 +19,7 @@ import (
type Client struct{}
// RoundTrip sends a single gopher request and returns its response.
func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) {
if request.Scheme != "gopher" && request.Scheme != "" {
return nil, errors.New("non-gopher protocols not supported")
}
@ -28,7 +29,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
host = net.JoinHostPort(host, "70")
}
conn, err := net.Dial("tcp", host)
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", host)
if err != nil {
return nil, err
}
@ -56,12 +57,12 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
}
// Fetch parses a URL string and fetches the gopher resource.
func (c Client) Fetch(url string) (*types.Response, error) {
func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) {
u, err := neturl.Parse(url)
if err != nil {
return nil, err
}
return c.RoundTrip(&types.Request{URL: u})
return c.RoundTrip(ctx, &types.Request{URL: u})
}
func (c Client) IsRedirect(_ *types.Response) bool { return false }

View File

@ -11,6 +11,6 @@ i /customlist.gophermap localhost.localdomain 70
0file4.txt /file4.txt localhost.localdomain 70
1subdir title /subdir localhost.localdomain 70
1subdir2 title /subdir2 localhost.localdomain 70
9uptime /uptime localhost.localdomain 70
0uptime /uptime localhost.localdomain 70
1uptime_output.gophermap /uptime_output.gophermap localhost.localdomain 70
.

View File

@ -2,6 +2,7 @@ package nex
import (
"bytes"
"context"
"errors"
"io"
"net"
@ -18,7 +19,7 @@ import (
type Client struct{}
// RoundTrip sends a single nex request and returns its response.
func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) {
if request.Scheme != "nex" && request.Scheme != "" {
return nil, errors.New("non-nex protocols not supported")
}
@ -28,7 +29,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
host = net.JoinHostPort(host, "1900")
}
conn, err := net.Dial("tcp", host)
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", host)
if err != nil {
return nil, err
}
@ -50,12 +51,12 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
}
// Fetch builds and sends a nex request, and returns the response.
func (c Client) Fetch(url string) (*types.Response, error) {
func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) {
u, err := neturl.Parse(url)
if err != nil {
return nil, err
}
return c.RoundTrip(&types.Request{URL: u})
return c.RoundTrip(ctx, &types.Request{URL: u})
}
func (c Client) IsRedirect(response *types.Response) bool { return false }

View File

@ -2,6 +2,7 @@ package spartan
import (
"bytes"
"context"
"errors"
"io"
"net"
@ -16,7 +17,7 @@ import (
// It carries no state and is reusable simultaneously by multiple goroutines.
//
// The zero value is immediately usabble, but will not follow redirects.
type Client struct{
type Client struct {
MaxRedirects int
}
@ -32,7 +33,7 @@ const DefaultMaxRedirects int = 2
var ExceededMaxRedirects = errors.New("spartan.Client: exceeded MaxRedirects")
// RoundTrip sends a single spartan request and returns its response.
func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) {
if request.Scheme != "spartan" && request.Scheme != "" {
return nil, errors.New("non-spartan protocols not supported")
}
@ -44,7 +45,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
}
addr := net.JoinHostPort(host, port)
conn, err := net.Dial("tcp", addr)
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
@ -90,14 +91,14 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
// Fetch parses a URL string and fetches the spartan resource.
//
// It will resolve any redirects along the way, up to client.MaxRedirects.
func (c Client) Fetch(url string) (*types.Response, error) {
func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) {
u, err := neturl.Parse(url)
if err != nil {
return nil, err
}
for i := 0; i <= c.MaxRedirects; i += 1 {
response, err := c.RoundTrip(&types.Request{URL: u})
response, err := c.RoundTrip(ctx, &types.Request{URL: u})
if err != nil {
return nil, err
}

View File

@ -1,12 +1,14 @@
package main
import (
"context"
"crypto/tls"
"fmt"
"io"
"net/http"
"net/url"
"os"
"time"
"tildegit.org/tjp/sliderule"
"tildegit.org/tjp/sliderule/gemini"
@ -17,29 +19,45 @@ const usage = `Resource fetcher for the small web.
Usage:
sw-fetch (-h | --help)
sw-fetch [-v | --verbose] [-o PATH | --output PATH] [-k | --keyfile PATH] [ -c | --certfile PATH ] [ -s | --skip-verify ] [ -u | --upload ] URL
sw-fetch
[-v | --verbose]
[-o PATH | --output PATH]
[-k | --keyfile PATH]
[ -c | --certfile PATH ]
[ -s | --skip-verify ]
[ -t | --timeout TIMEOUT ]
[ -u | --upload ]
URL
Options:
-h --help Show this screen.
-v --verbose Display more diagnostic information on standard error.
-o --output PATH Send the fetched resource to PATH instead of standard out.
-k --keyfile PATH Path to the TLS key file to use.
-c --certfile PATH Path to the TLS certificate file to use.
-s --skip-verify Don't verify server TLS certificates.
-u --upload Use stdin as the request body on supported protocols and don't follow redirects.
-h --help Show this screen.
-v --verbose Display more diagnostic information on standard error.
-o --output PATH Send the fetched resource to PATH instead of standard out.
-k --keyfile PATH Path to the TLS key file to use.
-c --certfile PATH Path to the TLS certificate file to use.
-s --skip-verify Don't verify server TLS certificates.
-t --timeout TIMEOUT Fail after the given timeout (like "15s").
-u --upload Use stdin as the request body on supported protocols and don't follow redirects.
`
func main() {
conf := configure()
cl := sliderule.NewClient(conf.clientTLS)
ctx := context.Background()
if conf.timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, conf.timeout)
defer cancel()
}
var response *sliderule.Response
var err error
if conf.upload {
response, err = cl.Upload(conf.url.String(), os.Stdin)
response, err = cl.Upload(ctx, conf.url.String(), os.Stdin)
} else {
response, err = cl.Fetch(conf.url.String())
response, err = cl.Fetch(ctx, conf.url.String())
}
if err != nil {
fail(err.Error() + "\n")
@ -61,6 +79,7 @@ type config struct {
output io.WriteCloser
url *url.URL
clientTLS *tls.Config
timeout time.Duration
}
func configure() config {
@ -72,6 +91,7 @@ func configure() config {
key := ""
cert := ""
verify := true
var err error
for i := 1; i <= len(os.Args)-1; i += 1 {
switch os.Args[i] {
@ -87,12 +107,11 @@ func configure() config {
out := os.Args[i+1]
if out != "-" {
output, err := os.OpenFile(out, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
conf.output, err = os.OpenFile(out, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
if err != nil {
fmt.Println(err.Error())
failf("'%s' is not a valid path\n", out)
}
conf.output = output
}
i += 1
@ -112,6 +131,16 @@ func configure() config {
cert = os.Args[i]
case "-s", "--skip-verify":
verify = false
case "-t", "--timeout":
if i+1 == len(os.Args)-1 {
fail(usage)
}
i += 1
conf.timeout, err = time.ParseDuration(os.Args[i])
if err != nil {
fail(err.Error())
}
case "-u", "--upload":
conf.upload = true
}