matrix-org.dendrite/federationapi/federationapi_keys_test.go

240 lines
7.3 KiB
Go

package federationapi
import (
"bytes"
"context"
"crypto/ed25519"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"testing"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/setup/config"
)
type server struct {
name spec.ServerName // server name
validity time.Duration // key validity duration from now
config *config.FederationAPI // skeleton config, from TestMain
fedclient fclient.FederationClient // uses MockRoundTripper
cache *caching.Caches // server-specific cache
api api.FederationInternalAPI // server-specific server key API
}
func (s *server) renew() {
// This updates the validity period to be an hour in the
// future, which is particularly useful in server A and
// server C's cases which have validity either as now or
// in the past.
s.validity = time.Hour
s.config.Matrix.KeyValidityPeriod = s.validity
}
var (
serverKeyID = gomatrixserverlib.KeyID("ed25519:auto")
serverA = &server{name: "a.com", validity: time.Duration(0)} // expires now
serverB = &server{name: "b.com", validity: time.Hour} // expires in an hour
serverC = &server{name: "c.com", validity: -time.Hour} // expired an hour ago
)
var servers = map[string]*server{
"a.com": serverA,
"b.com": serverB,
"c.com": serverC,
}
func TestMain(m *testing.M) {
// Set up the server key API for each "server" that we
// will use in our tests.
os.Exit(func() int {
for _, s := range servers {
// Generate a new key.
_, testPriv, err := ed25519.GenerateKey(nil)
if err != nil {
panic("can't generate identity key: " + err.Error())
}
// Create a new cache but don't enable prometheus!
s.cache = caching.NewRistrettoCache(8*1024*1024, time.Hour, false)
natsInstance := jetstream.NATSInstance{}
// Create a temporary directory for JetStream.
d, err := os.MkdirTemp("./", "jetstream*")
if err != nil {
panic(err)
}
defer os.RemoveAll(d)
// Draw up just enough Dendrite config for the server key
// API to work.
cfg := &config.Dendrite{}
cfg.Defaults(config.DefaultOpts{
Generate: true,
SingleDatabase: false,
})
cfg.Global.ServerName = spec.ServerName(s.name)
cfg.Global.PrivateKey = testPriv
cfg.Global.JetStream.InMemory = true
cfg.Global.JetStream.TopicPrefix = string(s.name[:1])
cfg.Global.JetStream.StoragePath = config.Path(d)
cfg.Global.KeyID = serverKeyID
cfg.Global.KeyValidityPeriod = s.validity
cfg.FederationAPI.KeyPerspectives = nil
f, err := os.CreateTemp(d, "federation_keys_test*.db")
if err != nil {
return -1
}
defer f.Close()
cfg.FederationAPI.Database.ConnectionString = config.DataSource("file:" + f.Name())
s.config = &cfg.FederationAPI
// Create a transport which redirects federation requests to
// the mock round tripper. Since we're not *really* listening for
// federation requests then this will return the key instead.
transport := &http.Transport{}
transport.RegisterProtocol("matrix", &MockRoundTripper{})
// Create the federation client.
s.fedclient = fclient.NewFederationClient(
s.config.Matrix.SigningIdentities(),
fclient.WithTransport(transport),
)
// Finally, build the server key APIs.
processCtx := process.NewProcessContext()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
s.api = NewInternalAPI(processCtx, cfg, cm, &natsInstance, s.fedclient, nil, s.cache, nil, true)
}
// Now that we have built our server key APIs, start the
// rest of the tests.
return m.Run()
}())
}
type MockRoundTripper struct{}
func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err error) {
// Check if the request is looking for keys from a server that
// we know about in the test. The only reason this should go wrong
// is if the test is broken.
s, ok := servers[req.Host]
if !ok {
return nil, fmt.Errorf("server not known: %s", req.Host)
}
// We're intercepting /matrix/key/v2/server requests here, so check
// that the URL supplied in the request is for that.
if req.URL.Path != "/_matrix/key/v2/server" {
return nil, fmt.Errorf("unexpected request path: %s", req.URL.Path)
}
// Get the keys and JSON-ify them.
keys := routing.LocalKeys(s.config, spec.ServerName(req.Host))
body, err := json.MarshalIndent(keys.JSON, "", " ")
if err != nil {
return nil, err
}
// And respond.
res = &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(body)),
}
return
}
func TestServersRequestOwnKeys(t *testing.T) {
// Each server will request its own keys. There's no reason
// for this to fail as each server should know its own keys.
for name, s := range servers {
req := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: s.name,
KeyID: serverKeyID,
}
res, err := s.api.FetchKeys(
context.Background(),
map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{
req: spec.AsTimestamp(time.Now()),
},
)
if err != nil {
t.Fatalf("server could not fetch own key: %s", err)
}
if _, ok := res[req]; !ok {
t.Fatalf("server didn't return its own key in the results")
}
t.Logf("%s's key expires at %s\n", name, res[req].ValidUntilTS.Time())
}
}
func TestRenewalBehaviour(t *testing.T) {
// Server A will request Server C's key but their validity period
// is an hour in the past. We'll retrieve the key as, even though it's
// past its validity, it will be able to verify past events.
req := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: serverC.name,
KeyID: serverKeyID,
}
res, err := serverA.api.FetchKeys(
context.Background(),
map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{
req: spec.AsTimestamp(time.Now()),
},
)
if err != nil {
t.Fatalf("server A failed to retrieve server C key: %s", err)
}
if len(res) != 1 {
t.Fatalf("server C should have returned one key but instead returned %d keys", len(res))
}
if _, ok := res[req]; !ok {
t.Fatalf("server C isn't included in the key fetch response")
}
originalValidity := res[req].ValidUntilTS
// We're now going to kick server C into renewing its key. Since we're
// happy at this point that the key that we already have is from the past
// then repeating a key fetch should cause us to try and renew the key.
// If so, then the new key will end up in our cache.
serverC.renew()
res, err = serverA.api.FetchKeys(
context.Background(),
map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{
req: spec.AsTimestamp(time.Now()),
},
)
if err != nil {
t.Fatalf("server A failed to retrieve server C key: %s", err)
}
if len(res) != 1 {
t.Fatalf("server C should have returned one key but instead returned %d keys", len(res))
}
if _, ok := res[req]; !ok {
t.Fatalf("server C isn't included in the key fetch response")
}
currentValidity := res[req].ValidUntilTS
if originalValidity == currentValidity {
t.Fatalf("server C key should have renewed but didn't")
}
}