191 lines
4.1 KiB
Go
191 lines
4.1 KiB
Go
package tlsauth_test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"tildegit.org/tjp/gus"
|
|
"tildegit.org/tjp/gus/contrib/tlsauth"
|
|
"tildegit.org/tjp/gus/gemini"
|
|
)
|
|
|
|
func TestIdentify(t *testing.T) {
|
|
invoked := false
|
|
|
|
var leafCert *x509.Certificate
|
|
server, client, clientCert := setup(t,
|
|
"testdata/server.crt", "testdata/server.key",
|
|
"testdata/client1.crt", "testdata/client1.key",
|
|
func(_ context.Context, request *gus.Request) *gus.Response {
|
|
invoked = true
|
|
|
|
ident := tlsauth.Identity(request)
|
|
if assert.NotNil(t, ident) {
|
|
assert.True(t, ident.Equal(leafCert))
|
|
}
|
|
|
|
return nil
|
|
},
|
|
)
|
|
leafCert, err := x509.ParseCertificate(clientCert.Certificate[0])
|
|
require.Nil(t, err)
|
|
|
|
go func() {
|
|
_ = server.Serve()
|
|
}()
|
|
defer server.Close()
|
|
|
|
requestPath(t, client, server, "/")
|
|
assert.True(t, invoked)
|
|
}
|
|
|
|
func TestRequiredAuth(t *testing.T) {
|
|
invoked1 := false
|
|
invoked2 := false
|
|
|
|
handler1 := func(_ context.Context, request *gus.Request) *gus.Response {
|
|
invoked1 = true
|
|
return gemini.Success("", &bytes.Buffer{})
|
|
}
|
|
|
|
handler2 := func(_ context.Context, request *gus.Request) *gus.Response {
|
|
invoked2 = true
|
|
return gemini.Success("", &bytes.Buffer{})
|
|
}
|
|
|
|
authMiddleware := gus.Filter(tlsauth.RequiredAuth(tlsauth.Allow), nil)
|
|
|
|
handler1 = gus.Filter(
|
|
func(_ context.Context, req *gus.Request) bool {
|
|
return strings.HasPrefix(req.Path, "/one")
|
|
},
|
|
nil,
|
|
)(authMiddleware(handler1))
|
|
handler2 = authMiddleware(handler2)
|
|
|
|
server, client, _ := setup(t,
|
|
"testdata/server.crt", "testdata/server.key",
|
|
"testdata/client1.crt", "testdata/client1.key",
|
|
gus.FallthroughHandler(handler1, handler2),
|
|
)
|
|
|
|
go func() {
|
|
_ = server.Serve()
|
|
}()
|
|
defer server.Close()
|
|
|
|
requestPath(t, client, server, "/one")
|
|
assert.True(t, invoked1)
|
|
|
|
client, _ = clientFor(t, server, "", "") // no client cert this time
|
|
requestPath(t, client, server, "/two")
|
|
assert.False(t, invoked2)
|
|
}
|
|
|
|
func TestOptionalAuth(t *testing.T) {
|
|
invoked1 := false
|
|
invoked2 := false
|
|
|
|
handler1 := func(_ context.Context, request *gus.Request) *gus.Response {
|
|
if !strings.HasPrefix(request.Path, "/one") {
|
|
return nil
|
|
}
|
|
|
|
invoked1 = true
|
|
return gemini.Success("", &bytes.Buffer{})
|
|
}
|
|
|
|
handler2 := func(_ context.Context, request *gus.Request) *gus.Response {
|
|
invoked2 = true
|
|
return gemini.Success("", &bytes.Buffer{})
|
|
}
|
|
|
|
mw := gus.Filter(tlsauth.OptionalAuth(tlsauth.Reject), nil)
|
|
handler := gus.FallthroughHandler(mw(handler1), mw(handler2))
|
|
|
|
server, client, _ := setup(t,
|
|
"testdata/server.crt", "testdata/server.key",
|
|
"testdata/client1.crt", "testdata/client1.key",
|
|
handler,
|
|
)
|
|
|
|
go func() {
|
|
_ = server.Serve()
|
|
}()
|
|
defer server.Close()
|
|
|
|
requestPath(t, client, server, "/one")
|
|
assert.False(t, invoked1)
|
|
|
|
client, _ = clientFor(t, server, "", "")
|
|
requestPath(t, client, server, "/two")
|
|
assert.True(t, invoked2)
|
|
}
|
|
|
|
func setup(
|
|
t *testing.T,
|
|
serverCertPath string,
|
|
serverKeyPath string,
|
|
clientCertPath string,
|
|
clientKeyPath string,
|
|
handler gus.Handler,
|
|
) (gus.Server, gemini.Client, tls.Certificate) {
|
|
serverTLS, err := gemini.FileTLS(serverCertPath, serverKeyPath)
|
|
require.Nil(t, err)
|
|
|
|
server, err := gemini.NewServer(
|
|
context.Background(),
|
|
"localhost",
|
|
"tcp",
|
|
"127.0.0.1:0",
|
|
handler,
|
|
nil,
|
|
serverTLS,
|
|
)
|
|
require.Nil(t, err)
|
|
|
|
client, clientCert := clientFor(t, server, clientCertPath, clientKeyPath)
|
|
|
|
return server, client, clientCert
|
|
}
|
|
|
|
func clientFor(
|
|
t *testing.T,
|
|
server gus.Server,
|
|
certPath string,
|
|
keyPath string,
|
|
) (gemini.Client, tls.Certificate) {
|
|
var clientCert tls.Certificate
|
|
var certs []tls.Certificate
|
|
if certPath != "" {
|
|
c, err := tls.LoadX509KeyPair(certPath, keyPath)
|
|
require.Nil(t, err)
|
|
|
|
clientCert = c
|
|
certs = []tls.Certificate{c}
|
|
}
|
|
|
|
return gemini.NewClient(&tls.Config{
|
|
Certificates: certs,
|
|
InsecureSkipVerify: true,
|
|
}), clientCert
|
|
}
|
|
|
|
func requestPath(t *testing.T, client gemini.Client, server gus.Server, path string) *gus.Response {
|
|
u, err := url.Parse("gemini://" + server.Address() + path)
|
|
require.Nil(t, err)
|
|
|
|
response, err := client.RoundTrip(&gus.Request{URL: u})
|
|
require.Nil(t, err)
|
|
|
|
return response
|
|
}
|