116 lines
2.7 KiB
Go
116 lines
2.7 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"io"
|
|
"os"
|
|
"os/user"
|
|
"path/filepath"
|
|
"slices"
|
|
"strings"
|
|
|
|
"tildegit.org/tjp/sliderule"
|
|
"tildegit.org/tjp/sliderule/gemini"
|
|
)
|
|
|
|
func GeminiAuthMiddleware(auth *Auth) sliderule.Middleware {
|
|
if auth == nil {
|
|
return func(inner sliderule.Handler) sliderule.Handler { return inner }
|
|
}
|
|
|
|
return func(inner sliderule.Handler) sliderule.Handler {
|
|
return sliderule.HandlerFunc(func(ctx context.Context, request *sliderule.Request) *sliderule.Response {
|
|
if auth.Strategy.Approve(ctx, request) {
|
|
return inner.Handle(ctx, request)
|
|
}
|
|
|
|
if len(request.TLSState.PeerCertificates) == 0 {
|
|
return gemini.RequireCert("client certificate required")
|
|
}
|
|
return gemini.CertAuthFailure("client certificate rejected")
|
|
})
|
|
}
|
|
}
|
|
|
|
func ClientTLSFile(path string) (AuthStrategy, error) {
|
|
if strings.Contains(path, "~") {
|
|
return UserClientTLSAuth(path), nil
|
|
}
|
|
|
|
f, err := os.Open(path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() { _ = f.Close() }()
|
|
|
|
contents, err := io.ReadAll(f)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
fingerprints := []string{}
|
|
for _, line := range strings.Split(string(contents), "\n") {
|
|
line = strings.Trim(line, " \t\r")
|
|
if len(line) == sha256.Size*2 {
|
|
fingerprints = append(fingerprints, line)
|
|
}
|
|
}
|
|
return ClientTLSAuth(fingerprints), nil
|
|
}
|
|
|
|
func ClientTLS(raw string) AuthStrategy {
|
|
fingerprints := []string{}
|
|
for _, fp := range strings.Split(raw, ",") {
|
|
fp = strings.Trim(fp, " \t\r")
|
|
if len(fp) == sha256.Size*2 {
|
|
fingerprints = append(fingerprints, fp)
|
|
}
|
|
}
|
|
return ClientTLSAuth(fingerprints)
|
|
}
|
|
|
|
type UserClientTLSAuth string
|
|
|
|
func (ca UserClientTLSAuth) Approve(ctx context.Context, request *sliderule.Request) bool {
|
|
u, err := user.Lookup(sliderule.RouteParams(ctx)["username"])
|
|
if err != nil {
|
|
return false
|
|
}
|
|
fpath := resolveTilde(string(ca), u)
|
|
|
|
strat, err := ClientTLSFile(fpath)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return strat.Approve(ctx, request)
|
|
}
|
|
|
|
func resolveTilde(path string, u *user.User) string {
|
|
if strings.HasPrefix(path, "~/") {
|
|
return filepath.Join(u.HomeDir, path[1:])
|
|
}
|
|
return strings.ReplaceAll(path, "~", u.Username)
|
|
}
|
|
|
|
type ClientTLSAuth []string
|
|
|
|
func (ca ClientTLSAuth) Approve(_ context.Context, request *sliderule.Request) bool {
|
|
if request.TLSState == nil || len(request.TLSState.PeerCertificates) == 0 {
|
|
return false
|
|
}
|
|
return slices.Contains(ca, fingerprint(request.TLSState.PeerCertificates[0].Raw))
|
|
}
|
|
|
|
func fingerprint(raw []byte) string {
|
|
hash := sha256.Sum256(raw)
|
|
return hex.EncodeToString(hash[:])
|
|
}
|
|
|
|
type HasClientTLSAuth struct{}
|
|
|
|
func (_ HasClientTLSAuth) Approve(_ context.Context, request *sliderule.Request) bool {
|
|
return request.TLSState != nil && len(request.TLSState.PeerCertificates) > 0
|
|
}
|