From fcf545c27c70fb795a631a42738c486c07092e83 Mon Sep 17 00:00:00 2001 From: tjpcc Date: Tue, 14 Feb 2023 20:10:57 -0700 Subject: [PATCH] Router improvements. - test coverage for Router, not just PathTree - Router.Mount() now flattens routes into the parent router - Router.Use() implemented to set middleware on a router itself --- internal/pathtree.go | 48 ++++++++++++++++++++++++++ router.go | 58 ++++++++++++++++---------------- router_test.go | 80 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 157 insertions(+), 29 deletions(-) create mode 100644 router_test.go diff --git a/internal/pathtree.go b/internal/pathtree.go index 563e85c..7da4c2b 100644 --- a/internal/pathtree.go +++ b/internal/pathtree.go @@ -32,6 +32,15 @@ func (pt *PathTree[V]) Add(pattern string, value V) { } } +type Route[V any] struct { + Pattern string + Value V +} + +func (pt PathTree[V]) Routes() []Route[V] { + return pt.tree.routes() +} + // pattern segment which must be a specific string ("/users/"). type segmentNode[V any] struct { label string @@ -216,6 +225,45 @@ func (st *subtree[V]) Add(pattern []string, value V) { } } +func (st subtree[V]) routes() []Route[V] { + routes := []Route[V]{} + for _, seg := range st.segments { + if seg.value != nil { + routes = append(routes, Route[V]{ + Pattern: seg.label, + Value: *seg.value, + }) + } + for _, r := range seg.subtree.routes() { + r.Pattern = seg.label + "/" + r.Pattern + routes = append(routes, r) + } + } + + for _, wc := range st.wildcards { + if wc.value != nil { + routes = append(routes, Route[V]{ + Pattern: ":" + wc.param, + Value: *wc.value, + }) + } + for _, r := range wc.subtree.routes() { + r.Pattern = ":" + wc.param + "/" + r.Pattern + routes = append(routes, r) + } + } + + if st.remainder != nil { + rn := *st.remainder + routes = append(routes, Route[V]{ + Pattern: "*" + rn.param, + Value: rn.value, + }) + } + + return routes +} + type childSegments[V any] []segmentNode[V] func (cs childSegments[V]) Len() int { return len(cs) } diff --git a/router.go b/router.go index 50cc41f..1d8e93d 100644 --- a/router.go +++ b/router.go @@ -2,8 +2,6 @@ package gus import ( "context" - "crypto/tls" - "net/url" "strings" "tildegit.org/tjp/gus/internal" @@ -27,11 +25,18 @@ import ( // The zero value is a usable Router which will fail to match any requst path. type Router struct { tree internal.PathTree[Handler] + + middleware []Middleware + routeAdded bool } // Route adds a handler to the router under a path pattern. -func (r Router) Route(pattern string, handler Handler) { +func (r *Router) Route(pattern string, handler Handler) { + for i := len(r.middleware) - 1; i >= 0; i-- { + handler = r.middleware[i](handler) + } r.tree.Add(pattern, handler) + r.routeAdded = true } // Handler matches against the request path and dipatches to a route handler. @@ -59,6 +64,8 @@ func (r Router) Handler(ctx context.Context, request *Request) *Response { } // Match returns the matched handler and captured path parameters, or nils. +// +// The returned handlers will be wrapped with any middleware attached to the router. func (r Router) Match(request *Request) (Handler, map[string]string) { handler, params := r.tree.Match(request.Path) if handler == nil { @@ -72,19 +79,27 @@ func (r Router) Match(request *Request) (Handler, map[string]string) { // The prefix pattern may include segment :wildcards, but no *remainder segment. The // mounted sub-router should have patterns which only include the portion of the path // after whatever was matched by the prefix pattern. -func (r Router) Mount(prefix string, subrouter *Router) { +func (r *Router) Mount(prefix string, subrouter *Router) { prefix = strings.TrimSuffix(prefix, "/") - r.Route(prefix+"/*"+subrouterPathKey, func(ctx context.Context, request *Request) *Response { - r := cloneRequest(request) - r.Path = "/" + RouteParams(ctx)[subrouterPathKey] - return subrouter.Handler(ctx, r) - }) - // TODO: better approach. the above works but it's a little hacky - // - add a method to PathTree that returns all the registered patterns - // and their associated handlers - // - have Mount pull those out of the subrouter, prepend the prefix to - // all its patterns, and re-add them to the parent router. + for _, subroute := range subrouter.tree.Routes() { + r.Route(prefix+"/"+subroute.Pattern, subroute.Value) + } +} + +// Use attaches a middleware to the router. +// +// Any routes set on the router will have their handlers decorated by the attached +// middlewares in reverse order (the first middleware attached will be the outer-most: +// first to see requests and the last to see responses). +// +// Use will panic if Route or Mount have already been called on the router - +// middlewares must be set before any routes. +func (r *Router) Use(mw Middleware) { + if r.routeAdded { + panic("all middlewares must be added prior to adding routes") + } + r.middleware = append(r.middleware, mw) } // RouteParams gathers captured path parameters from the request context. @@ -104,18 +119,3 @@ const subrouterPathKey = "subrouter_path" type routeParamsKeyType struct{} var routeParamsKey = routeParamsKeyType{} - -func cloneRequest(start *Request) *Request { - end := &Request{} - *end = *start - - end.URL = &url.URL{} - *end.URL = *start.URL - - if start.TLSState != nil { - end.TLSState = &tls.ConnectionState{} - *end.TLSState = *start.TLSState - } - - return end -} diff --git a/router_test.go b/router_test.go new file mode 100644 index 0000000..6f9c915 --- /dev/null +++ b/router_test.go @@ -0,0 +1,80 @@ +package gus_test + +import ( + "bytes" + "context" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "tildegit.org/tjp/gus" + "tildegit.org/tjp/gus/gemini" +) + +func h1(_ context.Context, _ *gus.Request) *gus.Response { + return gemini.Success("", &bytes.Buffer{}) +} + +func mw1(h gus.Handler) gus.Handler { + return func(ctx context.Context, req *gus.Request) *gus.Response { + resp := h(ctx, req) + resp.Body = io.MultiReader(resp.Body, bytes.NewBufferString("\nmiddleware 1")) + return resp + } +} + +func mw2(h gus.Handler) gus.Handler { + return func(ctx context.Context, req *gus.Request) *gus.Response { + resp := h(ctx, req) + resp.Body = io.MultiReader(resp.Body, bytes.NewBufferString("\nmiddleware 2")) + return resp + } +} + +func TestRouterUse(t *testing.T) { + r := &gus.Router{} + r.Use(mw1) + r.Use(mw2) + r.Route("/", h1) + + request, err := gemini.ParseRequest(bytes.NewBufferString("/\r\n")) + require.Nil(t, err) + + response := r.Handler(context.Background(), request) + require.NotNil(t, response) + + body, err := io.ReadAll(response.Body) + require.Nil(t, err) + + assert.Equal(t, "\nmiddleware 2\nmiddleware 1", string(body)) +} + +func TestRouterMount(t *testing.T) { + outer := &gus.Router{} + outer.Use(mw2) + + inner := &gus.Router{} + inner.Use(mw1) + inner.Route("/bar", h1) + + outer.Mount("/foo", inner) + + request, err := gemini.ParseRequest(bytes.NewBufferString("/foo/bar\r\n")) + require.Nil(t, err) + + response := outer.Handler(context.Background(), request) + require.NotNil(t, response) + + body, err := io.ReadAll(response.Body) + require.Nil(t, err) + + assert.Equal(t, "\nmiddleware 1\nmiddleware 2", string(body)) + + request, err = gemini.ParseRequest(bytes.NewBufferString("/foo\r\n")) + require.Nil(t, err) + + response = outer.Handler(context.Background(), request) + require.Nil(t, response) +}