From 6b1c9eafa97cb02dd902ddbf2676c709c86f9625 Mon Sep 17 00:00:00 2001 From: Boris Rybalkin Date: Wed, 1 Mar 2023 21:57:30 +0000 Subject: [PATCH] unix socket support (#2974) ### Pull Request Checklist * [x] I have added Go unit tests or [Complement integration tests](https://github.com/matrix-org/complement) for this PR _or_ I have justified why this PR doesn't need tests * [x] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) _or_ I have already signed off privately Signed-off-by: `Boris Rybalkin ` I need this for Syncloud project (https://github.com/syncloud/platform) where I run multiple apps behind an nginx on the same RPi like device so unix socket is very convenient to not have port conflicts between apps. Also someone opened this Issue: https://github.com/matrix-org/dendrite/issues/2924 --------- Co-authored-by: kegsay Co-authored-by: Till <2353100+S7evinK@users.noreply.github.com> --- cmd/dendrite/main.go | 28 +++++++++++++++-- setup/base/base.go | 38 ++++++++++++++++------ setup/base/base_test.go | 49 +++++++++++++++++++++++++++-- setup/config/config.go | 15 --------- setup/config/config_address.go | 45 ++++++++++++++++++++++++++ setup/config/config_address_test.go | 25 +++++++++++++++ 6 files changed, 171 insertions(+), 29 deletions(-) create mode 100644 setup/config/config_address.go create mode 100644 setup/config/config_address_test.go diff --git a/cmd/dendrite/main.go b/cmd/dendrite/main.go index e8ff0a478..1ae348cfa 100644 --- a/cmd/dendrite/main.go +++ b/cmd/dendrite/main.go @@ -16,6 +16,7 @@ package main import ( "flag" + "io/fs" "github.com/sirupsen/logrus" @@ -30,6 +31,12 @@ import ( ) var ( + unixSocket = flag.String("unix-socket", "", + "EXPERIMENTAL(unstable): The HTTP listening unix socket for the server (disables http[s]-bind-address feature)", + ) + unixSocketPermission = flag.Int("unix-socket-permission", 0755, + "EXPERIMENTAL(unstable): The HTTP listening unix socket permission for the server", + ) httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server") httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server") certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS") @@ -38,8 +45,23 @@ var ( func main() { cfg := setup.ParseFlags(true) - httpAddr := config.HTTPAddress("http://" + *httpBindAddr) - httpsAddr := config.HTTPAddress("https://" + *httpsBindAddr) + httpAddr := config.ServerAddress{} + httpsAddr := config.ServerAddress{} + if *unixSocket == "" { + http, err := config.HTTPAddress("http://" + *httpBindAddr) + if err != nil { + logrus.WithError(err).Fatalf("Failed to parse http address") + } + httpAddr = http + https, err := config.HTTPAddress("https://" + *httpsBindAddr) + if err != nil { + logrus.WithError(err).Fatalf("Failed to parse https address") + } + httpsAddr = https + } else { + httpAddr = config.UnixSocketAddress(*unixSocket, fs.FileMode(*unixSocketPermission)) + } + options := []basepkg.BaseDendriteOptions{} base := basepkg.NewBaseDendrite(cfg, options...) @@ -92,7 +114,7 @@ func main() { base.SetupAndServeHTTP(httpAddr, nil, nil) }() // Handle HTTPS if certificate and key are provided - if *certFile != "" && *keyFile != "" { + if *unixSocket == "" && *certFile != "" && *keyFile != "" { go func() { base.SetupAndServeHTTP(httpsAddr, certFile, keyFile) }() diff --git a/setup/base/base.go b/setup/base/base.go index aabdd7937..dfe48ff3c 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -20,9 +20,11 @@ import ( "database/sql" "embed" "encoding/json" + "errors" "fmt" "html/template" "io" + "io/fs" "net" "net/http" _ "net/http/pprof" @@ -85,8 +87,6 @@ type BaseDendrite struct { startupLock sync.Mutex } -const NoListener = "" - const HTTPServerTimeout = time.Minute * 5 type BaseDendriteOptions int @@ -345,18 +345,17 @@ func (b *BaseDendrite) ConfigureAdminEndpoints() { // SetupAndServeHTTP sets up the HTTP server to serve client & federation APIs // and adds a prometheus handler under /_dendrite/metrics. func (b *BaseDendrite) SetupAndServeHTTP( - externalHTTPAddr config.HTTPAddress, + externalHTTPAddr config.ServerAddress, certFile, keyFile *string, ) { // Manually unlocked right before actually serving requests, // as we don't return from this method (defer doesn't work). b.startupLock.Lock() - externalAddr, _ := externalHTTPAddr.Address() externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() externalServ := &http.Server{ - Addr: string(externalAddr), + Addr: externalHTTPAddr.Address, WriteTimeout: HTTPServerTimeout, Handler: externalRouter, BaseContext: func(_ net.Listener) context.Context { @@ -419,7 +418,7 @@ func (b *BaseDendrite) SetupAndServeHTTP( b.startupLock.Unlock() - if externalAddr != NoListener { + if externalHTTPAddr.Enabled() { go func() { var externalShutdown atomic.Bool // RegisterOnShutdown can be called more than once logrus.Infof("Starting external listener on %s", externalServ.Addr) @@ -437,9 +436,30 @@ func (b *BaseDendrite) SetupAndServeHTTP( } } } else { - if err := externalServ.ListenAndServe(); err != nil { - if err != http.ErrServerClosed { - logrus.WithError(err).Fatal("failed to serve HTTP") + if externalHTTPAddr.IsUnixSocket() { + err := os.Remove(externalHTTPAddr.Address) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + logrus.WithError(err).Fatal("failed to remove existing unix socket") + } + listener, err := net.Listen(externalHTTPAddr.Network(), externalHTTPAddr.Address) + if err != nil { + logrus.WithError(err).Fatal("failed to serve unix socket") + } + err = os.Chmod(externalHTTPAddr.Address, externalHTTPAddr.UnixSocketPermission) + if err != nil { + logrus.WithError(err).Fatal("failed to set unix socket permissions") + } + if err := externalServ.Serve(listener); err != nil { + if err != http.ErrServerClosed { + logrus.WithError(err).Fatal("failed to serve unix socket") + } + } + + } else { + if err := externalServ.ListenAndServe(); err != nil { + if err != http.ErrServerClosed { + logrus.WithError(err).Fatal("failed to serve HTTP") + } } } } diff --git a/setup/base/base_test.go b/setup/base/base_test.go index d906294c0..658dc5b03 100644 --- a/setup/base/base_test.go +++ b/setup/base/base_test.go @@ -2,10 +2,13 @@ package base_test import ( "bytes" + "context" "embed" "html/template" + "net" "net/http" "net/http/httptest" + "path" "testing" "time" @@ -18,7 +21,7 @@ import ( //go:embed static/*.gotmpl var staticContent embed.FS -func TestLandingPage(t *testing.T) { +func TestLandingPage_Tcp(t *testing.T) { // generate the expected result tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl")) expectedRes := &bytes.Buffer{} @@ -35,7 +38,9 @@ func TestLandingPage(t *testing.T) { s.Close() // start base with the listener and wait for it to be started - go b.SetupAndServeHTTP(config.HTTPAddress(s.URL), nil, nil) + address, err := config.HTTPAddress(s.URL) + assert.NoError(t, err) + go b.SetupAndServeHTTP(address, nil, nil) time.Sleep(time.Millisecond * 10) // When hitting /, we should be redirected to /_matrix/static, which should contain the landing page @@ -55,3 +60,43 @@ func TestLandingPage(t *testing.T) { // Using .String() for user friendly output assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch") } + +func TestLandingPage_UnixSocket(t *testing.T) { + // generate the expected result + tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl")) + expectedRes := &bytes.Buffer{} + err := tmpl.ExecuteTemplate(expectedRes, "index.gotmpl", map[string]string{ + "Version": internal.VersionString(), + }) + assert.NoError(t, err) + + b, _, _ := testrig.Base(nil) + defer b.Close() + + tempDir := t.TempDir() + socket := path.Join(tempDir, "socket") + // start base with the listener and wait for it to be started + address := config.UnixSocketAddress(socket, 0755) + assert.NoError(t, err) + go b.SetupAndServeHTTP(address, nil, nil) + time.Sleep(time.Millisecond * 100) + + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", socket) + }, + }, + } + resp, err := client.Get("http://unix/") + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // read the response + buf := &bytes.Buffer{} + _, err = buf.ReadFrom(resp.Body) + assert.NoError(t, err) + + // Using .String() for user friendly output + assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch") +} diff --git a/setup/config/config.go b/setup/config/config.go index 848766162..1a25f71eb 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -19,7 +19,6 @@ import ( "encoding/pem" "fmt" "io" - "net/url" "os" "path/filepath" "regexp" @@ -131,20 +130,6 @@ func (d DataSource) IsPostgres() bool { // A Topic in kafka. type Topic string -// An Address to listen on. -type Address string - -// An HTTPAddress to listen on, starting with either http:// or https://. -type HTTPAddress string - -func (h HTTPAddress) Address() (Address, error) { - url, err := url.Parse(string(h)) - if err != nil { - return "", err - } - return Address(url.Host), nil -} - // FileSizeBytes is a file size in bytes type FileSizeBytes int64 diff --git a/setup/config/config_address.go b/setup/config/config_address.go new file mode 100644 index 000000000..0e4f0296f --- /dev/null +++ b/setup/config/config_address.go @@ -0,0 +1,45 @@ +package config + +import ( + "io/fs" + "net/url" +) + +const ( + NetworkTCP = "tcp" + NetworkUnix = "unix" +) + +type ServerAddress struct { + Address string + Scheme string + UnixSocketPermission fs.FileMode +} + +func (s ServerAddress) Enabled() bool { + return s.Address != "" +} + +func (s ServerAddress) IsUnixSocket() bool { + return s.Scheme == NetworkUnix +} + +func (s ServerAddress) Network() string { + if s.Scheme == NetworkUnix { + return NetworkUnix + } else { + return NetworkTCP + } +} + +func UnixSocketAddress(path string, perm fs.FileMode) ServerAddress { + return ServerAddress{Address: path, Scheme: NetworkUnix, UnixSocketPermission: perm} +} + +func HTTPAddress(urlAddress string) (ServerAddress, error) { + parsedUrl, err := url.Parse(urlAddress) + if err != nil { + return ServerAddress{}, err + } + return ServerAddress{parsedUrl.Host, parsedUrl.Scheme, 0}, nil +} diff --git a/setup/config/config_address_test.go b/setup/config/config_address_test.go new file mode 100644 index 000000000..1be484fd5 --- /dev/null +++ b/setup/config/config_address_test.go @@ -0,0 +1,25 @@ +package config + +import ( + "io/fs" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHttpAddress_ParseGood(t *testing.T) { + address, err := HTTPAddress("http://localhost:123") + assert.NoError(t, err) + assert.Equal(t, "localhost:123", address.Address) + assert.Equal(t, "tcp", address.Network()) +} + +func TestHttpAddress_ParseBad(t *testing.T) { + _, err := HTTPAddress(":") + assert.Error(t, err) +} + +func TestUnixSocketAddress_Network(t *testing.T) { + address := UnixSocketAddress("/tmp", fs.FileMode(0755)) + assert.Equal(t, "unix", address.Network()) +}