client-hello-mirror/server.go

257 lines
6.0 KiB
Go

// SPDX-FileCopyrightText: 2022-2023 nervuri <https://nervuri.net/contact>
//
// SPDX-License-Identifier: BSD-3-Clause
package main
import (
"bufio"
"bytes"
"crypto/tls"
_ "embed"
"encoding/binary"
"encoding/json"
"flag"
"io"
"log"
"net"
"os"
"tildegit.org/nervuri/client-hello-mirror/clienthello"
"time"
)
const readTimeout = 10 // seconds
const writeTimeout = 10 // seconds
type tlsConnectionInfo struct {
TlsVersion uint16 `json:"tls_version"`
CipherSuite uint16 `json:"cipher_suite"`
SessionResumed bool `json:"session_resumed"`
NegotiatedProtocol string `json:"alpn_negotiated_protocol"` // ALPN
}
// Connection wrapper that enables exposing the Client Hello to the
// request handler.
// See https://github.com/FiloSottile/mostly-harmless/tree/main/talks/asyncnet
type prefixConn struct {
net.Conn
io.Reader
}
func (c prefixConn) Read(p []byte) (int, error) {
return c.Reader.Read(p)
}
// Output to stderr and exit with error code 1.
// Like log.Fatal, but without the date&time prefix.
// Used before starting the server loop.
func fatalError(err ...any) {
logger := log.New(os.Stderr, "", 0)
logger.Fatal(err...)
}
//go:embed index.html
var html string
//go:embed index.gmi
var gemtext string
// Copy the Client Hello message before starting the TLS handshake.
func peek(conn net.Conn, tlsConfig *tls.Config) {
defer conn.Close()
// Set read timeout.
err := conn.SetReadDeadline(time.Now().Add(readTimeout * time.Second))
if err != nil {
log.Println("SetReadDeadline error: ", err.Error())
return
}
var buf bytes.Buffer
// Copy TLS record header.
_, err = io.CopyN(&buf, conn, 5)
if err != nil {
log.Println(err)
return
}
// Check if this is a TLS handshake record.
if buf.Bytes()[0] != 0x16 {
return
}
// Extract handshake message length.
handshakeMessageLength := binary.BigEndian.Uint16(buf.Bytes()[3:5])
if handshakeMessageLength == 0 {
log.Println("Zero-length handshake message")
return
}
// Copy handshake message.
_, err = io.CopyN(&buf, conn, int64(handshakeMessageLength))
if err != nil {
log.Println(err)
return
}
rawClientHello := buf.Bytes()
// Check if this really is a Client Hello message.
if rawClientHello[5] != 1 {
log.Println("HandshakeType is not client_hello")
return
}
// "Put back" the Client Hello bytes we just read, so that they can be
// used in the TLS handshake. Concatenate the read bytes with the
// unread bytes using a MultiReader, inside a connection wrapper.
pConn := prefixConn{
Conn: conn,
Reader: io.MultiReader(&buf, conn),
}
tlsConnection := tls.Server(pConn, tlsConfig)
err = tlsConnection.Handshake()
if err != nil {
log.Println(err)
return
}
requestHandler(tlsConnection, rawClientHello)
}
func requestHandler(conn *tls.Conn, rawClientHello []byte) {
defer conn.Close()
// Read first line.
scanner := bufio.NewScanner(conn)
scanner.Scan()
line := scanner.Text()
err := scanner.Err()
if err != nil {
log.Println(err)
return
}
// Get request info from first line.
req, err := getRequestInfo(line)
if err != nil {
log.Println(err)
return
}
// Parse Client Hello message.
var clientHelloMsg clienthello.ClientHelloMsg
if !clientHelloMsg.Unmarshal(rawClientHello) {
log.Println("Failed to parse Client Hello")
return
}
// Get TLS connection info.
connectionState := conn.ConnectionState()
tlsConnInfo := tlsConnectionInfo{
connectionState.Version,
connectionState.CipherSuite,
connectionState.DidResume,
connectionState.NegotiatedProtocol,
}
// Prepare response.
var resp = NewResponse(&req)
if req.Path == "/" {
if req.Protocol == httpProtocol {
resp.DisableCaching()
resp.Body = html
} else if req.Protocol == geminiProtocol {
resp.Body = gemtext
}
} else if req.Path == "/json/v1" {
if req.Protocol == httpProtocol {
resp.Headers["Content-Type"] = "application/json"
resp.DisableCaching()
} else if req.Protocol == geminiProtocol {
resp.StatusLine = "20 application/json"
}
output := struct {
ClientHello clienthello.ClientHelloMsg `json:"client_hello"`
TlsConnectionInfo tlsConnectionInfo `json:"connection_info"`
}{
clientHelloMsg,
tlsConnInfo,
}
outputJSON, err := json.MarshalIndent(output, "", " ")
if err != nil {
log.Println(err)
return
}
resp.Body = string(outputJSON)
} else {
if req.Protocol == httpProtocol {
resp.StatusLine = "HTTP/1.1 404 Not Found"
resp.Headers["Content-Type"] = "text/plain; charset=utf-8"
resp.Body = "404 Not Found"
} else if req.Protocol == geminiProtocol {
resp.StatusLine = "51 Not Found!"
}
}
// Set write timeout.
err = conn.SetWriteDeadline(time.Now().Add(writeTimeout * time.Second))
if err != nil {
log.Println("SetWriteDeadline error: ", err.Error())
return
}
// Write response.
_, err = conn.Write(resp.Prepare())
if err != nil {
log.Println(err)
return
}
}
func main() {
var certFile, keyFile string
var userToSwitchTo string
var hostAndPort string
// Parse arguments.
flag.StringVar(&certFile, "c", "", "path to certificate file")
flag.StringVar(&keyFile, "k", "", "path to private key file")
flag.StringVar(&userToSwitchTo, "u", "", "user to switch to, if running as root")
flag.Parse()
hostAndPort = flag.Arg(0)
if certFile == "" || keyFile == "" || hostAndPort == "" {
fatalError("usage: client-hello-mirror -c cert.pem -k key.pem [-u user] host:port")
}
// Load cert.
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
fatalError(err)
}
// TLS config
tlsConfig := tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS10,
NextProtos: []string{"http/1.1"},
}
// Listen for connections.
ln, err := net.Listen("tcp", hostAndPort)
if err != nil {
fatalError(err)
}
defer ln.Close()
dropPrivileges(userToSwitchTo)
log.Println("Server started")
for {
// Wait for a connection.
conn, err := ln.Accept()
if err != nil {
log.Println("Error accepting: ", err.Error())
continue
}
// Process the request.
go peek(conn, &tlsConfig)
}
}