dial timeouts for clients, and catch up on test fixes
This commit is contained in:
parent
de1490808f
commit
4d861a2c39
24
client.go
24
client.go
|
@ -1,12 +1,12 @@
|
|||
package sliderule
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
neturl "net/url"
|
||||
|
||||
"tildegit.org/tjp/sliderule/finger"
|
||||
|
@ -18,7 +18,7 @@ import (
|
|||
)
|
||||
|
||||
type protocolClient interface {
|
||||
RoundTrip(*Request) (*Response, error)
|
||||
RoundTrip(context.Context, *Request) (*Response, error)
|
||||
IsRedirect(*Response) bool
|
||||
}
|
||||
|
||||
|
@ -61,23 +61,23 @@ func NewClient(tlsConf *tls.Config) Client {
|
|||
// RoundTrip sends a single request and returns the repsonse.
|
||||
//
|
||||
// If the response is a redirect it will be returned, rather than fetched.
|
||||
func (c Client) RoundTrip(request *Request) (*Response, error) {
|
||||
func (c Client) RoundTrip(ctx context.Context, request *Request) (*Response, error) {
|
||||
pc, ok := c.protos[request.Scheme]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unrecognized protocol: %s", request.Scheme)
|
||||
}
|
||||
return pc.RoundTrip(request)
|
||||
return pc.RoundTrip(ctx, request)
|
||||
}
|
||||
|
||||
// Fetch collects a resource from a URL including following any redirects.
|
||||
func (c Client) Fetch(url string) (*Response, error) {
|
||||
func (c Client) Fetch(ctx context.Context, url string) (*Response, error) {
|
||||
u, err := neturl.Parse(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := 0; i <= c.MaxRedirects; i += 1 {
|
||||
response, err := c.RoundTrip(&types.Request{URL: u})
|
||||
response, err := c.RoundTrip(ctx, &types.Request{URL: u})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -100,23 +100,23 @@ func (c Client) Fetch(url string) (*Response, error) {
|
|||
}
|
||||
|
||||
// Upload sends a request with a body and returns any redirect response.
|
||||
func (c Client) Upload(url string, contents io.Reader) (*Response, error) {
|
||||
func (c Client) Upload(ctx context.Context, url string, contents io.Reader) (*Response, error) {
|
||||
u, err := neturl.Parse(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch u.Scheme {
|
||||
case "titan", "spartan", "http", "https":
|
||||
return c.RoundTrip(&types.Request{URL: u, Meta: contents})
|
||||
return c.RoundTrip(ctx, &types.Request{URL: u, Meta: contents})
|
||||
default:
|
||||
return nil, fmt.Errorf("upload not supported on %s", u.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func getRedirectLocation(prev *url.URL, proto string, meta any) string {
|
||||
func getRedirectLocation(prev *neturl.URL, proto string, meta any) string {
|
||||
switch proto {
|
||||
case "gemini", "spartan":
|
||||
u, _ := url.Parse(meta.(string))
|
||||
u, _ := neturl.Parse(meta.(string))
|
||||
return prev.ResolveReference(u).String()
|
||||
case "http", "https":
|
||||
return meta.(*http.Response).Header.Get("Location")
|
||||
|
@ -128,9 +128,9 @@ type httpClient struct {
|
|||
tp *http.Transport
|
||||
}
|
||||
|
||||
func (hc httpClient) RoundTrip(request *Request) (*Response, error) {
|
||||
func (hc httpClient) RoundTrip(ctx context.Context, request *Request) (*Response, error) {
|
||||
body, _ := request.Meta.(io.Reader)
|
||||
hreq, err := http.NewRequest("GET", request.URL.String(), body)
|
||||
hreq, err := http.NewRequestWithContext(ctx, "GET", request.URL.String(), body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -98,7 +98,7 @@ func requestPath(t *testing.T, client gemini.Client, server sr.Server, path stri
|
|||
u, err := url.Parse("gemini://" + server.Address() + path)
|
||||
require.Nil(t, err)
|
||||
|
||||
response, err := client.RoundTrip(&sr.Request{URL: u})
|
||||
response, err := client.RoundTrip(context.Background(), &sr.Request{URL: u})
|
||||
require.Nil(t, err)
|
||||
|
||||
return response
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
|
@ -33,7 +34,7 @@ func main() {
|
|||
request := &sr.Request{URL: buildURL(os.Args[1])}
|
||||
|
||||
// fetch the response
|
||||
response, err := client.RoundTrip(request)
|
||||
response, err := client.RoundTrip(context.Background(), request)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package finger
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -18,7 +19,7 @@ import (
|
|||
type Client struct{}
|
||||
|
||||
// RoundTrip sends a single finger request and returns its response.
|
||||
func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
|
||||
func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) {
|
||||
if request.Scheme != "finger" && request.Scheme != "" {
|
||||
return nil, errors.New("non-finger protocols not supported")
|
||||
}
|
||||
|
@ -28,7 +29,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
|
|||
host = net.JoinHostPort(host, "79")
|
||||
}
|
||||
|
||||
conn, err := net.Dial("tcp", host)
|
||||
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -55,12 +56,12 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
|
|||
}
|
||||
|
||||
// Fetch resolves a finger query.
|
||||
func (c Client) Fetch(query string) (*types.Response, error) {
|
||||
func (c Client) Fetch(ctx context.Context, query string) (*types.Response, error) {
|
||||
req, err := ParseRequest(bytes.NewBufferString(query + "\r\n"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.RoundTrip(req)
|
||||
return c.RoundTrip(ctx, req)
|
||||
}
|
||||
|
||||
func (c Client) IsRedirect(_ *types.Response) bool { return false }
|
||||
|
|
|
@ -2,6 +2,7 @@ package gemini
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
|
@ -49,7 +50,7 @@ var ExceededMaxRedirects = errors.New("gemini.Client: exceeded MaxRedirects")
|
|||
//
|
||||
// This method will not automatically follow redirects or cache permanent failures or
|
||||
// redirects.
|
||||
func (client Client) RoundTrip(request *types.Request) (*types.Response, error) {
|
||||
func (client Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) {
|
||||
if request.Scheme != "gemini" && request.Scheme != "titan" && request.Scheme != "" {
|
||||
return nil, errors.New("non-gemini protocols not supported")
|
||||
}
|
||||
|
@ -64,14 +65,14 @@ func (client Client) RoundTrip(request *types.Request) (*types.Response, error)
|
|||
tlsConf = &tls.Config{InsecureSkipVerify: true}
|
||||
}
|
||||
|
||||
conn, err := tls.Dial("tcp", host, tlsConf)
|
||||
conn, err := (&tls.Dialer{Config: tlsConf}).DialContext(ctx, "tcp", host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
request.RemoteAddr = conn.RemoteAddr()
|
||||
st := conn.ConnectionState()
|
||||
st := conn.(*tls.Conn).ConnectionState()
|
||||
request.TLSState = &st
|
||||
|
||||
destURL := *request.URL
|
||||
|
@ -124,14 +125,14 @@ func (client Client) RoundTrip(request *types.Request) (*types.Response, error)
|
|||
// Fetch parses a URL string and fetches the gemini resource.
|
||||
//
|
||||
// It will resolve any redirects along the way, up to client.MaxRedirects.
|
||||
func (c Client) Fetch(url string) (*types.Response, error) {
|
||||
func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) {
|
||||
u, err := neturl.Parse(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := 0; i <= c.MaxRedirects; i += 1 {
|
||||
response, err := c.RoundTrip(&types.Request{URL: u})
|
||||
response, err := c.RoundTrip(ctx, &types.Request{URL: u})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -92,7 +92,7 @@ func TestParsePromptLine(t *testing.T) {
|
|||
if line.Type() != gemtext.LineTypePrompt{
|
||||
t.Errorf("expected LineTypePrompt, got %d", line.Type())
|
||||
}
|
||||
link, ok := line.(gemtext.PromptLine)
|
||||
link, ok := line.(gemtext.LinkLine)
|
||||
if !ok {
|
||||
t.Fatalf("expected a PromptLine, got %T", line)
|
||||
}
|
||||
|
|
|
@ -78,8 +78,8 @@ This is some non-blank regular text.
|
|||
|
||||
assert.Equal(t, gemtext.LineTypePrompt, doc[12].Type())
|
||||
assert.Equal(t, "=: spartan://foo.bar/baz this should be a spartan prompt\n", string(doc[12].Raw()))
|
||||
assert.Equal(t, "spartan://foo.bar/baz", doc[12].(gemtext.PromptLine).URL())
|
||||
assert.Equal(t, "this should be a spartan prompt", doc[12].(gemtext.PromptLine).Label())
|
||||
assert.Equal(t, "spartan://foo.bar/baz", doc[12].(gemtext.LinkLine).URL())
|
||||
assert.Equal(t, "this should be a spartan prompt", doc[12].(gemtext.LinkLine).Label())
|
||||
|
||||
assertEmptyLine(t, doc[13])
|
||||
|
||||
|
|
|
@ -12,8 +12,8 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"tildegit.org/tjp/sliderule/internal/types"
|
||||
"tildegit.org/tjp/sliderule/gemini"
|
||||
"tildegit.org/tjp/sliderule/internal/types"
|
||||
)
|
||||
|
||||
func TestRoundTrip(t *testing.T) {
|
||||
|
@ -36,7 +36,7 @@ func TestRoundTrip(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
|
||||
cli := gemini.NewClient(testClientTLS())
|
||||
response, err := cli.RoundTrip(&types.Request{URL: u})
|
||||
response, err := cli.RoundTrip(context.Background(), &types.Request{URL: u})
|
||||
require.Nil(t, err)
|
||||
|
||||
assert.Equal(t, gemini.StatusSuccess, response.Status)
|
||||
|
|
|
@ -2,6 +2,7 @@ package gopher
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -18,7 +19,7 @@ import (
|
|||
type Client struct{}
|
||||
|
||||
// RoundTrip sends a single gopher request and returns its response.
|
||||
func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
|
||||
func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) {
|
||||
if request.Scheme != "gopher" && request.Scheme != "" {
|
||||
return nil, errors.New("non-gopher protocols not supported")
|
||||
}
|
||||
|
@ -28,7 +29,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
|
|||
host = net.JoinHostPort(host, "70")
|
||||
}
|
||||
|
||||
conn, err := net.Dial("tcp", host)
|
||||
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -56,12 +57,12 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
|
|||
}
|
||||
|
||||
// Fetch parses a URL string and fetches the gopher resource.
|
||||
func (c Client) Fetch(url string) (*types.Response, error) {
|
||||
func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) {
|
||||
u, err := neturl.Parse(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.RoundTrip(&types.Request{URL: u})
|
||||
return c.RoundTrip(ctx, &types.Request{URL: u})
|
||||
}
|
||||
|
||||
func (c Client) IsRedirect(_ *types.Response) bool { return false }
|
||||
|
|
|
@ -11,6 +11,6 @@ i /customlist.gophermap localhost.localdomain 70
|
|||
0file4.txt /file4.txt localhost.localdomain 70
|
||||
1subdir title /subdir localhost.localdomain 70
|
||||
1subdir2 title /subdir2 localhost.localdomain 70
|
||||
9uptime /uptime localhost.localdomain 70
|
||||
0uptime /uptime localhost.localdomain 70
|
||||
1uptime_output.gophermap /uptime_output.gophermap localhost.localdomain 70
|
||||
.
|
||||
|
|
|
@ -2,6 +2,7 @@ package nex
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -18,7 +19,7 @@ import (
|
|||
type Client struct{}
|
||||
|
||||
// RoundTrip sends a single nex request and returns its response.
|
||||
func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
|
||||
func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) {
|
||||
if request.Scheme != "nex" && request.Scheme != "" {
|
||||
return nil, errors.New("non-nex protocols not supported")
|
||||
}
|
||||
|
@ -28,7 +29,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
|
|||
host = net.JoinHostPort(host, "1900")
|
||||
}
|
||||
|
||||
conn, err := net.Dial("tcp", host)
|
||||
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -50,12 +51,12 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
|
|||
}
|
||||
|
||||
// Fetch builds and sends a nex request, and returns the response.
|
||||
func (c Client) Fetch(url string) (*types.Response, error) {
|
||||
func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) {
|
||||
u, err := neturl.Parse(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.RoundTrip(&types.Request{URL: u})
|
||||
return c.RoundTrip(ctx, &types.Request{URL: u})
|
||||
}
|
||||
|
||||
func (c Client) IsRedirect(response *types.Response) bool { return false }
|
||||
|
|
|
@ -2,6 +2,7 @@ package spartan
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -16,7 +17,7 @@ import (
|
|||
// It carries no state and is reusable simultaneously by multiple goroutines.
|
||||
//
|
||||
// The zero value is immediately usabble, but will not follow redirects.
|
||||
type Client struct{
|
||||
type Client struct {
|
||||
MaxRedirects int
|
||||
}
|
||||
|
||||
|
@ -32,7 +33,7 @@ const DefaultMaxRedirects int = 2
|
|||
var ExceededMaxRedirects = errors.New("spartan.Client: exceeded MaxRedirects")
|
||||
|
||||
// RoundTrip sends a single spartan request and returns its response.
|
||||
func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
|
||||
func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) {
|
||||
if request.Scheme != "spartan" && request.Scheme != "" {
|
||||
return nil, errors.New("non-spartan protocols not supported")
|
||||
}
|
||||
|
@ -44,7 +45,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
|
|||
}
|
||||
addr := net.JoinHostPort(host, port)
|
||||
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -90,14 +91,14 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
|
|||
// Fetch parses a URL string and fetches the spartan resource.
|
||||
//
|
||||
// It will resolve any redirects along the way, up to client.MaxRedirects.
|
||||
func (c Client) Fetch(url string) (*types.Response, error) {
|
||||
func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) {
|
||||
u, err := neturl.Parse(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := 0; i <= c.MaxRedirects; i += 1 {
|
||||
response, err := c.RoundTrip(&types.Request{URL: u})
|
||||
response, err := c.RoundTrip(ctx, &types.Request{URL: u})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"tildegit.org/tjp/sliderule"
|
||||
"tildegit.org/tjp/sliderule/gemini"
|
||||
|
@ -17,29 +19,45 @@ const usage = `Resource fetcher for the small web.
|
|||
|
||||
Usage:
|
||||
sw-fetch (-h | --help)
|
||||
sw-fetch [-v | --verbose] [-o PATH | --output PATH] [-k | --keyfile PATH] [ -c | --certfile PATH ] [ -s | --skip-verify ] [ -u | --upload ] URL
|
||||
sw-fetch
|
||||
[-v | --verbose]
|
||||
[-o PATH | --output PATH]
|
||||
[-k | --keyfile PATH]
|
||||
[ -c | --certfile PATH ]
|
||||
[ -s | --skip-verify ]
|
||||
[ -t | --timeout TIMEOUT ]
|
||||
[ -u | --upload ]
|
||||
URL
|
||||
|
||||
Options:
|
||||
-h --help Show this screen.
|
||||
-v --verbose Display more diagnostic information on standard error.
|
||||
-o --output PATH Send the fetched resource to PATH instead of standard out.
|
||||
-k --keyfile PATH Path to the TLS key file to use.
|
||||
-c --certfile PATH Path to the TLS certificate file to use.
|
||||
-s --skip-verify Don't verify server TLS certificates.
|
||||
-u --upload Use stdin as the request body on supported protocols and don't follow redirects.
|
||||
-h --help Show this screen.
|
||||
-v --verbose Display more diagnostic information on standard error.
|
||||
-o --output PATH Send the fetched resource to PATH instead of standard out.
|
||||
-k --keyfile PATH Path to the TLS key file to use.
|
||||
-c --certfile PATH Path to the TLS certificate file to use.
|
||||
-s --skip-verify Don't verify server TLS certificates.
|
||||
-t --timeout TIMEOUT Fail after the given timeout (like "15s").
|
||||
-u --upload Use stdin as the request body on supported protocols and don't follow redirects.
|
||||
`
|
||||
|
||||
func main() {
|
||||
conf := configure()
|
||||
cl := sliderule.NewClient(conf.clientTLS)
|
||||
|
||||
ctx := context.Background()
|
||||
if conf.timeout != 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, conf.timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
var response *sliderule.Response
|
||||
var err error
|
||||
|
||||
if conf.upload {
|
||||
response, err = cl.Upload(conf.url.String(), os.Stdin)
|
||||
response, err = cl.Upload(ctx, conf.url.String(), os.Stdin)
|
||||
} else {
|
||||
response, err = cl.Fetch(conf.url.String())
|
||||
response, err = cl.Fetch(ctx, conf.url.String())
|
||||
}
|
||||
if err != nil {
|
||||
fail(err.Error() + "\n")
|
||||
|
@ -61,6 +79,7 @@ type config struct {
|
|||
output io.WriteCloser
|
||||
url *url.URL
|
||||
clientTLS *tls.Config
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func configure() config {
|
||||
|
@ -72,6 +91,7 @@ func configure() config {
|
|||
key := ""
|
||||
cert := ""
|
||||
verify := true
|
||||
var err error
|
||||
|
||||
for i := 1; i <= len(os.Args)-1; i += 1 {
|
||||
switch os.Args[i] {
|
||||
|
@ -87,12 +107,11 @@ func configure() config {
|
|||
|
||||
out := os.Args[i+1]
|
||||
if out != "-" {
|
||||
output, err := os.OpenFile(out, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
conf.output, err = os.OpenFile(out, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
failf("'%s' is not a valid path\n", out)
|
||||
}
|
||||
conf.output = output
|
||||
}
|
||||
|
||||
i += 1
|
||||
|
@ -112,6 +131,16 @@ func configure() config {
|
|||
cert = os.Args[i]
|
||||
case "-s", "--skip-verify":
|
||||
verify = false
|
||||
case "-t", "--timeout":
|
||||
if i+1 == len(os.Args)-1 {
|
||||
fail(usage)
|
||||
}
|
||||
|
||||
i += 1
|
||||
conf.timeout, err = time.ParseDuration(os.Args[i])
|
||||
if err != nil {
|
||||
fail(err.Error())
|
||||
}
|
||||
case "-u", "--upload":
|
||||
conf.upload = true
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue