This repository has been archived on 2023-05-01. You can view files and clone it, but cannot push or open issues or pull requests.
gus/contrib/tlsauth/auth_test.go

191 lines
4.1 KiB
Go

package tlsauth_test
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"net/url"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tildegit.org/tjp/gus"
"tildegit.org/tjp/gus/contrib/tlsauth"
"tildegit.org/tjp/gus/gemini"
)
func TestIdentify(t *testing.T) {
invoked := false
var leafCert *x509.Certificate
server, client, clientCert := setup(t,
"testdata/server.crt", "testdata/server.key",
"testdata/client1.crt", "testdata/client1.key",
func(_ context.Context, request *gus.Request) *gus.Response {
invoked = true
ident := tlsauth.Identity(request)
if assert.NotNil(t, ident) {
assert.True(t, ident.Equal(leafCert))
}
return nil
},
)
leafCert, err := x509.ParseCertificate(clientCert.Certificate[0])
require.Nil(t, err)
go func() {
_ = server.Serve()
}()
defer server.Close()
requestPath(t, client, server, "/")
assert.True(t, invoked)
}
func TestRequiredAuth(t *testing.T) {
invoked1 := false
invoked2 := false
handler1 := func(_ context.Context, request *gus.Request) *gus.Response {
invoked1 = true
return gemini.Success("", &bytes.Buffer{})
}
handler2 := func(_ context.Context, request *gus.Request) *gus.Response {
invoked2 = true
return gemini.Success("", &bytes.Buffer{})
}
authMiddleware := gus.Filter(tlsauth.RequiredAuth(tlsauth.Allow), nil)
handler1 = gus.Filter(
func(_ context.Context, req *gus.Request) bool {
return strings.HasPrefix(req.Path, "/one")
},
nil,
)(authMiddleware(handler1))
handler2 = authMiddleware(handler2)
server, client, _ := setup(t,
"testdata/server.crt", "testdata/server.key",
"testdata/client1.crt", "testdata/client1.key",
gus.FallthroughHandler(handler1, handler2),
)
go func() {
_ = server.Serve()
}()
defer server.Close()
requestPath(t, client, server, "/one")
assert.True(t, invoked1)
client, _ = clientFor(t, server, "", "") // no client cert this time
requestPath(t, client, server, "/two")
assert.False(t, invoked2)
}
func TestOptionalAuth(t *testing.T) {
invoked1 := false
invoked2 := false
handler1 := func(_ context.Context, request *gus.Request) *gus.Response {
if !strings.HasPrefix(request.Path, "/one") {
return nil
}
invoked1 = true
return gemini.Success("", &bytes.Buffer{})
}
handler2 := func(_ context.Context, request *gus.Request) *gus.Response {
invoked2 = true
return gemini.Success("", &bytes.Buffer{})
}
mw := gus.Filter(tlsauth.OptionalAuth(tlsauth.Reject), nil)
handler := gus.FallthroughHandler(mw(handler1), mw(handler2))
server, client, _ := setup(t,
"testdata/server.crt", "testdata/server.key",
"testdata/client1.crt", "testdata/client1.key",
handler,
)
go func() {
_ = server.Serve()
}()
defer server.Close()
requestPath(t, client, server, "/one")
assert.False(t, invoked1)
client, _ = clientFor(t, server, "", "")
requestPath(t, client, server, "/two")
assert.True(t, invoked2)
}
func setup(
t *testing.T,
serverCertPath string,
serverKeyPath string,
clientCertPath string,
clientKeyPath string,
handler gus.Handler,
) (gus.Server, gemini.Client, tls.Certificate) {
serverTLS, err := gemini.FileTLS(serverCertPath, serverKeyPath)
require.Nil(t, err)
server, err := gemini.NewServer(
context.Background(),
"localhost",
"tcp",
"127.0.0.1:0",
handler,
nil,
serverTLS,
)
require.Nil(t, err)
client, clientCert := clientFor(t, server, clientCertPath, clientKeyPath)
return server, client, clientCert
}
func clientFor(
t *testing.T,
server gus.Server,
certPath string,
keyPath string,
) (gemini.Client, tls.Certificate) {
var clientCert tls.Certificate
var certs []tls.Certificate
if certPath != "" {
c, err := tls.LoadX509KeyPair(certPath, keyPath)
require.Nil(t, err)
clientCert = c
certs = []tls.Certificate{c}
}
return gemini.NewClient(&tls.Config{
Certificates: certs,
InsecureSkipVerify: true,
}), clientCert
}
func requestPath(t *testing.T, client gemini.Client, server gus.Server, path string) *gus.Response {
u, err := url.Parse("gemini://" + server.Address() + path)
require.Nil(t, err)
response, err := client.RoundTrip(&gus.Request{URL: u})
require.Nil(t, err)
return response
}