From 23d705b93a89cb0aee582eda819a76257f42dffc Mon Sep 17 00:00:00 2001 From: tjpcc Date: Tue, 24 Jan 2023 07:36:28 -0700 Subject: [PATCH] Add support for titan:// to the gemini server Titan is a gemini add-on protocol so it really didn't make sense to build it out in a separate package. The most significant difference in titan for the purposes of implementation here is that requests can have bodies following the URL line. Since gus.Request is a struct, the only way to smuggle in the new field (a reader for the body) was to stash it in the context. --- gemini/request.go | 9 ++++- gemini/roundtrip_test.go | 80 +++++++++++++++++++++++++++------------- gemini/serve.go | 53 +++++++++++++++++++++++--- 3 files changed, 110 insertions(+), 32 deletions(-) diff --git a/gemini/request.go b/gemini/request.go index ced7d0b..5220952 100644 --- a/gemini/request.go +++ b/gemini/request.go @@ -13,8 +13,15 @@ import ( var InvalidRequestLineEnding = errors.New("invalid request line ending") // ParseRequest parses a single gemini request from a reader. +// +// If the reader argument is a *bufio.Reader, it will only read a single line from it. func ParseRequest(rdr io.Reader) (*gus.Request, error) { - line, err := bufio.NewReader(rdr).ReadString('\n') + bufrdr, ok := rdr.(*bufio.Reader) + if !ok { + bufrdr = bufio.NewReader(rdr) + } + + line, err := bufrdr.ReadString('\n') if err != io.EOF && err != nil { return nil, err } diff --git a/gemini/roundtrip_test.go b/gemini/roundtrip_test.go index 4bac239..4f48e47 100644 --- a/gemini/roundtrip_test.go +++ b/gemini/roundtrip_test.go @@ -9,56 +9,84 @@ import ( "net/url" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tildegit.org/tjp/gus" "tildegit.org/tjp/gus/gemini" ) func TestRoundTrip(t *testing.T) { tlsConf, err := gemini.FileTLS("./testdata/server.crt", "./testdata/server.key") - if err != nil { - t.Fatalf("FileTLS(): %s", err.Error()) - } + require.Nil(t, err) handler := func(ctx context.Context, req *gus.Request) *gus.Response { return gemini.Success("text/gemini", bytes.NewBufferString("you've found my page")) } server, err := gemini.NewServer(context.Background(), nil, tlsConf, "tcp", "127.0.0.1:0", handler) - if err != nil { - t.Fatalf("NewServer(): %s", err.Error()) - } + require.Nil(t, err) go server.Serve() defer server.Close() u, err := url.Parse(fmt.Sprintf("gemini://%s/test", server.Address())) - if err != nil { - t.Fatalf("url.Parse: %s", err.Error()) - } + require.Nil(t, err) cli := gemini.NewClient(testClientTLS()) response, err := cli.RoundTrip(&gus.Request{URL: u}) - if err != nil { - t.Fatalf("RoundTrip(): %s", err.Error()) - } + require.Nil(t, err) - if response.Status != gemini.StatusSuccess { - t.Errorf("response status: expected %d, got %d", gemini.StatusSuccess, response.Status) - } - if response.Meta != "text/gemini" { - t.Errorf("response meta: expected \"text/gemini\", got %q", response.Meta) - } + assert.Equal(t, gemini.StatusSuccess, response.Status) + assert.Equal(t, "text/gemini", response.Meta) - if response.Body == nil { - t.Fatal("succcess response has nil body") - } + require.NotNil(t, response.Body) body, err := io.ReadAll(response.Body) - if err != nil { - t.Fatalf("ReadAll: %s", err.Error()) - } - if string(body) != "you've found my page" { - t.Errorf("response body: expected \"you've found my page\", got %q", string(body)) + require.Nil(t, err) + + assert.Equal(t, "you've found my page", string(body)) +} + +func TestTitanRequest(t *testing.T) { + tlsConf, err := gemini.FileTLS("./testdata/server.crt", "./testdata/server.key") + require.Nil(t, err) + + invoked := false + handler := func(ctx context.Context, request *gus.Request) *gus.Response { + invoked = true + + body := ctx.Value(gemini.TitanRequestBody) + if !assert.NotNil(t, body) { + return gemini.Success("", nil) + } + + bodyBytes, err := io.ReadAll(body.(io.Reader)) + require.Nil(t, err) + + assert.Equal(t, "the request body\n", string(bodyBytes)) + return gemini.Success("", nil) } + + server, err := gemini.NewServer(context.Background(), nil, tlsConf, "tcp", "127.0.0.1:0", handler) + require.Nil(t, err) + + go server.Serve() + defer server.Close() + + conn, err := tls.Dial(server.Network(), server.Address(), testClientTLS()) + require.Nil(t, err) + + _, err = fmt.Fprintf( + conn, + "titan://%s/foobar;size=17;mime=text/plain\r\nthe request body\n", + server.Address(), + ) + require.Nil(t, err) + + _, err = io.ReadAll(conn) + require.Nil(t, err) + + assert.True(t, invoked) } func testClientTLS() *tls.Config { diff --git a/gemini/serve.go b/gemini/serve.go index cd51370..dd7ad52 100644 --- a/gemini/serve.go +++ b/gemini/serve.go @@ -1,16 +1,26 @@ package gemini import ( + "bufio" "context" "crypto/tls" + "errors" "io" "net" + "strconv" + "strings" "sync" "tildegit.org/tjp/gus" "tildegit.org/tjp/gus/logging" ) +// TitanRequestBody is the key set in a handler's context for titan requests. +// +// When this key is present in the context (request.URL.Scheme will be "titan"), the +// corresponding value is a *bufio.Reader from which the request body can be read. +const TitanRequestBody = "titan_request_body" + type server struct { ctx context.Context errorLog logging.Logger @@ -59,6 +69,10 @@ func NewServer( // It will respect cancellation of the context the server was created with, // but be aware that Close() must still be called in that case to avoid // dangling goroutines. +// +// On titan protocol requests, it sets a key/value pair in the context. The +// key is TitanRequestBody, and the value is a *bufio.Reader from which the +// request body can be read. func (s *server) Serve() error { s.wg.Add(1) defer s.wg.Done() @@ -74,7 +88,7 @@ func (s *server) Serve() error { if s.Closed() { err = nil } else { - s.errorLog.Log("msg", "accept_error", "error", err) + s.errorLog.Log("msg", "accept_error", "error", err) } return err @@ -112,11 +126,12 @@ func (s *server) handleConn(conn net.Conn) { defer s.wg.Done() defer conn.Close() + buf := bufio.NewReader(conn) + var response *gus.Response - req, err := ParseRequest(conn) + req, err := ParseRequest(buf) if err != nil { response = BadRequest(err.Error()) - return } else { req.Server = s req.RemoteAddr = conn.RemoteAddr() @@ -125,13 +140,25 @@ func (s *server) handleConn(conn net.Conn) { req.TLSState = &state } - response = s.handler(s.ctx, req) + ctx := s.ctx + if req.Scheme == "titan" { + len, err := sizeParam(req.Path) + if err == nil { + ctx = context.WithValue( + ctx, + "titan_request_body", + io.LimitReader(buf, int64(len)), + ) + } + } + + response = s.handler(ctx, req) if response == nil { response = NotFound("Resource does not exist.") } - defer response.Close() } + defer response.Close() _, _ = io.Copy(conn, NewResponseReader(response)) } @@ -152,3 +179,19 @@ func (s *server) Closed() bool { return false } } + +func sizeParam(path string) (int, error) { + _, rest, found := strings.Cut(path, ";") + if !found { + return 0, errors.New("no params in path") + } + + for _, piece := range strings.Split(rest, ";") { + key, val, _ := strings.Cut(piece, "=") + if key == "size" { + return strconv.Atoi(val) + } + } + + return 0, errors.New("no size param found") +}