sr-71/auth.go

116 lines
2.7 KiB
Go

package main
import (
"context"
"crypto/sha256"
"encoding/hex"
"io"
"os"
"os/user"
"path/filepath"
"slices"
"strings"
"tildegit.org/tjp/sliderule"
"tildegit.org/tjp/sliderule/gemini"
)
func GeminiAuthMiddleware(auth *Auth) sliderule.Middleware {
if auth == nil {
return func(inner sliderule.Handler) sliderule.Handler { return inner }
}
return func(inner sliderule.Handler) sliderule.Handler {
return sliderule.HandlerFunc(func(ctx context.Context, request *sliderule.Request) *sliderule.Response {
if auth.Strategy.Approve(ctx, request) {
return inner.Handle(ctx, request)
}
if len(request.TLSState.PeerCertificates) == 0 {
return gemini.RequireCert("client certificate required")
}
return gemini.CertAuthFailure("client certificate rejected")
})
}
}
func ClientTLSFile(path string) (AuthStrategy, error) {
if strings.Contains(path, "~") {
return UserClientTLSAuth(path), nil
}
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer func() { _ = f.Close() }()
contents, err := io.ReadAll(f)
if err != nil {
return nil, err
}
fingerprints := []string{}
for _, line := range strings.Split(string(contents), "\n") {
line = strings.Trim(line, " \t\r")
if len(line) == sha256.Size*2 {
fingerprints = append(fingerprints, line)
}
}
return ClientTLSAuth(fingerprints), nil
}
func ClientTLS(raw string) AuthStrategy {
fingerprints := []string{}
for _, fp := range strings.Split(raw, ",") {
fp = strings.Trim(fp, " \t\r")
if len(fp) == sha256.Size*2 {
fingerprints = append(fingerprints, fp)
}
}
return ClientTLSAuth(fingerprints)
}
type UserClientTLSAuth string
func (ca UserClientTLSAuth) Approve(ctx context.Context, request *sliderule.Request) bool {
u, err := user.Lookup(sliderule.RouteParams(ctx)["username"])
if err != nil {
return false
}
fpath := resolveTilde(string(ca), u)
strat, err := ClientTLSFile(fpath)
if err != nil {
return false
}
return strat.Approve(ctx, request)
}
func resolveTilde(path string, u *user.User) string {
if strings.HasPrefix(path, "~/") {
return filepath.Join(u.HomeDir, path[1:])
}
return strings.ReplaceAll(path, "~", u.Username)
}
type ClientTLSAuth []string
func (ca ClientTLSAuth) Approve(_ context.Context, request *sliderule.Request) bool {
if request.TLSState == nil || len(request.TLSState.PeerCertificates) == 0 {
return false
}
return slices.Contains(ca, fingerprint(request.TLSState.PeerCertificates[0].Raw))
}
func fingerprint(raw []byte) string {
hash := sha256.Sum256(raw)
return hex.EncodeToString(hash[:])
}
type HasClientTLSAuth struct{}
func (_ HasClientTLSAuth) Approve(_ context.Context, request *sliderule.Request) bool {
return request.TLSState != nil && len(request.TLSState.PeerCertificates) > 0
}