package main import ( "context" "crypto/sha256" "encoding/hex" "io" "os" "os/user" "path/filepath" "slices" "strings" "tildegit.org/tjp/sliderule" "tildegit.org/tjp/sliderule/gemini" ) func GeminiAuthMiddleware(auth *Auth) sliderule.Middleware { if auth == nil { return func(inner sliderule.Handler) sliderule.Handler { return inner } } return func(inner sliderule.Handler) sliderule.Handler { return sliderule.HandlerFunc(func(ctx context.Context, request *sliderule.Request) *sliderule.Response { if auth.Strategy.Approve(ctx, request) { return inner.Handle(ctx, request) } if len(request.TLSState.PeerCertificates) == 0 { return gemini.RequireCert("client certificate required") } return gemini.CertAuthFailure("client certificate rejected") }) } } func ClientTLSFile(path string) (AuthStrategy, error) { if strings.Contains(path, "~") { return UserClientTLSAuth(path), nil } f, err := os.Open(path) if err != nil { return nil, err } defer func() { _ = f.Close() }() contents, err := io.ReadAll(f) if err != nil { return nil, err } fingerprints := []string{} for _, line := range strings.Split(string(contents), "\n") { line = strings.Trim(line, " \t\r") if len(line) == sha256.Size*2 { fingerprints = append(fingerprints, line) } } return ClientTLSAuth(fingerprints), nil } func ClientTLS(raw string) AuthStrategy { fingerprints := []string{} for _, fp := range strings.Split(raw, ",") { fp = strings.Trim(fp, " \t\r") if len(fp) == sha256.Size*2 { fingerprints = append(fingerprints, fp) } } return ClientTLSAuth(fingerprints) } type UserClientTLSAuth string func (ca UserClientTLSAuth) Approve(ctx context.Context, request *sliderule.Request) bool { u, err := user.Lookup(sliderule.RouteParams(ctx)["username"]) if err != nil { return false } fpath := resolveTilde(string(ca), u) strat, err := ClientTLSFile(fpath) if err != nil { return false } return strat.Approve(ctx, request) } func resolveTilde(path string, u *user.User) string { if strings.HasPrefix(path, "~/") { return filepath.Join(u.HomeDir, path[1:]) } return strings.ReplaceAll(path, "~", u.Username) } type ClientTLSAuth []string func (ca ClientTLSAuth) Approve(_ context.Context, request *sliderule.Request) bool { if request.TLSState == nil || len(request.TLSState.PeerCertificates) == 0 { return false } return slices.Contains(ca, fingerprint(request.TLSState.PeerCertificates[0].Raw)) } func fingerprint(raw []byte) string { hash := sha256.Sum256(raw) return hex.EncodeToString(hash[:]) } type HasClientTLSAuth struct{} func (_ HasClientTLSAuth) Approve(_ context.Context, request *sliderule.Request) bool { return request.TLSState != nil && len(request.TLSState.PeerCertificates) > 0 }