Initial Store & Forward Implementation (#2917)

This adds store & forward relays into dendrite for p2p.
A few things have changed:
- new relay api serves new http endpoints for s&f federation
- updated outbound federation queueing which will attempt to forward
using s&f if appropriate
- database entries to track s&f relays for other nodes
This commit is contained in:
devonh 2023-01-23 17:55:12 +00:00 committed by GitHub
parent 48fa869fa3
commit 5b73592f5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
77 changed files with 7646 additions and 1373 deletions

1
.gitignore vendored
View File

@ -56,6 +56,7 @@ dendrite.yaml
# Database files
*.db
*.db-journal
# Log files
*.log*

View File

@ -41,13 +41,16 @@ import (
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/relayapi"
relayServerAPI "github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi"
userapiAPI "github.com/matrix-org/dendrite/userapi/api"
@ -71,10 +74,12 @@ const (
PeerTypeMulticast = pineconeRouter.PeerTypeMulticast
PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth
PeerTypeBonjour = pineconeRouter.PeerTypeBonjour
relayServerRetryInterval = time.Second * 30
)
type DendriteMonolith struct {
logger logrus.Logger
baseDendrite *base.BaseDendrite
PineconeRouter *pineconeRouter.Router
PineconeMulticast *pineconeMulticast.Multicast
PineconeQUIC *pineconeSessions.Sessions
@ -83,8 +88,9 @@ type DendriteMonolith struct {
CacheDirectory string
listener net.Listener
httpServer *http.Server
processContext *process.ProcessContext
userAPI userapiAPI.UserInternalAPI
federationAPI api.FederationInternalAPI
relayServersQueried map[gomatrixserverlib.ServerName]bool
}
func (m *DendriteMonolith) PublicKey() string {
@ -326,6 +332,7 @@ func (m *DendriteMonolith) Start() {
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", filepath.Join(m.StorageDirectory, prefix)))
cfg.MediaAPI.BasePath = config.Path(filepath.Join(m.CacheDirectory, "media"))
cfg.MediaAPI.AbsBasePath = config.Path(filepath.Join(m.CacheDirectory, "media"))
cfg.RelayAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-relayapi.db", filepath.Join(m.StorageDirectory, prefix)))
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
cfg.ClientAPI.RegistrationDisabled = false
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
@ -335,9 +342,9 @@ func (m *DendriteMonolith) Start() {
panic(err)
}
base := base.NewBaseDendrite(cfg, "Monolith")
base := base.NewBaseDendrite(cfg, "Monolith", base.DisableMetrics)
m.baseDendrite = base
base.ConfigureAdminEndpoints()
defer base.Close() // nolint: errcheck
federation := conn.CreateFederationClient(base, m.PineconeQUIC)
@ -346,11 +353,11 @@ func (m *DendriteMonolith) Start() {
rsAPI := roomserver.NewInternalAPI(base)
fsAPI := federationapi.NewInternalAPI(
m.federationAPI = federationapi.NewInternalAPI(
base, federation, rsAPI, base.Caches, keyRing, true,
)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, m.federationAPI, rsAPI)
m.userAPI = userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(m.userAPI)
@ -358,10 +365,24 @@ func (m *DendriteMonolith) Start() {
// The underlying roomserver implementation needs to be able to call the fedsender.
// This is different to rsAPI which can be the http client which doesn't need this dependency
rsAPI.SetFederationAPI(fsAPI, keyRing)
rsAPI.SetFederationAPI(m.federationAPI, keyRing)
userProvider := users.NewPineconeUserProvider(m.PineconeRouter, m.PineconeQUIC, m.userAPI, federation)
roomProvider := rooms.NewPineconeRoomProvider(m.PineconeRouter, m.PineconeQUIC, fsAPI, federation)
roomProvider := rooms.NewPineconeRoomProvider(m.PineconeRouter, m.PineconeQUIC, m.federationAPI, federation)
js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
producer := &producers.SyncAPIProducer{
JetStream: js,
TopicReceiptEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent),
TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
TopicTypingEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent),
TopicPresenceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent),
TopicDeviceListUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
TopicSigningKeyUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
Config: &base.Cfg.FederationAPI,
UserAPI: m.userAPI,
}
relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer)
monolith := setup.Monolith{
Config: base.Cfg,
@ -370,10 +391,11 @@ func (m *DendriteMonolith) Start() {
KeyRing: keyRing,
AppserviceAPI: asAPI,
FederationAPI: fsAPI,
FederationAPI: m.federationAPI,
RoomserverAPI: rsAPI,
UserAPI: m.userAPI,
KeyAPI: keyAPI,
RelayAPI: relayAPI,
ExtPublicRoomsProvider: roomProvider,
ExtUserDirectoryProvider: userProvider,
}
@ -411,8 +433,6 @@ func (m *DendriteMonolith) Start() {
Handler: h2c.NewHandler(pMux, h2s),
}
m.processContext = base.ProcessContext
go func() {
m.logger.Info("Listening on ", cfg.Global.ServerName)
@ -420,7 +440,7 @@ func (m *DendriteMonolith) Start() {
case net.ErrClosed, http.ErrServerClosed:
m.logger.Info("Stopped listening on ", cfg.Global.ServerName)
default:
m.logger.Fatal(err)
m.logger.Error("Stopped listening on ", cfg.Global.ServerName)
}
}()
go func() {
@ -430,33 +450,44 @@ func (m *DendriteMonolith) Start() {
case net.ErrClosed, http.ErrServerClosed:
m.logger.Info("Stopped listening on ", cfg.Global.ServerName)
default:
m.logger.Fatal(err)
m.logger.Error("Stopped listening on ", cfg.Global.ServerName)
}
}()
go func(ch <-chan pineconeEvents.Event) {
eLog := logrus.WithField("pinecone", "events")
stopRelayServerSync := make(chan bool)
relayRetriever := RelayServerRetriever{
Context: context.Background(),
ServerName: gomatrixserverlib.ServerName(m.PineconeRouter.PublicKey().String()),
FederationAPI: m.federationAPI,
relayServersQueried: make(map[gomatrixserverlib.ServerName]bool),
RelayAPI: monolith.RelayAPI,
running: *atomic.NewBool(false),
}
relayRetriever.InitializeRelayServers(eLog)
for event := range ch {
switch e := event.(type) {
case pineconeEvents.PeerAdded:
if !relayRetriever.running.Load() {
go relayRetriever.SyncRelayServers(stopRelayServerSync)
}
case pineconeEvents.PeerRemoved:
case pineconeEvents.TreeParentUpdate:
case pineconeEvents.SnakeDescUpdate:
case pineconeEvents.TreeRootAnnUpdate:
case pineconeEvents.SnakeEntryAdded:
case pineconeEvents.SnakeEntryRemoved:
if relayRetriever.running.Load() && m.PineconeRouter.TotalPeerCount() == 0 {
stopRelayServerSync <- true
}
case pineconeEvents.BroadcastReceived:
eLog.Info("Broadcast received from: ", e.PeerID)
// eLog.Info("Broadcast received from: ", e.PeerID)
req := &api.PerformWakeupServersRequest{
ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)},
}
res := &api.PerformWakeupServersResponse{}
if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil {
logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID)
if err := m.federationAPI.PerformWakeupServers(base.Context(), req, res); err != nil {
eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID)
}
case pineconeEvents.BandwidthReport:
default:
}
}
@ -464,12 +495,106 @@ func (m *DendriteMonolith) Start() {
}
func (m *DendriteMonolith) Stop() {
m.processContext.ShutdownDendrite()
m.baseDendrite.Close()
m.baseDendrite.WaitForShutdown()
_ = m.listener.Close()
m.PineconeMulticast.Stop()
_ = m.PineconeQUIC.Close()
_ = m.PineconeRouter.Close()
m.processContext.WaitForComponentsToFinish()
}
type RelayServerRetriever struct {
Context context.Context
ServerName gomatrixserverlib.ServerName
FederationAPI api.FederationInternalAPI
RelayAPI relayServerAPI.RelayInternalAPI
relayServersQueried map[gomatrixserverlib.ServerName]bool
queriedServersMutex sync.Mutex
running atomic.Bool
}
func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) {
request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(m.ServerName)}
response := api.P2PQueryRelayServersResponse{}
err := m.FederationAPI.P2PQueryRelayServers(m.Context, &request, &response)
if err != nil {
eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error())
}
for _, server := range response.RelayServers {
m.relayServersQueried[server] = false
}
eLog.Infof("Registered relay servers: %v", response.RelayServers)
}
func (m *RelayServerRetriever) SyncRelayServers(stop <-chan bool) {
defer m.running.Store(false)
t := time.NewTimer(relayServerRetryInterval)
for {
relayServersToQuery := []gomatrixserverlib.ServerName{}
func() {
m.queriedServersMutex.Lock()
defer m.queriedServersMutex.Unlock()
for server, complete := range m.relayServersQueried {
if !complete {
relayServersToQuery = append(relayServersToQuery, server)
}
}
}()
if len(relayServersToQuery) == 0 {
// All relay servers have been synced.
return
}
m.queryRelayServers(relayServersToQuery)
t.Reset(relayServerRetryInterval)
select {
case <-stop:
if !t.Stop() {
<-t.C
}
return
case <-t.C:
}
}
}
func (m *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool {
m.queriedServersMutex.Lock()
defer m.queriedServersMutex.Unlock()
result := map[gomatrixserverlib.ServerName]bool{}
for server, queried := range m.relayServersQueried {
result[server] = queried
}
return result
}
func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) {
logrus.Info("querying relay servers for any available transactions")
for _, server := range relayServers {
userID, err := gomatrixserverlib.NewUserID("@user:"+string(m.ServerName), false)
if err != nil {
return
}
err = m.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server)
if err == nil {
func() {
m.queriedServersMutex.Lock()
defer m.queriedServersMutex.Unlock()
m.relayServersQueried[server] = true
}()
// TODO : What happens if your relay receives new messages after this point?
// Should you continue to check with them, or should they try and contact you?
// They could send a "new_async_events" message your way maybe?
// Then you could mark them as needing to be queried again.
// What if you miss this message?
// Maybe you should try querying them again after a certain period of time as a backup?
} else {
logrus.Errorf("Failed querying relay server: %s", err.Error())
}
}
}
const MaxFrameSize = types.MaxFrameSize

View File

@ -0,0 +1,198 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gobind
import (
"context"
"fmt"
"net"
"testing"
"time"
"github.com/matrix-org/dendrite/federationapi/api"
relayServerAPI "github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"gotest.tools/v3/poll"
)
var TestBuf = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
type TestNetConn struct {
net.Conn
shouldFail bool
}
func (t *TestNetConn) Read(b []byte) (int, error) {
if t.shouldFail {
return 0, fmt.Errorf("Failed")
} else {
n := copy(b, TestBuf)
return n, nil
}
}
func (t *TestNetConn) Write(b []byte) (int, error) {
if t.shouldFail {
return 0, fmt.Errorf("Failed")
} else {
return len(b), nil
}
}
func (t *TestNetConn) Close() error {
if t.shouldFail {
return fmt.Errorf("Failed")
} else {
return nil
}
}
func TestConduitStoresPort(t *testing.T) {
conduit := Conduit{port: 7}
assert.Equal(t, 7, conduit.Port())
}
func TestConduitRead(t *testing.T) {
conduit := Conduit{conn: &TestNetConn{}}
b := make([]byte, len(TestBuf))
bytes, err := conduit.Read(b)
assert.NoError(t, err)
assert.Equal(t, len(TestBuf), bytes)
assert.Equal(t, TestBuf, b)
}
func TestConduitReadCopy(t *testing.T) {
conduit := Conduit{conn: &TestNetConn{}}
result, err := conduit.ReadCopy()
assert.NoError(t, err)
assert.Equal(t, TestBuf, result)
}
func TestConduitWrite(t *testing.T) {
conduit := Conduit{conn: &TestNetConn{}}
bytes, err := conduit.Write(TestBuf)
assert.NoError(t, err)
assert.Equal(t, len(TestBuf), bytes)
}
func TestConduitClose(t *testing.T) {
conduit := Conduit{conn: &TestNetConn{}}
err := conduit.Close()
assert.NoError(t, err)
assert.True(t, conduit.closed.Load())
}
func TestConduitReadClosed(t *testing.T) {
conduit := Conduit{conn: &TestNetConn{}}
err := conduit.Close()
assert.NoError(t, err)
b := make([]byte, len(TestBuf))
_, err = conduit.Read(b)
assert.Error(t, err)
}
func TestConduitReadCopyClosed(t *testing.T) {
conduit := Conduit{conn: &TestNetConn{}}
err := conduit.Close()
assert.NoError(t, err)
_, err = conduit.ReadCopy()
assert.Error(t, err)
}
func TestConduitWriteClosed(t *testing.T) {
conduit := Conduit{conn: &TestNetConn{}}
err := conduit.Close()
assert.NoError(t, err)
_, err = conduit.Write(TestBuf)
assert.Error(t, err)
}
func TestConduitReadCopyFails(t *testing.T) {
conduit := Conduit{conn: &TestNetConn{shouldFail: true}}
_, err := conduit.ReadCopy()
assert.Error(t, err)
}
var testRelayServers = []gomatrixserverlib.ServerName{"relay1", "relay2"}
type FakeFedAPI struct {
api.FederationInternalAPI
}
func (f *FakeFedAPI) P2PQueryRelayServers(ctx context.Context, req *api.P2PQueryRelayServersRequest, res *api.P2PQueryRelayServersResponse) error {
res.RelayServers = testRelayServers
return nil
}
type FakeRelayAPI struct {
relayServerAPI.RelayInternalAPI
}
func (r *FakeRelayAPI) PerformRelayServerSync(ctx context.Context, userID gomatrixserverlib.UserID, relayServer gomatrixserverlib.ServerName) error {
return nil
}
func TestRelayRetrieverInitialization(t *testing.T) {
retriever := RelayServerRetriever{
Context: context.Background(),
ServerName: "server",
relayServersQueried: make(map[gomatrixserverlib.ServerName]bool),
FederationAPI: &FakeFedAPI{},
RelayAPI: &FakeRelayAPI{},
}
retriever.InitializeRelayServers(logrus.WithField("test", "relay"))
relayServers := retriever.GetQueriedServerStatus()
assert.Equal(t, 2, len(relayServers))
}
func TestRelayRetrieverSync(t *testing.T) {
retriever := RelayServerRetriever{
Context: context.Background(),
ServerName: "server",
relayServersQueried: make(map[gomatrixserverlib.ServerName]bool),
FederationAPI: &FakeFedAPI{},
RelayAPI: &FakeRelayAPI{},
}
retriever.InitializeRelayServers(logrus.WithField("test", "relay"))
relayServers := retriever.GetQueriedServerStatus()
assert.Equal(t, 2, len(relayServers))
stopRelayServerSync := make(chan bool)
go retriever.SyncRelayServers(stopRelayServerSync)
check := func(log poll.LogT) poll.Result {
relayServers := retriever.GetQueriedServerStatus()
for _, queried := range relayServers {
if !queried {
return poll.Continue("waiting for all servers to be queried")
}
}
stopRelayServerSync <- true
return poll.Success()
}
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
}
func TestMonolithStarts(t *testing.T) {
monolith := DendriteMonolith{}
monolith.Start()
monolith.PublicKey()
monolith.Stop()
}

View File

@ -0,0 +1,59 @@
## Relay Server Architecture
Relay Servers function similar to the way physical mail drop boxes do.
A node can have many associated relay servers. Matrix events can be sent to them instead of to the destination node, and the destination node will eventually retrieve them from the relay server.
Nodes that want to send events to an offline node need to know what relay servers are associated with their intended destination.
Currently this is manually configured in the dendrite database. In the future this information could be configurable in the app and shared automatically via other means.
Currently events are sent as complete Matrix Transactions.
Transactions include a list of PDUs, (which contain, among other things, lists of authorization events, previous events, and signatures) a list of EDUs, and other information about the transaction.
There is no additional information sent along with the transaction other than what is typically added to them during Matrix federation today.
In the future this will probably need to change in order to handle more complex room state resolution during p2p usage.
### Relay Server Architecture
```
0 +--------------------+
+----------------------------------------+ | P2P Node A |
| Relay Server | | +--------+ |
| | | | Client | |
| +--------------------+ | | +--------+ |
| | Relay Server API | | | | |
| | | | | V |
| .--------. 2 | +-------------+ | | 1 | +------------+ |
| |`--------`| <----- | Forwarder | <------------- | Homeserver | |
| | Database | | +-------------+ | | | +------------+ |
| `----------` | | | +--------------------+
| ^ | | |
| | 4 | +-------------+ | |
| `------------ | Retriever | <------. +--------------------+
| | +-------------+ | | | | P2P Node B |
| | | | | | +--------+ |
| +--------------------+ | | | | Client | |
| | | | +--------+ |
+----------------------------------------+ | | | |
| | V |
3 | | +------------+ |
`------ | Homeserver | |
| +------------+ |
+--------------------+
```
- 0: This relay server is currently only acting on behalf of `P2P Node B`. It will only receive, and later forward events that are destined for `P2P Node B`.
- 1: When `P2P Node A` fails sending directly to `P2P Node B` (after a configurable number of attempts), it checks for any known relay servers associated with `P2P Node B` and sends to all of them.
- If sending to any of the relay servers succeeds, that transaction is considered to be successfully sent.
- 2: The relay server `forwarder` stores the transaction json in its database and marks it as destined for `P2P Node B`.
- 3: When `P2P Node B` comes online, it queries all its relay servers for any missed messages.
- 4: The relay server `retriever` will look in its database for any transactions that are destined for `P2P Node B` and returns them one at a time.
For now, it is important that we dont design out a hybrid approach of having both sender-side and recipient-side relay servers.
Both approaches make sense and determining which makes for a better experience depends on the use case.
#### Sender-Side Relay Servers
If we are running around truly ad-hoc, and I don't know when or where you will be able to pick up messages, then having a sender designated server makes sense to give things the best chance at making their way to the destination.
But in order to achieve this, you are either relying on p2p presence broadcasts for the relay to know when to try forwarding (which means you are in a pretty small network), or the relay just keeps on periodically attempting to forward to the destination which will lead to a lot of extra traffic on the network.
#### Recipient-Side Relay Servers
If we have agreed to some static relay server before going off and doing other things, or if we are talking about more global p2p federation, then having a recipient designated relay server can cut down on redundant traffic since it will sit there idle until the recipient pulls events from it.

View File

@ -24,3 +24,42 @@ Then point your favourite Matrix client to the homeserver URL`http://localhost:
If your peering connection is operational then you should see a `Connected TCP:` line in the log output. If not then try a different peer.
Once logged in, you should be able to open the room directory or join a room by its ID.
## Store & Forward Relays
To test out the store & forward relay functionality, you need a minimum of 3 instances.
One instance will act as the relay, and the other two instances will be the users trying to communicate.
Then you can send messages between the two nodes and watch as the relay is used if the receiving node is offline.
### Launching the Nodes
Relay Server:
```
go run cmd/dendrite-demo-pinecone/main.go -dir relay/ -listen "[::]:49000"
```
Node 1:
```
go run cmd/dendrite-demo-pinecone/main.go -dir node-1/ -peer "[::]:49000" -port 8007
```
Node 2:
```
go run cmd/dendrite-demo-pinecone/main.go -dir node-2/ -peer "[::]:49000" -port 8009
```
### Database Setup
At the moment, the database must be manually configured.
For both `Node 1` and `Node 2` add the following entries to their respective `relay_server` table in the federationapi database:
```
server_name: {node_1_public_key}, relay_server_name: {relay_public_key}
server_name: {node_2_public_key}, relay_server_name: {relay_public_key}
```
After editing the database you will need to relaunch the nodes for the changes to be picked up by dendrite.
### Testing
Now you can run two separate instances of element and connect them to `Node 1` and `Node 2`.
You can shutdown one of the nodes and continue sending messages. If you wait long enough, the message will be sent to the relay server. (you can see this in the log output of the relay server)

View File

@ -38,16 +38,21 @@ import (
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/relayapi"
relayServerAPI "github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/gomatrixserverlib"
"go.uber.org/atomic"
pineconeConnections "github.com/matrix-org/pinecone/connections"
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
@ -66,6 +71,8 @@ var (
instanceDir = flag.String("dir", ".", "the directory to store the databases in (if --config not specified)")
)
const relayServerRetryInterval = time.Second * 30
// nolint:gocyclo
func main() {
flag.Parse()
@ -139,6 +146,7 @@ func main() {
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(*instanceDir, *instanceName)))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(*instanceDir, *instanceName)))
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", filepath.Join(*instanceDir, *instanceName)))
cfg.RelayAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-relayapi.db", filepath.Join(*instanceDir, *instanceName)))
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName)))
cfg.ClientAPI.RegistrationDisabled = false
@ -224,6 +232,20 @@ func main() {
userProvider := users.NewPineconeUserProvider(pRouter, pQUIC, userAPI, federation)
roomProvider := rooms.NewPineconeRoomProvider(pRouter, pQUIC, fsAPI, federation)
js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
producer := &producers.SyncAPIProducer{
JetStream: js,
TopicReceiptEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent),
TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
TopicTypingEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent),
TopicPresenceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent),
TopicDeviceListUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
TopicSigningKeyUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
Config: &base.Cfg.FederationAPI,
UserAPI: userAPI,
}
relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer)
monolith := setup.Monolith{
Config: base.Cfg,
Client: conn.CreateClient(base, pQUIC),
@ -235,6 +257,7 @@ func main() {
RoomserverAPI: rsAPI,
UserAPI: userAPI,
KeyAPI: keyAPI,
RelayAPI: relayAPI,
ExtPublicRoomsProvider: roomProvider,
ExtUserDirectoryProvider: userProvider,
}
@ -305,27 +328,38 @@ func main() {
go func(ch <-chan pineconeEvents.Event) {
eLog := logrus.WithField("pinecone", "events")
relayServerSyncRunning := atomic.NewBool(false)
stopRelayServerSync := make(chan bool)
m := RelayServerRetriever{
Context: context.Background(),
ServerName: gomatrixserverlib.ServerName(pRouter.PublicKey().String()),
FederationAPI: fsAPI,
RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool),
RelayAPI: monolith.RelayAPI,
}
m.InitializeRelayServers(eLog)
for event := range ch {
switch e := event.(type) {
case pineconeEvents.PeerAdded:
if !relayServerSyncRunning.Load() {
go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning)
}
case pineconeEvents.PeerRemoved:
case pineconeEvents.TreeParentUpdate:
case pineconeEvents.SnakeDescUpdate:
case pineconeEvents.TreeRootAnnUpdate:
case pineconeEvents.SnakeEntryAdded:
case pineconeEvents.SnakeEntryRemoved:
if relayServerSyncRunning.Load() && pRouter.TotalPeerCount() == 0 {
stopRelayServerSync <- true
}
case pineconeEvents.BroadcastReceived:
eLog.Info("Broadcast received from: ", e.PeerID)
// eLog.Info("Broadcast received from: ", e.PeerID)
req := &api.PerformWakeupServersRequest{
ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)},
}
res := &api.PerformWakeupServersResponse{}
if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil {
logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID)
eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID)
}
case pineconeEvents.BandwidthReport:
default:
}
}
@ -333,3 +367,78 @@ func main() {
base.WaitForShutdown()
}
type RelayServerRetriever struct {
Context context.Context
ServerName gomatrixserverlib.ServerName
FederationAPI api.FederationInternalAPI
RelayServersQueried map[gomatrixserverlib.ServerName]bool
RelayAPI relayServerAPI.RelayInternalAPI
}
func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) {
request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(m.ServerName)}
response := api.P2PQueryRelayServersResponse{}
err := m.FederationAPI.P2PQueryRelayServers(m.Context, &request, &response)
if err != nil {
eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error())
}
for _, server := range response.RelayServers {
m.RelayServersQueried[server] = false
}
eLog.Infof("Registered relay servers: %v", response.RelayServers)
}
func (m *RelayServerRetriever) syncRelayServers(stop <-chan bool, running atomic.Bool) {
defer running.Store(false)
t := time.NewTimer(relayServerRetryInterval)
for {
relayServersToQuery := []gomatrixserverlib.ServerName{}
for server, complete := range m.RelayServersQueried {
if !complete {
relayServersToQuery = append(relayServersToQuery, server)
}
}
if len(relayServersToQuery) == 0 {
// All relay servers have been synced.
return
}
m.queryRelayServers(relayServersToQuery)
t.Reset(relayServerRetryInterval)
select {
case <-stop:
// We have been asked to stop syncing, drain the timer and return.
if !t.Stop() {
<-t.C
}
return
case <-t.C:
// The timer has expired. Continue to the next loop iteration.
}
}
}
func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) {
logrus.Info("querying relay servers for any available transactions")
for _, server := range relayServers {
userID, err := gomatrixserverlib.NewUserID("@user:"+string(m.ServerName), false)
if err != nil {
return
}
err = m.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server)
if err == nil {
m.RelayServersQueried[server] = true
// TODO : What happens if your relay receives new messages after this point?
// Should you continue to check with them, or should they try and contact you?
// They could send a "new_async_events" message your way maybe?
// Then you could mark them as needing to be queried again.
// What if you miss this message?
// Maybe you should try querying them again after a certain period of time as a backup?
} else {
logrus.Errorf("Failed querying relay server: %s", err.Error())
}
}
}

View File

@ -18,6 +18,7 @@ type FederationInternalAPI interface {
gomatrixserverlib.KeyDatabase
ClientFederationAPI
RoomserverFederationAPI
P2PFederationAPI
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
@ -30,7 +31,6 @@ type FederationInternalAPI interface {
request *PerformBroadcastEDURequest,
response *PerformBroadcastEDUResponse,
) error
PerformWakeupServers(
ctx context.Context,
request *PerformWakeupServersRequest,
@ -71,6 +71,15 @@ type RoomserverFederationAPI interface {
LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
}
type P2PFederationAPI interface {
// Relay Server sync api used in the pinecone demos.
P2PQueryRelayServers(
ctx context.Context,
request *P2PQueryRelayServersRequest,
response *P2PQueryRelayServersResponse,
) error
}
// KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver
// implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in
// this interface are of type FederationClientError
@ -82,6 +91,7 @@ type KeyserverFederationAPI interface {
// an interface for gmsl.FederationClient - contains functions called by federationapi only.
type FederationClient interface {
P2PFederationClient
gomatrixserverlib.KeyClient
SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error)
@ -110,6 +120,11 @@ type FederationClient interface {
LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
}
type P2PFederationClient interface {
P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error)
P2PGetTransactionFromRelay(ctx context.Context, u gomatrixserverlib.UserID, prev gomatrixserverlib.RelayEntry, relayServer gomatrixserverlib.ServerName) (res gomatrixserverlib.RespGetRelayTransaction, err error)
}
// FederationClientError is returned from FederationClient methods in the event of a problem.
type FederationClientError struct {
Err string
@ -233,3 +248,11 @@ type InputPublicKeysRequest struct {
type InputPublicKeysResponse struct {
}
type P2PQueryRelayServersRequest struct {
Server gomatrixserverlib.ServerName
}
type P2PQueryRelayServersResponse struct {
RelayServers []gomatrixserverlib.ServerName
}

View File

@ -113,7 +113,10 @@ func NewInternalAPI(
_ = federationDB.RemoveAllServersFromBlacklist()
}
stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1)
stats := statistics.NewStatistics(
federationDB,
cfg.FederationMaxRetries+1,
cfg.P2PFederationRetriesUntilAssumedOffline+1)
js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)

View File

@ -109,13 +109,14 @@ func NewFederationInternalAPI(
func (a *FederationInternalAPI) isBlacklistedOrBackingOff(s gomatrixserverlib.ServerName) (*statistics.ServerStatistics, error) {
stats := a.statistics.ForServer(s)
until, blacklisted := stats.BackoffInfo()
if blacklisted {
if stats.Blacklisted() {
return stats, &api.FederationClientError{
Blacklisted: true,
}
}
now := time.Now()
until := stats.BackoffInfo()
if until != nil && now.Before(*until) {
return stats, &api.FederationClientError{
RetryAfter: time.Until(*until),
@ -163,7 +164,7 @@ func (a *FederationInternalAPI) doRequestIfNotBackingOffOrBlacklisted(
RetryAfter: retryAfter,
}
}
stats.Success()
stats.Success(statistics.SendDirect)
return res, nil
}
@ -171,7 +172,7 @@ func (a *FederationInternalAPI) doRequestIfNotBlacklisted(
s gomatrixserverlib.ServerName, request func() (interface{}, error),
) (interface{}, error) {
stats := a.statistics.ForServer(s)
if _, blacklisted := stats.BackoffInfo(); blacklisted {
if blacklisted := stats.Blacklisted(); blacklisted {
return stats, &api.FederationClientError{
Err: fmt.Sprintf("server %q is blacklisted", s),
Blacklisted: true,

View File

@ -0,0 +1,202 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"fmt"
"testing"
"github.com/matrix-org/dendrite/federationapi/queue"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
const (
FailuresUntilAssumedOffline = 3
FailuresUntilBlacklist = 8
)
func (t *testFedClient) QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) {
t.queryKeysCalled = true
if t.shouldFail {
return gomatrixserverlib.RespQueryKeys{}, fmt.Errorf("Failure")
}
return gomatrixserverlib.RespQueryKeys{}, nil
}
func (t *testFedClient) ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) {
t.claimKeysCalled = true
if t.shouldFail {
return gomatrixserverlib.RespClaimKeys{}, fmt.Errorf("Failure")
}
return gomatrixserverlib.RespClaimKeys{}, nil
}
func TestFederationClientQueryKeys(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "server",
},
},
}
fedClient := &testFedClient{}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedapi := FederationInternalAPI{
db: testDB,
cfg: &cfg,
statistics: &stats,
federation: fedClient,
queues: queues,
}
_, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil)
assert.Nil(t, err)
assert.True(t, fedClient.queryKeysCalled)
}
func TestFederationClientQueryKeysBlacklisted(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
testDB.AddServerToBlacklist("server")
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "server",
},
},
}
fedClient := &testFedClient{}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedapi := FederationInternalAPI{
db: testDB,
cfg: &cfg,
statistics: &stats,
federation: fedClient,
queues: queues,
}
_, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil)
assert.NotNil(t, err)
assert.False(t, fedClient.queryKeysCalled)
}
func TestFederationClientQueryKeysFailure(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "server",
},
},
}
fedClient := &testFedClient{shouldFail: true}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedapi := FederationInternalAPI{
db: testDB,
cfg: &cfg,
statistics: &stats,
federation: fedClient,
queues: queues,
}
_, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil)
assert.NotNil(t, err)
assert.True(t, fedClient.queryKeysCalled)
}
func TestFederationClientClaimKeys(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "server",
},
},
}
fedClient := &testFedClient{}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedapi := FederationInternalAPI{
db: testDB,
cfg: &cfg,
statistics: &stats,
federation: fedClient,
queues: queues,
}
_, err := fedapi.ClaimKeys(context.Background(), "origin", "server", nil)
assert.Nil(t, err)
assert.True(t, fedClient.claimKeysCalled)
}
func TestFederationClientClaimKeysBlacklisted(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
testDB.AddServerToBlacklist("server")
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "server",
},
},
}
fedClient := &testFedClient{}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedapi := FederationInternalAPI{
db: testDB,
cfg: &cfg,
statistics: &stats,
federation: fedClient,
queues: queues,
}
_, err := fedapi.ClaimKeys(context.Background(), "origin", "server", nil)
assert.NotNil(t, err)
assert.False(t, fedClient.claimKeysCalled)
}

View File

@ -14,6 +14,7 @@ import (
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/consumers"
"github.com/matrix-org/dendrite/federationapi/statistics"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/version"
)
@ -24,6 +25,10 @@ func (r *FederationInternalAPI) PerformDirectoryLookup(
request *api.PerformDirectoryLookupRequest,
response *api.PerformDirectoryLookupResponse,
) (err error) {
if !r.shouldAttemptDirectFederation(request.ServerName) {
return fmt.Errorf("relay servers have no meaningful response for directory lookup.")
}
dir, err := r.federation.LookupRoomAlias(
ctx,
r.cfg.Matrix.ServerName,
@ -36,7 +41,7 @@ func (r *FederationInternalAPI) PerformDirectoryLookup(
}
response.RoomID = dir.RoomID
response.ServerNames = dir.Servers
r.statistics.ForServer(request.ServerName).Success()
r.statistics.ForServer(request.ServerName).Success(statistics.SendDirect)
return nil
}
@ -144,6 +149,10 @@ func (r *FederationInternalAPI) performJoinUsingServer(
supportedVersions []gomatrixserverlib.RoomVersion,
unsigned map[string]interface{},
) error {
if !r.shouldAttemptDirectFederation(serverName) {
return fmt.Errorf("relay servers have no meaningful response for join.")
}
_, origin, err := r.cfg.Matrix.SplitLocalID('@', userID)
if err != nil {
return err
@ -164,7 +173,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
r.statistics.ForServer(serverName).Failure()
return fmt.Errorf("r.federation.MakeJoin: %w", err)
}
r.statistics.ForServer(serverName).Success()
r.statistics.ForServer(serverName).Success(statistics.SendDirect)
// Set all the fields to be what they should be, this should be a no-op
// but it's possible that the remote server returned us something "odd"
@ -219,7 +228,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
r.statistics.ForServer(serverName).Failure()
return fmt.Errorf("r.federation.SendJoin: %w", err)
}
r.statistics.ForServer(serverName).Success()
r.statistics.ForServer(serverName).Success(statistics.SendDirect)
// If the remote server returned an event in the "event" key of
// the send_join request then we should use that instead. It may
@ -407,6 +416,10 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
serverName gomatrixserverlib.ServerName,
supportedVersions []gomatrixserverlib.RoomVersion,
) error {
if !r.shouldAttemptDirectFederation(serverName) {
return fmt.Errorf("relay servers have no meaningful response for outbound peek.")
}
// create a unique ID for this peek.
// for now we just use the room ID again. In future, if we ever
// support concurrent peeks to the same room with different filters
@ -446,7 +459,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
r.statistics.ForServer(serverName).Failure()
return fmt.Errorf("r.federation.Peek: %w", err)
}
r.statistics.ForServer(serverName).Success()
r.statistics.ForServer(serverName).Success(statistics.SendDirect)
// Work out if we support the room version that has been supplied in
// the peek response.
@ -516,6 +529,10 @@ func (r *FederationInternalAPI) PerformLeave(
// Try each server that we were provided until we land on one that
// successfully completes the make-leave send-leave dance.
for _, serverName := range request.ServerNames {
if !r.shouldAttemptDirectFederation(serverName) {
continue
}
// Try to perform a make_leave using the information supplied in the
// request.
respMakeLeave, err := r.federation.MakeLeave(
@ -585,7 +602,7 @@ func (r *FederationInternalAPI) PerformLeave(
continue
}
r.statistics.ForServer(serverName).Success()
r.statistics.ForServer(serverName).Success(statistics.SendDirect)
return nil
}
@ -616,6 +633,12 @@ func (r *FederationInternalAPI) PerformInvite(
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
// TODO (devon): This should be allowed via a relay. Currently only transactions
// can be sent to relays. Would need to extend relays to handle invites.
if !r.shouldAttemptDirectFederation(destination) {
return fmt.Errorf("relay servers have no meaningful response for invite.")
}
logrus.WithFields(logrus.Fields{
"event_id": request.Event.EventID(),
"user_id": *request.Event.StateKey(),
@ -682,12 +705,8 @@ func (r *FederationInternalAPI) PerformWakeupServers(
func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) {
for _, srv := range destinations {
// Check the statistics cache for the blacklist status to prevent hitting
// the database unnecessarily.
if r.queues.IsServerBlacklisted(srv) {
_ = r.db.RemoveServerFromBlacklist(srv)
}
r.queues.RetryServer(srv)
wasBlacklisted := r.statistics.ForServer(srv).MarkServerAlive()
r.queues.RetryServer(srv, wasBlacklisted)
}
}
@ -719,7 +738,9 @@ func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error {
return fmt.Errorf("auth chain response is missing m.room.create event")
}
func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder) gomatrixserverlib.RoomVersion {
func setDefaultRoomVersionFromJoinEvent(
joinEvent gomatrixserverlib.EventBuilder,
) gomatrixserverlib.RoomVersion {
// if auth events are not event references we know it must be v3+
// we have to do these shenanigans to satisfy sytest, specifically for:
// "Outbound federation rejects m.room.create events with an unknown room version"
@ -802,3 +823,31 @@ func federatedAuthProvider(
return returning, nil
}
}
// P2PQueryRelayServers implements api.FederationInternalAPI
func (r *FederationInternalAPI) P2PQueryRelayServers(
ctx context.Context,
request *api.P2PQueryRelayServersRequest,
response *api.P2PQueryRelayServersResponse,
) error {
logrus.Infof("Getting relay servers for: %s", request.Server)
relayServers, err := r.db.P2PGetRelayServersForServer(ctx, request.Server)
if err != nil {
return err
}
response.RelayServers = relayServers
return nil
}
func (r *FederationInternalAPI) shouldAttemptDirectFederation(
destination gomatrixserverlib.ServerName,
) bool {
var shouldRelay bool
stats := r.statistics.ForServer(destination)
if stats.AssumedOffline() && len(stats.KnownRelayServers()) > 0 {
shouldRelay = true
}
return !shouldRelay
}

View File

@ -0,0 +1,190 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"testing"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/queue"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
type testFedClient struct {
api.FederationClient
queryKeysCalled bool
claimKeysCalled bool
shouldFail bool
}
func (t *testFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) {
return gomatrixserverlib.RespDirectory{}, nil
}
func TestPerformWakeupServers(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
server := gomatrixserverlib.ServerName("wakeup")
testDB.AddServerToBlacklist(server)
testDB.SetServerAssumedOffline(context.Background(), server)
blacklisted, err := testDB.IsServerBlacklisted(server)
assert.NoError(t, err)
assert.True(t, blacklisted)
offline, err := testDB.IsServerAssumedOffline(context.Background(), server)
assert.NoError(t, err)
assert.True(t, offline)
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "relay",
},
},
}
fedClient := &testFedClient{}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedAPI := NewFederationInternalAPI(
testDB, &cfg, nil, fedClient, &stats, nil, queues, nil,
)
req := api.PerformWakeupServersRequest{
ServerNames: []gomatrixserverlib.ServerName{server},
}
res := api.PerformWakeupServersResponse{}
err = fedAPI.PerformWakeupServers(context.Background(), &req, &res)
assert.NoError(t, err)
blacklisted, err = testDB.IsServerBlacklisted(server)
assert.NoError(t, err)
assert.False(t, blacklisted)
offline, err = testDB.IsServerAssumedOffline(context.Background(), server)
assert.NoError(t, err)
assert.False(t, offline)
}
func TestQueryRelayServers(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
server := gomatrixserverlib.ServerName("wakeup")
relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"}
err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers)
assert.NoError(t, err)
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "relay",
},
},
}
fedClient := &testFedClient{}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedAPI := NewFederationInternalAPI(
testDB, &cfg, nil, fedClient, &stats, nil, queues, nil,
)
req := api.P2PQueryRelayServersRequest{
Server: server,
}
res := api.P2PQueryRelayServersResponse{}
err = fedAPI.P2PQueryRelayServers(context.Background(), &req, &res)
assert.NoError(t, err)
assert.Equal(t, len(relayServers), len(res.RelayServers))
}
func TestPerformDirectoryLookup(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "relay",
},
},
}
fedClient := &testFedClient{}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedAPI := NewFederationInternalAPI(
testDB, &cfg, nil, fedClient, &stats, nil, queues, nil,
)
req := api.PerformDirectoryLookupRequest{
RoomAlias: "room",
ServerName: "server",
}
res := api.PerformDirectoryLookupResponse{}
err := fedAPI.PerformDirectoryLookup(context.Background(), &req, &res)
assert.NoError(t, err)
}
func TestPerformDirectoryLookupRelaying(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
server := gomatrixserverlib.ServerName("wakeup")
testDB.SetServerAssumedOffline(context.Background(), server)
testDB.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{"relay"})
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: server,
},
},
}
fedClient := &testFedClient{}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedAPI := NewFederationInternalAPI(
testDB, &cfg, nil, fedClient, &stats, nil, queues, nil,
)
req := api.PerformDirectoryLookupRequest{
RoomAlias: "room",
ServerName: server,
}
res := api.PerformDirectoryLookupResponse{}
err := fedAPI.PerformDirectoryLookup(context.Background(), &req, &res)
assert.Error(t, err)
}

View File

@ -24,6 +24,7 @@ const (
FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest"
FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU"
FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers"
FederationAPIQueryRelayServers = "/federationapi/queryRelayServers"
FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices"
FederationAPIClaimKeysPath = "/federationapi/client/claimKeys"
@ -510,3 +511,14 @@ func (h *httpFederationInternalAPI) QueryPublicKeys(
h.httpClient, ctx, request, response,
)
}
func (h *httpFederationInternalAPI) P2PQueryRelayServers(
ctx context.Context,
request *api.P2PQueryRelayServersRequest,
response *api.P2PQueryRelayServersResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryRelayServers", h.federationAPIURL+FederationAPIQueryRelayServers,
h.httpClient, ctx, request, response,
)
}

View File

@ -29,7 +29,7 @@ import (
fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/process"
)
@ -70,7 +70,7 @@ type destinationQueue struct {
// Send event adds the event to the pending queue for the destination.
// If the queue is empty then it starts a background goroutine to
// start sending events to that destination.
func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, receipt *shared.Receipt) {
func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, dbReceipt *receipt.Receipt) {
if event == nil {
logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination)
return
@ -85,7 +85,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
if len(oq.pendingPDUs) < maxPDUsInMemory {
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{
pdu: event,
receipt: receipt,
dbReceipt: dbReceipt,
})
} else {
oq.overflowed.Store(true)
@ -101,7 +101,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
// sendEDU adds the EDU event to the pending queue for the destination.
// If the queue is empty then it starts a background goroutine to
// start sending events to that destination.
func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *shared.Receipt) {
func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, dbReceipt *receipt.Receipt) {
if event == nil {
logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination)
return
@ -116,7 +116,7 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
if len(oq.pendingEDUs) < maxEDUsInMemory {
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{
edu: event,
receipt: receipt,
dbReceipt: dbReceipt,
})
} else {
oq.overflowed.Store(true)
@ -210,10 +210,10 @@ func (oq *destinationQueue) getPendingFromDatabase() {
gotPDUs := map[string]struct{}{}
gotEDUs := map[string]struct{}{}
for _, pdu := range oq.pendingPDUs {
gotPDUs[pdu.receipt.String()] = struct{}{}
gotPDUs[pdu.dbReceipt.String()] = struct{}{}
}
for _, edu := range oq.pendingEDUs {
gotEDUs[edu.receipt.String()] = struct{}{}
gotEDUs[edu.dbReceipt.String()] = struct{}{}
}
overflowed := false
@ -371,7 +371,7 @@ func (oq *destinationQueue) backgroundSend() {
// If we have pending PDUs or EDUs then construct a transaction.
// Try sending the next transaction and see what happens.
terr := oq.nextTransaction(toSendPDUs, toSendEDUs)
terr, sendMethod := oq.nextTransaction(toSendPDUs, toSendEDUs)
if terr != nil {
// We failed to send the transaction. Mark it as a failure.
_, blacklisted := oq.statistics.Failure()
@ -388,18 +388,19 @@ func (oq *destinationQueue) backgroundSend() {
return
}
} else {
oq.handleTransactionSuccess(pduCount, eduCount)
oq.handleTransactionSuccess(pduCount, eduCount, sendMethod)
}
}
}
// nextTransaction creates a new transaction from the pending event
// queue and sends it.
// Returns an error if the transaction wasn't sent.
// Returns an error if the transaction wasn't sent. And whether the success
// was to a relay server or not.
func (oq *destinationQueue) nextTransaction(
pdus []*queuedPDU,
edus []*queuedEDU,
) error {
) (err error, sendMethod statistics.SendMethod) {
// Create the transaction.
t, pduReceipts, eduReceipts := oq.createTransaction(pdus, edus)
logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
@ -407,7 +408,37 @@ func (oq *destinationQueue) nextTransaction(
// Try to send the transaction to the destination server.
ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5)
defer cancel()
_, err := oq.client.SendTransaction(ctx, t)
relayServers := oq.statistics.KnownRelayServers()
if oq.statistics.AssumedOffline() && len(relayServers) > 0 {
sendMethod = statistics.SendViaRelay
relaySuccess := false
logrus.Infof("Sending to relay servers: %v", relayServers)
// TODO : how to pass through actual userID here?!?!?!?!
userID, userErr := gomatrixserverlib.NewUserID("@user:"+string(oq.destination), false)
if userErr != nil {
return userErr, sendMethod
}
// Attempt sending to each known relay server.
for _, relayServer := range relayServers {
_, relayErr := oq.client.P2PSendTransactionToRelay(ctx, *userID, t, relayServer)
if relayErr != nil {
err = relayErr
} else {
// If sending to one of the relay servers succeeds, consider the send successful.
relaySuccess = true
}
}
// Clear the error if sending to any of the relay servers succeeded.
if relaySuccess {
err = nil
}
} else {
sendMethod = statistics.SendDirect
_, err = oq.client.SendTransaction(ctx, t)
}
switch errResponse := err.(type) {
case nil:
// Clean up the transaction in the database.
@ -427,7 +458,7 @@ func (oq *destinationQueue) nextTransaction(
oq.transactionIDMutex.Lock()
oq.transactionID = ""
oq.transactionIDMutex.Unlock()
return nil
return nil, sendMethod
case gomatrix.HTTPError:
// Report that we failed to send the transaction and we
// will retry again, subject to backoff.
@ -437,13 +468,13 @@ func (oq *destinationQueue) nextTransaction(
// to a 400-ish error
code := errResponse.Code
logrus.Debug("Transaction failed with HTTP", code)
return err
return err, sendMethod
default:
logrus.WithFields(logrus.Fields{
"destination": oq.destination,
logrus.ErrorKey: err,
}).Debugf("Failed to send transaction %q", t.TransactionID)
return err
return err, sendMethod
}
}
@ -453,7 +484,7 @@ func (oq *destinationQueue) nextTransaction(
func (oq *destinationQueue) createTransaction(
pdus []*queuedPDU,
edus []*queuedEDU,
) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) {
) (gomatrixserverlib.Transaction, []*receipt.Receipt, []*receipt.Receipt) {
// If there's no projected transaction ID then generate one. If
// the transaction succeeds then we'll set it back to "" so that
// we generate a new one next time. If it fails, we'll preserve
@ -474,8 +505,8 @@ func (oq *destinationQueue) createTransaction(
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = oq.transactionID
var pduReceipts []*shared.Receipt
var eduReceipts []*shared.Receipt
var pduReceipts []*receipt.Receipt
var eduReceipts []*receipt.Receipt
// Go through PDUs that we retrieved from the database, if any,
// and add them into the transaction.
@ -487,7 +518,7 @@ func (oq *destinationQueue) createTransaction(
// Append the JSON of the event, since this is a json.RawMessage type in the
// gomatrixserverlib.Transaction struct
t.PDUs = append(t.PDUs, pdu.pdu.JSON())
pduReceipts = append(pduReceipts, pdu.receipt)
pduReceipts = append(pduReceipts, pdu.dbReceipt)
}
// Do the same for pending EDUS in the queue.
@ -497,7 +528,7 @@ func (oq *destinationQueue) createTransaction(
continue
}
t.EDUs = append(t.EDUs, *edu.edu)
eduReceipts = append(eduReceipts, edu.receipt)
eduReceipts = append(eduReceipts, edu.dbReceipt)
}
return t, pduReceipts, eduReceipts
@ -530,10 +561,11 @@ func (oq *destinationQueue) blacklistDestination() {
// handleTransactionSuccess updates the cached event queues as well as the success and
// backoff information for this server.
func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) {
func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int, sendMethod statistics.SendMethod) {
// If we successfully sent the transaction then clear out
// the pending events and EDUs, and wipe our transaction ID.
oq.statistics.Success()
oq.statistics.Success(sendMethod)
oq.pendingMutex.Lock()
defer oq.pendingMutex.Unlock()

View File

@ -30,7 +30,7 @@ import (
fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/process"
)
@ -138,12 +138,12 @@ func NewOutgoingQueues(
}
type queuedPDU struct {
receipt *shared.Receipt
dbReceipt *receipt.Receipt
pdu *gomatrixserverlib.HeaderedEvent
}
type queuedEDU struct {
receipt *shared.Receipt
dbReceipt *receipt.Receipt
edu *gomatrixserverlib.EDU
}
@ -374,24 +374,13 @@ func (oqs *OutgoingQueues) SendEDU(
return nil
}
// IsServerBlacklisted returns whether or not the provided server is currently
// blacklisted.
func (oqs *OutgoingQueues) IsServerBlacklisted(srv gomatrixserverlib.ServerName) bool {
return oqs.statistics.ForServer(srv).Blacklisted()
}
// RetryServer attempts to resend events to the given server if we had given up.
func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) {
func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName, wasBlacklisted bool) {
if oqs.disabled {
return
}
serverStatistics := oqs.statistics.ForServer(srv)
forceWakeup := serverStatistics.Blacklisted()
serverStatistics.RemoveBlacklist()
serverStatistics.ClearBackoff()
if queue := oqs.getQueue(srv); queue != nil {
queue.wakeQueueIfEventsPending(forceWakeup)
queue.wakeQueueIfEventsPending(wasBlacklisted)
}
}

View File

@ -18,7 +18,6 @@ import (
"context"
"encoding/json"
"fmt"
"sync"
"testing"
"time"
@ -26,13 +25,11 @@ import (
"gotest.tools/v3/poll"
"github.com/matrix-org/gomatrixserverlib"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
@ -57,7 +54,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase
}
} else {
// Fake Database
db := createDatabase()
db := test.NewInMemoryFederationDatabase()
b := struct {
ProcessContext *process.ProcessContext
}{ProcessContext: process.NewProcessContext()}
@ -65,220 +62,6 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase
}
}
func createDatabase() storage.Database {
return &fakeDatabase{
pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}),
pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}),
blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}),
pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent),
pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU),
associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}),
associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}),
}
}
type fakeDatabase struct {
storage.Database
dbMutex sync.Mutex
pendingPDUServers map[gomatrixserverlib.ServerName]struct{}
pendingEDUServers map[gomatrixserverlib.ServerName]struct{}
blacklistedServers map[gomatrixserverlib.ServerName]struct{}
pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent
pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU
associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}
associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}
}
var nidMutex sync.Mutex
var nid = int64(0)
func (d *fakeDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
var event gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal([]byte(js), &event); err == nil {
nidMutex.Lock()
defer nidMutex.Unlock()
nid++
receipt := shared.NewReceipt(nid)
d.pendingPDUs[&receipt] = &event
return &receipt, nil
}
var edu gomatrixserverlib.EDU
if err := json.Unmarshal([]byte(js), &edu); err == nil {
nidMutex.Lock()
defer nidMutex.Unlock()
nid++
receipt := shared.NewReceipt(nid)
d.pendingEDUs[&receipt] = &edu
return &receipt, nil
}
return nil, errors.New("Failed to determine type of json to store")
}
func (d *fakeDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
pduCount := 0
pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent)
if receipts, ok := d.associatedPDUs[serverName]; ok {
for receipt := range receipts {
if event, ok := d.pendingPDUs[receipt]; ok {
pdus[receipt] = event
pduCount++
if pduCount == limit {
break
}
}
}
}
return pdus, nil
}
func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
eduCount := 0
edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU)
if receipts, ok := d.associatedEDUs[serverName]; ok {
for receipt := range receipts {
if event, ok := d.pendingEDUs[receipt]; ok {
edus[receipt] = event
eduCount++
if eduCount == limit {
break
}
}
}
}
return edus, nil
}
func (d *fakeDatabase) AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
if _, ok := d.pendingPDUs[receipt]; ok {
for destination := range destinations {
if _, ok := d.associatedPDUs[destination]; !ok {
d.associatedPDUs[destination] = make(map[*shared.Receipt]struct{})
}
d.associatedPDUs[destination][receipt] = struct{}{}
}
return nil
} else {
return errors.New("PDU doesn't exist")
}
}
func (d *fakeDatabase) AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
if _, ok := d.pendingEDUs[receipt]; ok {
for destination := range destinations {
if _, ok := d.associatedEDUs[destination]; !ok {
d.associatedEDUs[destination] = make(map[*shared.Receipt]struct{})
}
d.associatedEDUs[destination][receipt] = struct{}{}
}
return nil
} else {
return errors.New("EDU doesn't exist")
}
}
func (d *fakeDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
if pdus, ok := d.associatedPDUs[serverName]; ok {
for _, receipt := range receipts {
delete(pdus, receipt)
}
}
return nil
}
func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
if edus, ok := d.associatedEDUs[serverName]; ok {
for _, receipt := range receipts {
delete(edus, receipt)
}
}
return nil
}
func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
servers := []gomatrixserverlib.ServerName{}
for server := range d.pendingPDUServers {
servers = append(servers, server)
}
return servers, nil
}
func (d *fakeDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
servers := []gomatrixserverlib.ServerName{}
for server := range d.pendingEDUServers {
servers = append(servers, server)
}
return servers, nil
}
func (d *fakeDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
d.blacklistedServers[serverName] = struct{}{}
return nil
}
func (d *fakeDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
delete(d.blacklistedServers, serverName)
return nil
}
func (d *fakeDatabase) RemoveAllServersFromBlacklist() error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{})
return nil
}
func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
isBlacklisted := false
if _, ok := d.blacklistedServers[serverName]; ok {
isBlacklisted = true
}
return isBlacklisted, nil
}
type stubFederationRoomServerAPI struct {
rsapi.FederationRoomserverAPI
}
@ -291,7 +74,9 @@ func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Cont
type stubFederationClient struct {
api.FederationClient
shouldTxSucceed bool
shouldTxRelaySucceed bool
txCount atomic.Uint32
txRelayCount atomic.Uint32
}
func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) {
@ -304,6 +89,16 @@ func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixse
return gomatrixserverlib.RespSend{}, result
}
func (f *stubFederationClient) P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error) {
var result error
if !f.shouldTxRelaySucceed {
result = fmt.Errorf("relay transaction failed")
}
f.txRelayCount.Add(1)
return gomatrixserverlib.EmptyResp{}, result
}
func mustCreatePDU(t *testing.T) *gomatrixserverlib.HeaderedEvent {
t.Helper()
content := `{"type":"m.room.message"}`
@ -319,15 +114,18 @@ func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU {
return &gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping}
}
func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) {
func testSetup(failuresUntilBlacklist uint32, failuresUntilAssumedOffline uint32, shouldTxSucceed bool, shouldTxRelaySucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) {
db, processContext, close := mustCreateFederationDatabase(t, dbType, realDatabase)
fc := &stubFederationClient{
shouldTxSucceed: shouldTxSucceed,
shouldTxRelaySucceed: shouldTxRelaySucceed,
txCount: *atomic.NewUint32(0),
txRelayCount: *atomic.NewUint32(0),
}
rs := &stubFederationRoomServerAPI{}
stats := statistics.NewStatistics(db, failuresUntilBlacklist)
stats := statistics.NewStatistics(db, failuresUntilBlacklist, failuresUntilAssumedOffline)
signingInfo := []*gomatrixserverlib.SigningIdentity{
{
KeyID: "ed21019:auto",
@ -344,7 +142,7 @@ func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -373,7 +171,7 @@ func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -402,7 +200,7 @@ func TestSendPDUOnFailStoredInDB(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -432,7 +230,7 @@ func TestSendEDUOnFailStoredInDB(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -462,7 +260,7 @@ func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -513,7 +311,7 @@ func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -564,7 +362,7 @@ func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(2)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -596,7 +394,7 @@ func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(2)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -628,7 +426,7 @@ func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(2)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -662,7 +460,7 @@ func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(2)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -696,7 +494,7 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(1)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -730,8 +528,8 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) {
poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
fc.shouldTxSucceed = true
db.RemoveServerFromBlacklist(destination)
queues.RetryServer(destination)
wasBlacklisted := dest.statistics.MarkServerAlive()
queues.RetryServer(destination, wasBlacklisted)
checkRetry := func(log poll.LogT) poll.Result {
data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
assert.NoError(t, dbErr)
@ -747,7 +545,7 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(1)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -781,8 +579,8 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) {
poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
fc.shouldTxSucceed = true
db.RemoveServerFromBlacklist(destination)
queues.RetryServer(destination)
wasBlacklisted := dest.statistics.MarkServerAlive()
queues.RetryServer(destination, wasBlacklisted)
checkRetry := func(log poll.LogT) poll.Result {
data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
assert.NoError(t, dbErr)
@ -801,7 +599,7 @@ func TestSendPDUBatches(t *testing.T) {
// test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
// db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -845,7 +643,7 @@ func TestSendEDUBatches(t *testing.T) {
// test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
// db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -889,7 +687,7 @@ func TestSendPDUAndEDUBatches(t *testing.T) {
// test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
// db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -940,7 +738,7 @@ func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -978,7 +776,7 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) {
destination := gomatrixserverlib.ServerName("remotehost")
destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, dbType, true)
// NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up.
defer close()
defer func() {
@ -1023,8 +821,8 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) {
poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond))
fc.shouldTxSucceed = true
db.RemoveServerFromBlacklist(destination)
queues.RetryServer(destination)
wasBlacklisted := dest.statistics.MarkServerAlive()
queues.RetryServer(destination, wasBlacklisted)
checkRetry := func(log poll.LogT) poll.Result {
pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200)
assert.NoError(t, dbErrPDU)
@ -1038,3 +836,147 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) {
poll.WaitOn(t, checkRetry, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond))
})
}
func TestSendPDUMultipleFailuresAssumedOffline(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(7)
failuresUntilAssumedOffline := uint32(2)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
<-pc.WaitForShutdown()
}()
ev := mustCreatePDU(t)
err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
assert.NoError(t, err)
check := func(log poll.LogT) poll.Result {
if fc.txCount.Load() == failuresUntilAssumedOffline {
data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
assert.NoError(t, dbErr)
if len(data) == 1 {
if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val {
return poll.Success()
}
return poll.Continue("waiting for server to be assumed offline")
}
return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data))
}
return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
}
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
}
func TestSendEDUMultipleFailuresAssumedOffline(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(7)
failuresUntilAssumedOffline := uint32(2)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
<-pc.WaitForShutdown()
}()
ev := mustCreateEDU(t)
err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
assert.NoError(t, err)
check := func(log poll.LogT) poll.Result {
if fc.txCount.Load() == failuresUntilAssumedOffline {
data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
assert.NoError(t, dbErr)
if len(data) == 1 {
if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val {
return poll.Success()
}
return poll.Continue("waiting for server to be assumed offline")
}
return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data))
}
return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
}
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
}
func TestSendPDUOnRelaySuccessRemovedFromDB(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
failuresUntilAssumedOffline := uint32(1)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
<-pc.WaitForShutdown()
}()
relayServers := []gomatrixserverlib.ServerName{"relayserver"}
queues.statistics.ForServer(destination).AddRelayServers(relayServers)
ev := mustCreatePDU(t)
err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
assert.NoError(t, err)
check := func(log poll.LogT) poll.Result {
if fc.txCount.Load() == 1 {
if fc.txRelayCount.Load() == 1 {
data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
assert.NoError(t, dbErr)
if len(data) == 0 {
return poll.Success()
}
return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data))
}
return poll.Continue("waiting for more relay send attempts before checking database. Currently %d", fc.txRelayCount.Load())
}
return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
}
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination)
assert.Equal(t, true, assumedOffline)
}
func TestSendEDUOnRelaySuccessRemovedFromDB(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
failuresUntilAssumedOffline := uint32(1)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
<-pc.WaitForShutdown()
}()
relayServers := []gomatrixserverlib.ServerName{"relayserver"}
queues.statistics.ForServer(destination).AddRelayServers(relayServers)
ev := mustCreateEDU(t)
err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
assert.NoError(t, err)
check := func(log poll.LogT) poll.Result {
if fc.txCount.Load() == 1 {
if fc.txRelayCount.Load() == 1 {
data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
assert.NoError(t, dbErr)
if len(data) == 0 {
return poll.Success()
}
return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data))
}
return poll.Continue("waiting for more relay send attempts before checking database. Currently %d", fc.txRelayCount.Load())
}
return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
}
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination)
assert.Equal(t, true, assumedOffline)
}

View File

@ -0,0 +1,94 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing_test
import (
"context"
"encoding/hex"
"io"
"net/http/httptest"
"net/url"
"testing"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
fedAPI "github.com/matrix-org/dendrite/federationapi"
fedInternal "github.com/matrix-org/dendrite/federationapi/internal"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
userAPI "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519"
)
type fakeUserAPI struct {
userAPI.FederationUserAPI
}
func (u *fakeUserAPI) QueryProfile(ctx context.Context, req *userAPI.QueryProfileRequest, res *userAPI.QueryProfileResponse) error {
return nil
}
func TestHandleQueryProfile(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath()
base.PublicFederationAPIMux = fedMux
base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin
base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false
fedClient := fakeFedClient{}
serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing()
fedapi := fedAPI.NewInternalAPI(base, &fedClient, nil, nil, keyRing, true)
userapi := fakeUserAPI{}
r, ok := fedapi.(*fedInternal.FederationInternalAPI)
if !ok {
panic("This is a programming error.")
}
routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil)
handler := fedMux.Get(routing.QueryProfileRouteName).GetHandler().ServeHTTP
_, sk, _ := ed25519.GenerateKey(nil)
keyID := signing.KeyID
pk := sk.Public().(ed25519.PublicKey)
serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk))
req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/profile?user_id="+url.QueryEscape("@user:"+string(testOrigin)))
type queryContent struct{}
content := queryContent{}
err := req.SetContent(content)
if err != nil {
t.Fatalf("Error: %s", err.Error())
}
req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk)
httpReq, err := req.HTTPRequest()
if err != nil {
t.Fatalf("Error: %s", err.Error())
}
// vars := map[string]string{"room_alias": "#room:server"}
w := httptest.NewRecorder()
// httpReq = mux.SetURLVars(httpReq, vars)
handler(w, httpReq)
res := w.Result()
data, _ := io.ReadAll(res.Body)
println(string(data))
assert.Equal(t, 200, res.StatusCode)
})
}

View File

@ -0,0 +1,94 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing_test
import (
"context"
"encoding/hex"
"io"
"net/http/httptest"
"net/url"
"testing"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
fedAPI "github.com/matrix-org/dendrite/federationapi"
fedclient "github.com/matrix-org/dendrite/federationapi/api"
fedInternal "github.com/matrix-org/dendrite/federationapi/internal"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519"
)
type fakeFedClient struct {
fedclient.FederationClient
}
func (f *fakeFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) {
return
}
func TestHandleQueryDirectory(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath()
base.PublicFederationAPIMux = fedMux
base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin
base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false
fedClient := fakeFedClient{}
serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing()
fedapi := fedAPI.NewInternalAPI(base, &fedClient, nil, nil, keyRing, true)
userapi := fakeUserAPI{}
r, ok := fedapi.(*fedInternal.FederationInternalAPI)
if !ok {
panic("This is a programming error.")
}
routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil)
handler := fedMux.Get(routing.QueryDirectoryRouteName).GetHandler().ServeHTTP
_, sk, _ := ed25519.GenerateKey(nil)
keyID := signing.KeyID
pk := sk.Public().(ed25519.PublicKey)
serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk))
req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/directory?room_alias="+url.QueryEscape("#room:server"))
type queryContent struct{}
content := queryContent{}
err := req.SetContent(content)
if err != nil {
t.Fatalf("Error: %s", err.Error())
}
req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk)
httpReq, err := req.HTTPRequest()
if err != nil {
t.Fatalf("Error: %s", err.Error())
}
// vars := map[string]string{"room_alias": "#room:server"}
w := httptest.NewRecorder()
// httpReq = mux.SetURLVars(httpReq, vars)
handler(w, httpReq)
res := w.Result()
data, _ := io.ReadAll(res.Body)
println(string(data))
assert.Equal(t, 200, res.StatusCode)
})
}

View File

@ -41,6 +41,12 @@ import (
"github.com/sirupsen/logrus"
)
const (
SendRouteName = "Send"
QueryDirectoryRouteName = "QueryDirectory"
QueryProfileRouteName = "QueryProfile"
)
// Setup registers HTTP handlers with the given ServeMux.
// The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly
// path unescape twice (once from the router, once from MakeFedAPI). We need to have this enabled
@ -68,7 +74,7 @@ func Setup(
if base.EnableMetrics {
prometheus.MustRegister(
pduCountTotal, eduCountTotal,
internal.PDUCountTotal, internal.EDUCountTotal,
)
}
@ -138,7 +144,7 @@ func Setup(
cfg, rsAPI, keyAPI, keys, federation, mu, servers, producer,
)
},
)).Methods(http.MethodPut, http.MethodOptions)
)).Methods(http.MethodPut, http.MethodOptions).Name(SendRouteName)
v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI(
"federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
@ -248,7 +254,7 @@ func Setup(
httpReq, federation, cfg, rsAPI, fsAPI,
)
},
)).Methods(http.MethodGet)
)).Methods(http.MethodGet).Name(QueryDirectoryRouteName)
v1fedmux.Handle("/query/profile", MakeFedAPI(
"federation_query_profile", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
@ -257,7 +263,7 @@ func Setup(
httpReq, userAPI, cfg,
)
},
)).Methods(http.MethodGet)
)).Methods(http.MethodGet).Name(QueryProfileRouteName)
v1fedmux.Handle("/user/devices/{userID}", MakeFedAPI(
"federation_user_devices", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,

View File

@ -17,26 +17,20 @@ package routing
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/dendrite/internal"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
syncTypes "github.com/matrix-org/dendrite/syncapi/types"
)
const (
@ -56,26 +50,6 @@ const (
MetricsWorkMissingPrevEvents = "missing_prev_events"
)
var (
pduCountTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: "dendrite",
Subsystem: "federationapi",
Name: "recv_pdus",
Help: "Number of incoming PDUs from remote servers with labels for success",
},
[]string{"status"}, // 'success' or 'total'
)
eduCountTotal = prometheus.NewCounter(
prometheus.CounterOpts{
Namespace: "dendrite",
Subsystem: "federationapi",
Name: "recv_edus",
Help: "Number of incoming EDUs from remote servers",
},
)
)
var inFlightTxnsPerOrigin sync.Map // transaction ID -> chan util.JSONResponse
// Send implements /_matrix/federation/v1/send/{txnID}
@ -123,18 +97,6 @@ func Send(
defer close(ch)
defer inFlightTxnsPerOrigin.Delete(index)
t := txnReq{
rsAPI: rsAPI,
keys: keys,
ourServerName: cfg.Matrix.ServerName,
federation: federation,
servers: servers,
keyAPI: keyAPI,
roomsMu: mu,
producer: producer,
inboundPresenceEnabled: cfg.Matrix.Presence.EnableInbound,
}
var txnEvents struct {
PDUs []json.RawMessage `json:"pdus"`
EDUs []gomatrixserverlib.EDU `json:"edus"`
@ -155,16 +117,23 @@ func Send(
}
}
// TODO: Really we should have a function to convert FederationRequest to txnReq
t.PDUs = txnEvents.PDUs
t.EDUs = txnEvents.EDUs
t.Origin = request.Origin()
t.TransactionID = txnID
t.Destination = cfg.Matrix.ServerName
t := internal.NewTxnReq(
rsAPI,
keyAPI,
cfg.Matrix.ServerName,
keys,
mu,
producer,
cfg.Matrix.Presence.EnableInbound,
txnEvents.PDUs,
txnEvents.EDUs,
request.Origin(),
txnID,
cfg.Matrix.ServerName)
util.GetLogger(httpReq.Context()).Debugf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, request.Origin(), len(t.PDUs), len(t.EDUs))
resp, jsonErr := t.processTransaction(httpReq.Context())
resp, jsonErr := t.ProcessTransaction(httpReq.Context())
if jsonErr != nil {
util.GetLogger(httpReq.Context()).WithField("jsonErr", jsonErr).Error("t.processTransaction failed")
return *jsonErr
@ -181,283 +150,3 @@ func Send(
ch <- res
return res
}
type txnReq struct {
gomatrixserverlib.Transaction
rsAPI api.FederationRoomserverAPI
keyAPI keyapi.FederationKeyAPI
ourServerName gomatrixserverlib.ServerName
keys gomatrixserverlib.JSONVerifier
federation txnFederationClient
roomsMu *internal.MutexByRoom
servers federationAPI.ServersInRoomProvider
producer *producers.SyncAPIProducer
inboundPresenceEnabled bool
}
// A subset of FederationClient functionality that txn requires. Useful for testing.
type txnFederationClient interface {
LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
res gomatrixserverlib.RespState, err error,
)
LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error)
GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents,
roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
}
func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
t.processEDUs(ctx)
}()
results := make(map[string]gomatrixserverlib.PDUResult)
roomVersions := make(map[string]gomatrixserverlib.RoomVersion)
getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion {
if v, ok := roomVersions[roomID]; ok {
return v
}
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID}
verRes := api.QueryRoomVersionForRoomResponse{}
if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil {
util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", verReq.RoomID)
return ""
}
roomVersions[roomID] = verRes.RoomVersion
return verRes.RoomVersion
}
for _, pdu := range t.PDUs {
pduCountTotal.WithLabelValues("total").Inc()
var header struct {
RoomID string `json:"room_id"`
}
if err := json.Unmarshal(pdu, &header); err != nil {
util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to extract room ID from event")
// We don't know the event ID at this point so we can't return the
// failure in the PDU results
continue
}
roomVersion := getRoomVersion(header.RoomID)
event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion)
if err != nil {
if _, ok := err.(gomatrixserverlib.BadJSONError); ok {
// Room version 6 states that homeservers should strictly enforce canonical JSON
// on PDUs.
//
// This enforces that the entire transaction is rejected if a single bad PDU is
// sent. It is unclear if this is the correct behaviour or not.
//
// See https://github.com/matrix-org/synapse/issues/7543
return nil, &util.JSONResponse{
Code: 400,
JSON: jsonerror.BadJSON("PDU contains bad JSON"),
}
}
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu))
continue
}
if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") {
continue
}
if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) {
results[event.EventID()] = gomatrixserverlib.PDUResult{
Error: "Forbidden by server ACLs",
}
continue
}
if err = event.VerifyEventSignatures(ctx, t.keys); err != nil {
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())
results[event.EventID()] = gomatrixserverlib.PDUResult{
Error: err.Error(),
}
continue
}
// pass the event to the roomserver which will do auth checks
// If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently
// discarded by the caller of this function
if err = api.SendEvents(
ctx,
t.rsAPI,
api.KindNew,
[]*gomatrixserverlib.HeaderedEvent{
event.Headered(roomVersion),
},
t.Destination,
t.Origin,
api.DoNotSendToOtherServers,
nil,
true,
); err != nil {
util.GetLogger(ctx).WithError(err).Errorf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err)
results[event.EventID()] = gomatrixserverlib.PDUResult{
Error: err.Error(),
}
continue
}
results[event.EventID()] = gomatrixserverlib.PDUResult{}
pduCountTotal.WithLabelValues("success").Inc()
}
wg.Wait()
return &gomatrixserverlib.RespSend{PDUs: results}, nil
}
// nolint:gocyclo
func (t *txnReq) processEDUs(ctx context.Context) {
for _, e := range t.EDUs {
eduCountTotal.Inc()
switch e.Type {
case gomatrixserverlib.MTyping:
// https://matrix.org/docs/spec/server_server/latest#typing-notifications
var typingPayload struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
Typing bool `json:"typing"`
}
if err := json.Unmarshal(e.Content, &typingPayload); err != nil {
util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event")
continue
}
if _, serverName, err := gomatrixserverlib.SplitID('@', typingPayload.UserID); err != nil {
continue
} else if serverName == t.ourServerName {
continue
} else if serverName != t.Origin {
continue
}
if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to JetStream")
}
case gomatrixserverlib.MDirectToDevice:
// https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema
var directPayload gomatrixserverlib.ToDeviceMessage
if err := json.Unmarshal(e.Content, &directPayload); err != nil {
util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events")
continue
}
if _, serverName, err := gomatrixserverlib.SplitID('@', directPayload.Sender); err != nil {
continue
} else if serverName == t.ourServerName {
continue
} else if serverName != t.Origin {
continue
}
for userID, byUser := range directPayload.Messages {
for deviceID, message := range byUser {
// TODO: check that the user and the device actually exist here
if err := t.producer.SendToDevice(ctx, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil {
sentry.CaptureException(err)
util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{
"sender": directPayload.Sender,
"user_id": userID,
"device_id": deviceID,
}).Error("Failed to send send-to-device event to JetStream")
}
}
}
case gomatrixserverlib.MDeviceListUpdate:
if err := t.producer.SendDeviceListUpdate(ctx, e.Content, t.Origin); err != nil {
sentry.CaptureException(err)
util.GetLogger(ctx).WithError(err).Error("failed to InputDeviceListUpdate")
}
case gomatrixserverlib.MReceipt:
// https://matrix.org/docs/spec/server_server/r0.1.4#receipts
payload := map[string]types.FederationReceiptMRead{}
if err := json.Unmarshal(e.Content, &payload); err != nil {
util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal receipt event")
continue
}
for roomID, receipt := range payload {
for userID, mread := range receipt.User {
_, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from receipt event sender")
continue
}
if t.Origin != domain {
util.GetLogger(ctx).Debugf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin)
continue
}
if err := t.processReceiptEvent(ctx, userID, roomID, "m.read", mread.Data.TS, mread.EventIDs); err != nil {
util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{
"sender": t.Origin,
"user_id": userID,
"room_id": roomID,
"events": mread.EventIDs,
}).Error("Failed to send receipt event to JetStream")
continue
}
}
}
case types.MSigningKeyUpdate:
if err := t.producer.SendSigningKeyUpdate(ctx, e.Content, t.Origin); err != nil {
sentry.CaptureException(err)
logrus.WithError(err).Errorf("Failed to process signing key update")
}
case gomatrixserverlib.MPresence:
if t.inboundPresenceEnabled {
if err := t.processPresence(ctx, e); err != nil {
logrus.WithError(err).Errorf("Failed to process presence update")
}
}
default:
util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU")
}
}
}
// processPresence handles m.receipt events
func (t *txnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) error {
payload := types.Presence{}
if err := json.Unmarshal(e.Content, &payload); err != nil {
return err
}
for _, content := range payload.Push {
if _, serverName, err := gomatrixserverlib.SplitID('@', content.UserID); err != nil {
continue
} else if serverName == t.ourServerName {
continue
} else if serverName != t.Origin {
continue
}
presence, ok := syncTypes.PresenceFromString(content.Presence)
if !ok {
continue
}
if err := t.producer.SendPresence(ctx, content.UserID, presence, content.StatusMsg, content.LastActiveAgo); err != nil {
return err
}
}
return nil
}
// processReceiptEvent sends receipt events to JetStream
func (t *txnReq) processReceiptEvent(ctx context.Context,
userID, roomID, receiptType string,
timestamp gomatrixserverlib.Timestamp,
eventIDs []string,
) error {
if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil {
return nil
} else if serverName == t.ourServerName {
return nil
} else if serverName != t.Origin {
return nil
}
// store every event
for _, eventID := range eventIDs {
if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil {
return fmt.Errorf("unable to set receipt event: %w", err)
}
}
return nil
}

View File

@ -1,552 +1,87 @@
package routing
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing_test
import (
"context"
"encoding/hex"
"encoding/json"
"fmt"
"net/http/httptest"
"testing"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
fedAPI "github.com/matrix-org/dendrite/federationapi"
fedInternal "github.com/matrix-org/dendrite/federationapi/internal"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519"
)
const (
testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
testDestination = gomatrixserverlib.ServerName("white.orchard")
)
var (
testRoomVersion = gomatrixserverlib.RoomVersionV1
testData = []json.RawMessage{
[]byte(`{"auth_events":[],"content":{"creator":"@userid:kaer.morhen"},"depth":0,"event_id":"$0ok8ynDp7kjc95e3:kaer.morhen","hashes":{"sha256":"17kPoH+h0Dk4Omn7Sus0qMb6+oGcf+CZFEgDhv7UKWs"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"jP4a04f5/F10Pw95FPpdCyKAO44JOwUQ/MZOOeA/RTU1Dn+AHPMzGSaZnuGjRr/xQuADt+I3ctb5ZQfLKNzHDw"}},"state_key":"","type":"m.room.create"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"content":{"membership":"join"},"depth":1,"event_id":"$LEwEu0kxrtu5fOiS:kaer.morhen","hashes":{"sha256":"B7M88PhXf3vd1LaFtjQutFu4x/w7fHD28XKZ4sAsJTo"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"p2vqmuJn7ZBRImctSaKbXCAxCcBlIjPH9JHte1ouIUGy84gpu4eLipOvSBCLL26hXfC0Zrm4WUto6Hr+ohdrCg"}},"state_key":"@userid:kaer.morhen","type":"m.room.member"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"join_rule":"public"},"depth":2,"event_id":"$SMHlqUrNhhBBRLeN:kaer.morhen","hashes":{"sha256":"vIuJQvmMjrGxshAkj1SXe0C4RqvMbv4ZADDw9pFCWqQ"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"hBMsb3Qppo3RaqqAl4JyTgaiWEbW5hlckATky6PrHun+F3YM203TzG7w9clwuQU5F5pZoB1a6nw+to0hN90FAw"}},"state_key":"","type":"m.room.join_rules"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"history_visibility":"shared"},"depth":3,"event_id":"$6F1yGIbO0J7TM93h:kaer.morhen","hashes":{"sha256":"Mr23GKSlZW7UCCYLgOWawI2Sg6KIoMjUWO2TDenuOgw"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$SMHlqUrNhhBBRLeN:kaer.morhen",{"sha256":"SylzE8U02I+6eyEHgL+FlU0L5YdqrVp8OOlxKS9VQW0"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sHLKrFI3hKGrEJfpMVZSDS3LvLasQsy50CTsOwru9XTVxgRsPo6wozNtRVjxo1J3Rk18RC9JppovmQ5VR5EcDw"}},"state_key":"","type":"m.room.history_visibility"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"ban":50,"events":null,"events_default":0,"invite":0,"kick":50,"redact":50,"state_default":50,"users":null,"users_default":0},"depth":4,"event_id":"$UKNe10XzYzG0TeA9:kaer.morhen","hashes":{"sha256":"ngbP3yja9U5dlckKerUs/fSOhtKxZMCVvsfhPURSS28"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$6F1yGIbO0J7TM93h:kaer.morhen",{"sha256":"A4CucrKSoWX4IaJXhq02mBg1sxIyZEftbC+5p3fZAvk"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"zOmwlP01QL3yFchzuR9WHvogOoBZA3oVtNIF3lM0ZfDnqlSYZB9sns27G/4HVq0k7alaK7ZE3oGoCrVnMkPNCw"}},"state_key":"","type":"m.room.power_levels"}`),
// messages
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":6,"event_id":"$MYSbs8m4rEbsCWXD:kaer.morhen","hashes":{"sha256":"kgbYM7v4Ud2YaBsjBTolM4ySg6rHcJNYI6nWhMSdFUA"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$gl2T9l3qm0kUbiIJ:kaer.morhen",{"sha256":"C/rD04h9wGxRdN2G/IBfrgoE1UovzLZ+uskwaKZ37/Q"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"x0UoKh968jj/F5l1/R7Ew0T6CTKuew3PLNHASNxqck/bkNe8yYQiDHXRr+kZxObeqPZZTpaF1+EI+bLU9W8GDQ"}},"type":"m.room.message"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":7,"event_id":"$N5x9WJkl9ClPrAEg:kaer.morhen","hashes":{"sha256":"FWM8oz4yquTunRZ67qlW2gzPDzdWfBP6RPHXhK1I/x8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$MYSbs8m4rEbsCWXD:kaer.morhen",{"sha256":"fatqgW+SE8mb2wFn3UN+drmluoD4UJ/EcSrL6Ur9q1M"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"Y+LX/xcyufoXMOIoqQBNOzy6lZfUGB1ffgXIrSugk6obMiyAsiRejHQN/pciZXsHKxMJLYRFAz4zSJoS/LGPAA"}},"type":"m.room.message"}`),
}
testEvents = []*gomatrixserverlib.HeaderedEvent{}
testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
)
type sendContent struct {
PDUs []json.RawMessage `json:"pdus"`
EDUs []gomatrixserverlib.EDU `json:"edus"`
}
func init() {
for _, j := range testData {
e, err := gomatrixserverlib.NewEventFromTrustedJSON(j, false, testRoomVersion)
func TestHandleSend(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath()
base.PublicFederationAPIMux = fedMux
base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin
base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false
fedapi := fedAPI.NewInternalAPI(base, nil, nil, nil, nil, true)
serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing()
r, ok := fedapi.(*fedInternal.FederationInternalAPI)
if !ok {
panic("This is a programming error.")
}
routing.Setup(base, nil, r, keyRing, nil, nil, nil, &base.Cfg.MSCs, nil, nil)
handler := fedMux.Get(routing.SendRouteName).GetHandler().ServeHTTP
_, sk, _ := ed25519.GenerateKey(nil)
keyID := signing.KeyID
pk := sk.Public().(ed25519.PublicKey)
serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk))
req := gomatrixserverlib.NewFederationRequest("PUT", serverName, testOrigin, "/send/1234")
content := sendContent{}
err := req.SetContent(content)
if err != nil {
panic("cannot load test data: " + err.Error())
t.Fatalf("Error: %s", err.Error())
}
h := e.Headered(testRoomVersion)
testEvents = append(testEvents, h)
if e.StateKey() != nil {
testStateEvents[gomatrixserverlib.StateKeyTuple{
EventType: e.Type(),
StateKey: *e.StateKey(),
}] = h
}
}
}
type testRoomserverAPI struct {
api.RoomserverInternalAPITrace
inputRoomEvents []api.InputRoomEvent
queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse
queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse
queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse
}
func (t *testRoomserverAPI) InputRoomEvents(
ctx context.Context,
request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse,
) error {
t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...)
for _, ire := range request.InputRoomEvents {
fmt.Println("InputRoomEvents: ", ire.Event.EventID())
}
return nil
}
// Query the latest events and state for a room from the room server.
func (t *testRoomserverAPI) QueryLatestEventsAndState(
ctx context.Context,
request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse,
) error {
r := t.queryLatestEventsAndState(request)
response.RoomExists = r.RoomExists
response.RoomVersion = testRoomVersion
response.LatestEvents = r.LatestEvents
response.StateEvents = r.StateEvents
response.Depth = r.Depth
return nil
}
// Query the state after a list of events in a room from the room server.
func (t *testRoomserverAPI) QueryStateAfterEvents(
ctx context.Context,
request *api.QueryStateAfterEventsRequest,
response *api.QueryStateAfterEventsResponse,
) error {
response.RoomVersion = testRoomVersion
res := t.queryStateAfterEvents(request)
response.PrevEventsExist = res.PrevEventsExist
response.RoomExists = res.RoomExists
response.StateEvents = res.StateEvents
return nil
}
// Query a list of events by event ID.
func (t *testRoomserverAPI) QueryEventsByID(
ctx context.Context,
request *api.QueryEventsByIDRequest,
response *api.QueryEventsByIDResponse,
) error {
res := t.queryEventsByID(request)
response.Events = res.Events
return nil
}
// Query if a server is joined to a room
func (t *testRoomserverAPI) QueryServerJoinedToRoom(
ctx context.Context,
request *api.QueryServerJoinedToRoomRequest,
response *api.QueryServerJoinedToRoomResponse,
) error {
response.RoomExists = true
response.IsInRoom = true
return nil
}
// Asks for the room version for a given room.
func (t *testRoomserverAPI) QueryRoomVersionForRoom(
ctx context.Context,
request *api.QueryRoomVersionForRoomRequest,
response *api.QueryRoomVersionForRoomResponse,
) error {
response.RoomVersion = testRoomVersion
return nil
}
func (t *testRoomserverAPI) QueryServerBannedFromRoom(
ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse,
) error {
res.Banned = false
return nil
}
type txnFedClient struct {
state map[string]gomatrixserverlib.RespState // event_id to response
stateIDs map[string]gomatrixserverlib.RespStateIDs // event_id to response
getEvent map[string]gomatrixserverlib.Transaction // event_id to response
getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error)
}
func (c *txnFedClient) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
res gomatrixserverlib.RespState, err error,
) {
fmt.Println("testFederationClient.LookupState", eventID)
r, ok := c.state[eventID]
if !ok {
err = fmt.Errorf("txnFedClient: no /state for event %s", eventID)
return
}
res = r
return
}
func (c *txnFedClient) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) {
fmt.Println("testFederationClient.LookupStateIDs", eventID)
r, ok := c.stateIDs[eventID]
if !ok {
err = fmt.Errorf("txnFedClient: no /state_ids for event %s", eventID)
return
}
res = r
return
}
func (c *txnFedClient) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) {
fmt.Println("testFederationClient.GetEvent", eventID)
r, ok := c.getEvent[eventID]
if !ok {
err = fmt.Errorf("txnFedClient: no /event for event ID %s", eventID)
return
}
res = r
return
}
func (c *txnFedClient) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents,
roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) {
return c.getMissingEvents(missing)
}
func mustCreateTransaction(rsAPI api.FederationRoomserverAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq {
t := &txnReq{
rsAPI: rsAPI,
keys: &test.NopJSONVerifier{},
federation: fedClient,
roomsMu: internal.NewMutexByRoom(),
}
t.PDUs = pdus
t.Origin = testOrigin
t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
t.Destination = testDestination
return t
}
func mustProcessTransaction(t *testing.T, txn *txnReq, pdusWithErrors []string) {
res, err := txn.processTransaction(context.Background())
req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk)
httpReq, err := req.HTTPRequest()
if err != nil {
t.Errorf("txn.processTransaction returned an error: %v", err)
return
}
if len(res.PDUs) != len(txn.PDUs) {
t.Errorf("txn.processTransaction did not return results for all PDUs, got %d want %d", len(res.PDUs), len(txn.PDUs))
return
}
NextPDU:
for eventID, result := range res.PDUs {
if result.Error == "" {
continue
}
for _, eventIDWantError := range pdusWithErrors {
if eventID == eventIDWantError {
break NextPDU
}
}
t.Errorf("txn.processTransaction PDU %s returned an error %s", eventID, result.Error)
t.Fatalf("Error: %s", err.Error())
}
vars := map[string]string{"txnID": "1234"}
w := httptest.NewRecorder()
httpReq = mux.SetURLVars(httpReq, vars)
handler(w, httpReq)
res := w.Result()
assert.Equal(t, 200, res.StatusCode)
})
}
/*
func fromStateTuples(tuples []gomatrixserverlib.StateKeyTuple, omitTuples []gomatrixserverlib.StateKeyTuple) (result []*gomatrixserverlib.HeaderedEvent) {
NextTuple:
for _, t := range tuples {
for _, o := range omitTuples {
if t == o {
break NextTuple
}
}
h, ok := testStateEvents[t]
if ok {
result = append(result, h)
}
}
return
}
*/
func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []*gomatrixserverlib.HeaderedEvent) {
for _, g := range got {
fmt.Println("GOT ", g.Event.EventID())
}
if len(got) != len(want) {
t.Errorf("wrong number of InputRoomEvents: got %d want %d", len(got), len(want))
return
}
for i := range got {
if got[i].Event.EventID() != want[i].EventID() {
t.Errorf("InputRoomEvents[%d] got %s want %s", i, got[i].Event.EventID(), want[i].EventID())
}
}
}
// The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on
// to the roomserver. It's the most basic test possible.
func TestBasicTransaction(t *testing.T) {
rsAPI := &testRoomserverAPI{}
pdus := []json.RawMessage{
testData[len(testData)-1], // a message event
}
txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus)
mustProcessTransaction(t, txn, nil)
assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]})
}
// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver
// as it does the auth check.
func TestTransactionFailAuthChecks(t *testing.T) {
rsAPI := &testRoomserverAPI{}
pdus := []json.RawMessage{
testData[len(testData)-1], // a message event
}
txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus)
mustProcessTransaction(t, txn, []string{})
// expect message to be sent to the roomserver
assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]})
}
// The purpose of this test is to make sure that when an event is received for which we do not know the prev_events,
// we request them from /get_missing_events. It works by setting PrevEventsExist=false in the roomserver query response,
// resulting in a call to /get_missing_events which returns the missing prev event. Both events should be processed in
// topological order and sent to the roomserver.
/*
func TestTransactionFetchMissingPrevEvents(t *testing.T) {
haveEvent := testEvents[len(testEvents)-3]
prevEvent := testEvents[len(testEvents)-2]
inputEvent := testEvents[len(testEvents)-1]
var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions
rsAPI = &testRoomserverAPI{
queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse {
res := api.QueryEventsByIDResponse{}
for _, ev := range testEvents {
for _, id := range req.EventIDs {
if ev.EventID() == id {
res.Events = append(res.Events, ev)
}
}
}
return res
},
queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse {
return api.QueryStateAfterEventsResponse{
PrevEventsExist: true,
StateEvents: testEvents[:5],
}
},
queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse {
missingPrevEvent := []string{"missing_prev_event"}
if len(req.PrevEventIDs) == 1 {
switch req.PrevEventIDs[0] {
case haveEvent.EventID():
missingPrevEvent = []string{}
case prevEvent.EventID():
// we only have this event if we've been send prevEvent
if len(rsAPI.inputRoomEvents) == 1 && rsAPI.inputRoomEvents[0].Event.EventID() == prevEvent.EventID() {
missingPrevEvent = []string{}
}
}
}
return api.QueryMissingAuthPrevEventsResponse{
RoomExists: true,
MissingAuthEventIDs: []string{},
MissingPrevEventIDs: missingPrevEvent,
}
},
queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse {
return api.QueryLatestEventsAndStateResponse{
RoomExists: true,
Depth: haveEvent.Depth(),
LatestEvents: []gomatrixserverlib.EventReference{
haveEvent.EventReference(),
},
StateEvents: fromStateTuples(req.StateToFetch, nil),
}
},
}
cli := &txnFedClient{
getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) {
if !reflect.DeepEqual(missing.EarliestEvents, []string{haveEvent.EventID()}) {
t.Errorf("call to /get_missing_events wrong earliest events: got %v want %v", missing.EarliestEvents, haveEvent.EventID())
}
if !reflect.DeepEqual(missing.LatestEvents, []string{inputEvent.EventID()}) {
t.Errorf("call to /get_missing_events wrong latest events: got %v want %v", missing.LatestEvents, inputEvent.EventID())
}
return gomatrixserverlib.RespMissingEvents{
Events: []*gomatrixserverlib.Event{
prevEvent.Unwrap(),
},
}, nil
},
}
pdus := []json.RawMessage{
inputEvent.JSON(),
}
txn := mustCreateTransaction(rsAPI, cli, pdus)
mustProcessTransaction(t, txn, nil)
assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{prevEvent, inputEvent})
}
// The purpose of this test is to check that when there are missing prev_events and we still haven't been able to fill
// in the hole with /get_missing_events that the state BEFORE the events we want to persist is fetched via /state_ids
// and /event. It works by setting PrevEventsExist=false in the roomserver query response, resulting in
// a call to /get_missing_events which returns 1 out of the 2 events it needs to fill in the gap. Synapse and Dendrite
// both give up after 1x /get_missing_events call, relying on requesting the state AFTER the missing event in order to
// continue. The DAG looks something like:
// FE GME TXN
// A ---> B ---> C ---> D
// TXN=event in the txn, GME=response to /get_missing_events, FE=roomserver's forward extremity. Should result in:
// - /state_ids?event=B is requested, then /event/B to get the state AFTER B. B is a state event.
// - state resolution is done to check C is allowed.
// This results in B being sent as an outlier FIRST, then C,D.
func TestTransactionFetchMissingStateByStateIDs(t *testing.T) {
eventA := testEvents[len(testEvents)-5]
// this is also len(testEvents)-4
eventB := testStateEvents[gomatrixserverlib.StateKeyTuple{
EventType: gomatrixserverlib.MRoomPowerLevels,
StateKey: "",
}]
eventC := testEvents[len(testEvents)-3]
eventD := testEvents[len(testEvents)-2]
fmt.Println("a:", eventA.EventID())
fmt.Println("b:", eventB.EventID())
fmt.Println("c:", eventC.EventID())
fmt.Println("d:", eventD.EventID())
var rsAPI *testRoomserverAPI
rsAPI = &testRoomserverAPI{
queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse {
omitTuples := []gomatrixserverlib.StateKeyTuple{
{
EventType: gomatrixserverlib.MRoomPowerLevels,
StateKey: "",
},
}
askingForEvent := req.PrevEventIDs[0]
haveEventB := false
haveEventC := false
for _, ev := range rsAPI.inputRoomEvents {
switch ev.Event.EventID() {
case eventB.EventID():
haveEventB = true
omitTuples = nil // include event B now
case eventC.EventID():
haveEventC = true
}
}
prevEventExists := false
if askingForEvent == eventC.EventID() {
prevEventExists = haveEventC
} else if askingForEvent == eventB.EventID() {
prevEventExists = haveEventB
}
var stateEvents []*gomatrixserverlib.HeaderedEvent
if prevEventExists {
stateEvents = fromStateTuples(req.StateToFetch, omitTuples)
}
return api.QueryStateAfterEventsResponse{
PrevEventsExist: prevEventExists,
RoomExists: true,
StateEvents: stateEvents,
}
},
queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse {
askingForEvent := req.PrevEventIDs[0]
haveEventB := false
haveEventC := false
for _, ev := range rsAPI.inputRoomEvents {
switch ev.Event.EventID() {
case eventB.EventID():
haveEventB = true
case eventC.EventID():
haveEventC = true
}
}
prevEventExists := false
if askingForEvent == eventC.EventID() {
prevEventExists = haveEventC
} else if askingForEvent == eventB.EventID() {
prevEventExists = haveEventB
}
var missingPrevEvent []string
if !prevEventExists {
missingPrevEvent = []string{"test"}
}
return api.QueryMissingAuthPrevEventsResponse{
RoomExists: true,
MissingAuthEventIDs: []string{},
MissingPrevEventIDs: missingPrevEvent,
}
},
queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse {
omitTuples := []gomatrixserverlib.StateKeyTuple{
{EventType: gomatrixserverlib.MRoomPowerLevels, StateKey: ""},
}
return api.QueryLatestEventsAndStateResponse{
RoomExists: true,
Depth: eventA.Depth(),
LatestEvents: []gomatrixserverlib.EventReference{
eventA.EventReference(),
},
StateEvents: fromStateTuples(req.StateToFetch, omitTuples),
}
},
queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse {
var res api.QueryEventsByIDResponse
fmt.Println("queryEventsByID ", req.EventIDs)
for _, wantEventID := range req.EventIDs {
for _, ev := range testStateEvents {
// roomserver is missing the power levels event unless it's been sent to us recently as an outlier
if wantEventID == eventB.EventID() {
fmt.Println("Asked for pl event")
for _, inEv := range rsAPI.inputRoomEvents {
fmt.Println("recv ", inEv.Event.EventID())
if inEv.Event.EventID() == wantEventID {
res.Events = append(res.Events, inEv.Event)
break
}
}
continue
}
if ev.EventID() == wantEventID {
res.Events = append(res.Events, ev)
}
}
}
return res
},
}
// /state_ids for event B returns every state event but B (it's the state before)
var authEventIDs []string
var stateEventIDs []string
for _, ev := range testStateEvents {
if ev.EventID() == eventB.EventID() {
continue
}
// state res checks what auth events you give it, and this isn't a valid auth event
if ev.Type() != gomatrixserverlib.MRoomHistoryVisibility {
authEventIDs = append(authEventIDs, ev.EventID())
}
stateEventIDs = append(stateEventIDs, ev.EventID())
}
cli := &txnFedClient{
stateIDs: map[string]gomatrixserverlib.RespStateIDs{
eventB.EventID(): {
StateEventIDs: stateEventIDs,
AuthEventIDs: authEventIDs,
},
},
// /event for event B returns it
getEvent: map[string]gomatrixserverlib.Transaction{
eventB.EventID(): {
PDUs: []json.RawMessage{
eventB.JSON(),
},
},
},
// /get_missing_events should be done exactly once
getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) {
if !reflect.DeepEqual(missing.EarliestEvents, []string{eventA.EventID()}) {
t.Errorf("call to /get_missing_events wrong earliest events: got %v want %v", missing.EarliestEvents, eventA.EventID())
}
if !reflect.DeepEqual(missing.LatestEvents, []string{eventD.EventID()}) {
t.Errorf("call to /get_missing_events wrong latest events: got %v want %v", missing.LatestEvents, eventD.EventID())
}
// just return event C, not event B so /state_ids logic kicks in as there will STILL be missing prev_events
return gomatrixserverlib.RespMissingEvents{
Events: []*gomatrixserverlib.Event{
eventC.Unwrap(),
},
}, nil
},
}
pdus := []json.RawMessage{
eventD.JSON(),
}
txn := mustCreateTransaction(rsAPI, cli, pdus)
mustProcessTransaction(t, txn, nil)
assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{eventB, eventC, eventD})
}
*/

View File

@ -1,6 +1,7 @@
package statistics
import (
"context"
"math"
"math/rand"
"sync"
@ -28,12 +29,22 @@ type Statistics struct {
// just blacklist the host altogether? The backoff is exponential,
// so the max time here to attempt is 2**failures seconds.
FailuresUntilBlacklist uint32
// How many times should we tolerate consecutive failures before we
// mark the destination as offline. At this point we should attempt
// to send messages to the user's async relay servers if we know them.
FailuresUntilAssumedOffline uint32
}
func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistics {
func NewStatistics(
db storage.Database,
failuresUntilBlacklist uint32,
failuresUntilAssumedOffline uint32,
) Statistics {
return Statistics{
DB: db,
FailuresUntilBlacklist: failuresUntilBlacklist,
FailuresUntilAssumedOffline: failuresUntilAssumedOffline,
backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer),
servers: make(map[gomatrixserverlib.ServerName]*ServerStatistics),
}
@ -52,6 +63,7 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
server = &ServerStatistics{
statistics: s,
serverName: serverName,
knownRelayServers: []gomatrixserverlib.ServerName{},
}
s.servers[serverName] = server
s.mutex.Unlock()
@ -61,10 +73,32 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
} else {
server.blacklisted.Store(blacklisted)
}
assumedOffline, err := s.DB.IsServerAssumedOffline(context.Background(), serverName)
if err != nil {
logrus.WithError(err).Errorf("Failed to get assumed offline entry %q", serverName)
} else {
server.assumedOffline.Store(assumedOffline)
}
knownRelayServers, err := s.DB.P2PGetRelayServersForServer(context.Background(), serverName)
if err != nil {
logrus.WithError(err).Errorf("Failed to get relay server list for %q", serverName)
} else {
server.relayMutex.Lock()
server.knownRelayServers = knownRelayServers
server.relayMutex.Unlock()
}
}
return server
}
type SendMethod uint8
const (
SendDirect SendMethod = iota
SendViaRelay
)
// ServerStatistics contains information about our interactions with a
// remote federated host, e.g. how many times we were successful, how
// many times we failed etc. It also manages the backoff time and black-
@ -73,12 +107,15 @@ type ServerStatistics struct {
statistics *Statistics //
serverName gomatrixserverlib.ServerName //
blacklisted atomic.Bool // is the node blacklisted
assumedOffline atomic.Bool // is the node assumed to be offline
backoffStarted atomic.Bool // is the backoff started
backoffUntil atomic.Value // time.Time until this backoff interval ends
backoffCount atomic.Uint32 // number of times BackoffDuration has been called
successCounter atomic.Uint32 // how many times have we succeeded?
backoffNotifier func() // notifies destination queue when backoff completes
notifierMutex sync.Mutex
knownRelayServers []gomatrixserverlib.ServerName
relayMutex sync.Mutex
}
const maxJitterMultiplier = 1.4
@ -113,15 +150,21 @@ func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) {
// attempt, which increases the sent counter and resets the idle and
// failure counters. If a host was blacklisted at this point then
// we will unblacklist it.
func (s *ServerStatistics) Success() {
// `relay` specifies whether the success was to the actual destination
// or one of their relay servers.
func (s *ServerStatistics) Success(method SendMethod) {
s.cancel()
s.backoffCount.Store(0)
// NOTE : Sending to the final destination vs. a relay server has
// slightly different semantics.
if method == SendDirect {
s.successCounter.Inc()
if s.statistics.DB != nil {
if s.blacklisted.Load() && s.statistics.DB != nil {
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
}
}
}
}
// Failure marks a failure and starts backing off if needed.
@ -139,7 +182,18 @@ func (s *ServerStatistics) Failure() (time.Time, bool) {
// start a goroutine which will wait out the backoff and
// unset the backoffStarted flag when done.
if s.backoffStarted.CompareAndSwap(false, true) {
if s.backoffCount.Inc() >= s.statistics.FailuresUntilBlacklist {
backoffCount := s.backoffCount.Inc()
if backoffCount >= s.statistics.FailuresUntilAssumedOffline {
s.assumedOffline.CompareAndSwap(false, true)
if s.statistics.DB != nil {
if err := s.statistics.DB.SetServerAssumedOffline(context.Background(), s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to set %q as assumed offline", s.serverName)
}
}
}
if backoffCount >= s.statistics.FailuresUntilBlacklist {
s.blacklisted.Store(true)
if s.statistics.DB != nil {
if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil {
@ -157,13 +211,21 @@ func (s *ServerStatistics) Failure() (time.Time, bool) {
s.backoffUntil.Store(until)
s.statistics.backoffMutex.Lock()
defer s.statistics.backoffMutex.Unlock()
s.statistics.backoffTimers[s.serverName] = time.AfterFunc(time.Until(until), s.backoffFinished)
s.statistics.backoffMutex.Unlock()
}
return s.backoffUntil.Load().(time.Time), false
}
// MarkServerAlive removes the assumed offline and blacklisted statuses from this server.
// Returns whether the server was blacklisted before this point.
func (s *ServerStatistics) MarkServerAlive() bool {
s.removeAssumedOffline()
wasBlacklisted := s.removeBlacklist()
return wasBlacklisted
}
// ClearBackoff stops the backoff timer for this destination if it is running
// and removes the timer from the backoffTimers map.
func (s *ServerStatistics) ClearBackoff() {
@ -191,13 +253,13 @@ func (s *ServerStatistics) backoffFinished() {
}
// BackoffInfo returns information about the current or previous backoff.
// Returns the last backoffUntil time and whether the server is currently blacklisted or not.
func (s *ServerStatistics) BackoffInfo() (*time.Time, bool) {
// Returns the last backoffUntil time.
func (s *ServerStatistics) BackoffInfo() *time.Time {
until, ok := s.backoffUntil.Load().(time.Time)
if ok {
return &until, s.blacklisted.Load()
return &until
}
return nil, s.blacklisted.Load()
return nil
}
// Blacklisted returns true if the server is blacklisted and false
@ -206,10 +268,33 @@ func (s *ServerStatistics) Blacklisted() bool {
return s.blacklisted.Load()
}
// RemoveBlacklist removes the blacklisted status from the server.
func (s *ServerStatistics) RemoveBlacklist() {
// AssumedOffline returns true if the server is assumed offline and false
// otherwise.
func (s *ServerStatistics) AssumedOffline() bool {
return s.assumedOffline.Load()
}
// removeBlacklist removes the blacklisted status from the server.
// Returns whether the server was blacklisted.
func (s *ServerStatistics) removeBlacklist() bool {
var wasBlacklisted bool
if s.Blacklisted() {
wasBlacklisted = true
_ = s.statistics.DB.RemoveServerFromBlacklist(s.serverName)
}
s.cancel()
s.backoffCount.Store(0)
return wasBlacklisted
}
// removeAssumedOffline removes the assumed offline status from the server.
func (s *ServerStatistics) removeAssumedOffline() {
if s.AssumedOffline() {
_ = s.statistics.DB.RemoveServerAssumedOffline(context.Background(), s.serverName)
}
s.assumedOffline.Store(false)
}
// SuccessCount returns the number of successful requests. This is
@ -217,3 +302,46 @@ func (s *ServerStatistics) RemoveBlacklist() {
func (s *ServerStatistics) SuccessCount() uint32 {
return s.successCounter.Load()
}
// KnownRelayServers returns the list of relay servers associated with this
// server.
func (s *ServerStatistics) KnownRelayServers() []gomatrixserverlib.ServerName {
s.relayMutex.Lock()
defer s.relayMutex.Unlock()
return s.knownRelayServers
}
func (s *ServerStatistics) AddRelayServers(relayServers []gomatrixserverlib.ServerName) {
seenSet := make(map[gomatrixserverlib.ServerName]bool)
uniqueList := []gomatrixserverlib.ServerName{}
for _, srv := range relayServers {
if seenSet[srv] {
continue
}
seenSet[srv] = true
uniqueList = append(uniqueList, srv)
}
err := s.statistics.DB.P2PAddRelayServersForServer(context.Background(), s.serverName, uniqueList)
if err != nil {
logrus.WithError(err).Errorf("Failed to add relay servers for %q. Servers: %v", s.serverName, uniqueList)
return
}
for _, newServer := range uniqueList {
alreadyKnown := false
knownRelayServers := s.KnownRelayServers()
for _, srv := range knownRelayServers {
if srv == newServer {
alreadyKnown = true
}
}
if !alreadyKnown {
{
s.relayMutex.Lock()
s.knownRelayServers = append(s.knownRelayServers, newServer)
s.relayMutex.Unlock()
}
}
}
}

View File

@ -4,17 +4,26 @@ import (
"math"
"testing"
"time"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
const (
FailuresUntilAssumedOffline = 3
FailuresUntilBlacklist = 8
)
func TestBackoff(t *testing.T) {
stats := NewStatistics(nil, 7)
stats := NewStatistics(nil, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
server := ServerStatistics{
statistics: &stats,
serverName: "test.com",
}
// Start by checking that counting successes works.
server.Success()
server.Success(SendDirect)
if successes := server.SuccessCount(); successes != 1 {
t.Fatalf("Expected success count 1, got %d", successes)
}
@ -31,9 +40,8 @@ func TestBackoff(t *testing.T) {
// side effects since a backoff is already in progress. If it does
// then we'll fail.
until, blacklisted := server.Failure()
// Get the duration.
_, blacklist := server.BackoffInfo()
blacklist := server.Blacklisted()
assumedOffline := server.AssumedOffline()
duration := time.Until(until)
// Unset the backoff, or otherwise our next call will think that
@ -41,16 +49,43 @@ func TestBackoff(t *testing.T) {
server.cancel()
server.backoffStarted.Store(false)
if i >= stats.FailuresUntilAssumedOffline {
if !assumedOffline {
t.Fatalf("Backoff %d should have resulted in assuming the destination was offline but didn't", i)
}
}
// Check if we should be assumed offline by now.
if i >= stats.FailuresUntilAssumedOffline {
if !assumedOffline {
t.Fatalf("Backoff %d should have resulted in assumed offline but didn't", i)
} else {
t.Logf("Backoff %d is assumed offline as expected", i)
}
} else {
if assumedOffline {
t.Fatalf("Backoff %d should not have resulted in assumed offline but did", i)
} else {
t.Logf("Backoff %d is not assumed offline as expected", i)
}
}
// Check if we should be blacklisted by now.
if i >= stats.FailuresUntilBlacklist {
if !blacklist {
t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i)
} else if blacklist != blacklisted {
t.Fatalf("BackoffInfo and Failure returned different blacklist values")
t.Fatalf("Blacklisted and Failure returned different blacklist values")
} else {
t.Logf("Backoff %d is blacklisted as expected", i)
continue
}
} else {
if blacklist {
t.Fatalf("Backoff %d should not have resulted in blacklist but did", i)
} else {
t.Logf("Backoff %d is not blacklisted as expected", i)
}
}
// Check if the duration is what we expect.
@ -69,3 +104,14 @@ func TestBackoff(t *testing.T) {
}
}
}
func TestRelayServersListing(t *testing.T) {
stats := NewStatistics(test.NewInMemoryFederationDatabase(), FailuresUntilBlacklist, FailuresUntilAssumedOffline)
server := ServerStatistics{statistics: &stats}
server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"})
relayServers := server.KnownRelayServers()
assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers)
server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"})
relayServers = server.KnownRelayServers()
assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers)
}

View File

@ -20,11 +20,12 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/dendrite/federationapi/types"
)
type Database interface {
P2PDatabase
gomatrixserverlib.KeyDatabase
UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error)
@ -34,16 +35,16 @@ type Database interface {
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error)
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
StoreJSON(ctx context.Context, js string) (*receipt.Receipt, error)
GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error)
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error)
GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error)
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error)
AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error
AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt) error
AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error
GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
@ -54,6 +55,18 @@ type Database interface {
RemoveAllServersFromBlacklist() error
IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error)
// Adds the server to the list of assumed offline servers.
// If the server already exists in the table, nothing happens and returns success.
SetServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error
// Removes the server from the list of assumed offline servers.
// If the server doesn't exist in the table, nothing happens and returns success.
RemoveServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error
// Purges all entries from the assumed offline table.
RemoveAllServersAssumedOffline(ctx context.Context) error
// Gets whether the provided server is present in the table.
// If it is present, returns true. If not, returns false.
IsServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) (bool, error)
AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error)
@ -74,3 +87,21 @@ type Database interface {
PurgeRoom(ctx context.Context, roomID string) error
}
type P2PDatabase interface {
// Stores the given list of servers as relay servers for the provided destination server.
// Providing duplicates will only lead to a single entry and won't lead to an error.
P2PAddRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
// Get the list of relay servers associated with the provided destination server.
// If no entry exists in the table, an empty list is returned and does not result in an error.
P2PGetRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error)
// Deletes any entries for the provided destination server that match the provided relayServers list.
// If any of the provided servers don't match an entry, nothing happens and no error is returned.
P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
// Deletes all entries for the provided destination server.
// If the destination server doesn't exist in the table, nothing happens and no error is returned.
P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error
}

View File

@ -0,0 +1,107 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const assumedOfflineSchema = `
CREATE TABLE IF NOT EXISTS federationsender_assumed_offline(
-- The assumed offline server name
server_name TEXT PRIMARY KEY NOT NULL
);
`
const insertAssumedOfflineSQL = "" +
"INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" +
" ON CONFLICT DO NOTHING"
const selectAssumedOfflineSQL = "" +
"SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1"
const deleteAssumedOfflineSQL = "" +
"DELETE FROM federationsender_assumed_offline WHERE server_name = $1"
const deleteAllAssumedOfflineSQL = "" +
"TRUNCATE federationsender_assumed_offline"
type assumedOfflineStatements struct {
db *sql.DB
insertAssumedOfflineStmt *sql.Stmt
selectAssumedOfflineStmt *sql.Stmt
deleteAssumedOfflineStmt *sql.Stmt
deleteAllAssumedOfflineStmt *sql.Stmt
}
func NewPostgresAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) {
s = &assumedOfflineStatements{
db: db,
}
_, err = db.Exec(assumedOfflineSchema)
if err != nil {
return
}
return s, sqlutil.StatementList{
{&s.insertAssumedOfflineStmt, insertAssumedOfflineSQL},
{&s.selectAssumedOfflineStmt, selectAssumedOfflineSQL},
{&s.deleteAssumedOfflineStmt, deleteAssumedOfflineSQL},
{&s.deleteAllAssumedOfflineStmt, deleteAllAssumedOfflineSQL},
}.Prepare(db)
}
func (s *assumedOfflineStatements) InsertAssumedOffline(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}
func (s *assumedOfflineStatements) SelectAssumedOffline(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (bool, error) {
stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt)
res, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return false, err
}
defer res.Close() // nolint:errcheck
// The query will return the server name if the server is assume offline, and
// will return no rows if not. By calling Next, we find out if a row was
// returned or not - we don't care about the value itself.
return res.Next(), nil
}
func (s *assumedOfflineStatements) DeleteAssumedOffline(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}
func (s *assumedOfflineStatements) DeleteAllAssumedOffline(
ctx context.Context, txn *sql.Tx,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt)
_, err := stmt.ExecContext(ctx)
return err
}

View File

@ -0,0 +1,137 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const relayServersSchema = `
CREATE TABLE IF NOT EXISTS federationsender_relay_servers (
-- The destination server name
server_name TEXT NOT NULL,
-- The relay server name for a given destination
relay_server_name TEXT NOT NULL,
UNIQUE (server_name, relay_server_name)
);
CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx
ON federationsender_relay_servers (server_name);
`
const insertRelayServersSQL = "" +
"INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" +
" ON CONFLICT DO NOTHING"
const selectRelayServersSQL = "" +
"SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1"
const deleteRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name = ANY($2)"
const deleteAllRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1"
type relayServersStatements struct {
db *sql.DB
insertRelayServersStmt *sql.Stmt
selectRelayServersStmt *sql.Stmt
deleteRelayServersStmt *sql.Stmt
deleteAllRelayServersStmt *sql.Stmt
}
func NewPostgresRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) {
s = &relayServersStatements{
db: db,
}
_, err = db.Exec(relayServersSchema)
if err != nil {
return
}
return s, sqlutil.StatementList{
{&s.insertRelayServersStmt, insertRelayServersSQL},
{&s.selectRelayServersStmt, selectRelayServersSQL},
{&s.deleteRelayServersStmt, deleteRelayServersSQL},
{&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL},
}.Prepare(db)
}
func (s *relayServersStatements) InsertRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
for _, relayServer := range relayServers {
stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt)
if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil {
return err
}
}
return nil
}
func (s *relayServersStatements) SelectRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt)
rows, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var relayServer string
if err = rows.Scan(&relayServer); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(relayServer))
}
return result, nil
}
func (s *relayServersStatements) DeleteRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt)
_, err := stmt.ExecContext(ctx, serverName, pq.Array(relayServers))
return err
}
func (s *relayServersStatements) DeleteAllRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt)
if _, err := stmt.ExecContext(ctx, serverName); err != nil {
return err
}
return nil
}

View File

@ -62,6 +62,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, err
}
assumedOffline, err := NewPostgresAssumedOfflineTable(d.db)
if err != nil {
return nil, err
}
relayServers, err := NewPostgresRelayServersTable(d.db)
if err != nil {
return nil, err
}
inboundPeeks, err := NewPostgresInboundPeeksTable(d.db)
if err != nil {
return nil, err
@ -104,6 +112,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
FederationQueueEDUs: queueEDUs,
FederationQueueJSON: queueJSON,
FederationBlacklist: blacklist,
FederationAssumedOffline: assumedOffline,
FederationRelayServers: relayServers,
FederationInboundPeeks: inboundPeeks,
FederationOutboundPeeks: outboundPeeks,
NotaryServerKeysJSON: notaryJSON,

View File

@ -0,0 +1,42 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// A Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs.
// We don't actually export the NIDs but we need the caller to be able
// to pass them back so that we can clean up if the transaction sends
// successfully.
package receipt
import "fmt"
// Receipt is a wrapper type used to represent a nid that corresponds to a unique row entry
// in some database table.
// The internal nid value cannot be modified after a Receipt has been created.
// This guarantees a receipt will always refer to the same table entry that it was created
// to represent.
type Receipt struct {
nid int64
}
func NewReceipt(nid int64) Receipt {
return Receipt{nid: nid}
}
func (r *Receipt) GetNID() int64 {
return r.nid
}
func (r *Receipt) String() string {
return fmt.Sprintf("%d", r.nid)
}

View File

@ -20,6 +20,7 @@ import (
"fmt"
"time"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/dendrite/internal/caching"
@ -37,6 +38,8 @@ type Database struct {
FederationQueueJSON tables.FederationQueueJSON
FederationJoinedHosts tables.FederationJoinedHosts
FederationBlacklist tables.FederationBlacklist
FederationAssumedOffline tables.FederationAssumedOffline
FederationRelayServers tables.FederationRelayServers
FederationOutboundPeeks tables.FederationOutboundPeeks
FederationInboundPeeks tables.FederationInboundPeeks
NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON
@ -44,22 +47,6 @@ type Database struct {
ServerSigningKeys tables.FederationServerSigningKeys
}
// An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs.
// We don't actually export the NIDs but we need the caller to be able
// to pass them back so that we can clean up if the transaction sends
// successfully.
type Receipt struct {
nid int64
}
func NewReceipt(nid int64) Receipt {
return Receipt{nid: nid}
}
func (r *Receipt) String() string {
return fmt.Sprintf("%d", r.nid)
}
// UpdateRoom updates the joined hosts for a room and returns what the joined
// hosts were before the update, or nil if this was a duplicate message.
// This is called when we receive a message from kafka, so we pass in
@ -113,11 +100,18 @@ func (d *Database) GetJoinedHosts(
// GetAllJoinedHosts returns the currently joined hosts for
// all rooms known to the federation sender.
// Returns an error if something goes wrong.
func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
func (d *Database) GetAllJoinedHosts(
ctx context.Context,
) ([]gomatrixserverlib.ServerName, error) {
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
}
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) {
func (d *Database) GetJoinedHostsForRooms(
ctx context.Context,
roomIDs []string,
excludeSelf,
excludeBlacklisted bool,
) ([]gomatrixserverlib.ServerName, error) {
servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted)
if err != nil {
return nil, err
@ -139,7 +133,7 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string,
// metadata entries.
func (d *Database) StoreJSON(
ctx context.Context, js string,
) (*Receipt, error) {
) (*receipt.Receipt, error) {
var nid int64
var err error
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
@ -149,18 +143,21 @@ func (d *Database) StoreJSON(
if err != nil {
return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
}
return &Receipt{
nid: nid,
}, nil
newReceipt := receipt.NewReceipt(nid)
return &newReceipt, nil
}
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
func (d *Database) AddServerToBlacklist(
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName)
})
}
func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error {
func (d *Database) RemoveServerFromBlacklist(
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationBlacklist.DeleteBlacklist(context.TODO(), txn, serverName)
})
@ -172,51 +169,166 @@ func (d *Database) RemoveAllServersFromBlacklist() error {
})
}
func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) {
func (d *Database) IsServerBlacklisted(
serverName gomatrixserverlib.ServerName,
) (bool, error) {
return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName)
}
func (d *Database) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
func (d *Database) SetServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationAssumedOffline.InsertAssumedOffline(ctx, txn, serverName)
})
}
func (d *Database) RemoveServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationAssumedOffline.DeleteAssumedOffline(ctx, txn, serverName)
})
}
func (d *Database) RemoveAllServersAssumedOffline(
ctx context.Context,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationAssumedOffline.DeleteAllAssumedOffline(ctx, txn)
})
}
func (d *Database) IsServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) (bool, error) {
return d.FederationAssumedOffline.SelectAssumedOffline(ctx, nil, serverName)
}
func (d *Database) P2PAddRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationRelayServers.InsertRelayServers(ctx, txn, serverName, relayServers)
})
}
func (d *Database) P2PGetRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) ([]gomatrixserverlib.ServerName, error) {
return d.FederationRelayServers.SelectRelayServers(ctx, nil, serverName)
}
func (d *Database) P2PRemoveRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationRelayServers.DeleteRelayServers(ctx, txn, serverName, relayServers)
})
}
func (d *Database) P2PRemoveAllRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationRelayServers.DeleteAllRelayServers(ctx, txn, serverName)
})
}
func (d *Database) AddOutboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
renewalInterval int64,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
})
}
func (d *Database) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
func (d *Database) RenewOutboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
renewalInterval int64,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationOutboundPeeks.RenewOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
})
}
func (d *Database) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) {
func (d *Database) GetOutboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID,
peekID string,
) (*types.OutboundPeek, error) {
return d.FederationOutboundPeeks.SelectOutboundPeek(ctx, nil, serverName, roomID, peekID)
}
func (d *Database) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) {
func (d *Database) GetOutboundPeeks(
ctx context.Context,
roomID string,
) ([]types.OutboundPeek, error) {
return d.FederationOutboundPeeks.SelectOutboundPeeks(ctx, nil, roomID)
}
func (d *Database) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
func (d *Database) AddInboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
renewalInterval int64,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationInboundPeeks.InsertInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
})
}
func (d *Database) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
func (d *Database) RenewInboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
renewalInterval int64,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationInboundPeeks.RenewInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
})
}
func (d *Database) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) {
func (d *Database) GetInboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
) (*types.InboundPeek, error) {
return d.FederationInboundPeeks.SelectInboundPeek(ctx, nil, serverName, roomID, peekID)
}
func (d *Database) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) {
func (d *Database) GetInboundPeeks(
ctx context.Context,
roomID string,
) ([]types.InboundPeek, error) {
return d.FederationInboundPeeks.SelectInboundPeeks(ctx, nil, roomID)
}
func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error {
func (d *Database) UpdateNotaryKeys(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
serverKeys gomatrixserverlib.ServerKeys,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
validUntil := serverKeys.ValidUntilTS
// Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid.
@ -251,7 +363,9 @@ func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserv
}
func (d *Database) GetNotaryKeys(
ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID,
ctx context.Context,
serverName gomatrixserverlib.ServerName,
optKeyIDs []gomatrixserverlib.KeyID,
) (sks []gomatrixserverlib.ServerKeys, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sks, err = d.NotaryServerKeysMetadata.SelectKeys(ctx, txn, serverName, optKeyIDs)

View File

@ -22,6 +22,7 @@ import (
"fmt"
"time"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/gomatrixserverlib"
)
@ -41,7 +42,7 @@ var defaultExpireEDUTypes = map[string]time.Duration{
func (d *Database) AssociateEDUWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.ServerName]struct{},
receipt *Receipt,
dbReceipt *receipt.Receipt,
eduType string,
expireEDUTypes map[string]time.Duration,
) error {
@ -66,7 +67,7 @@ func (d *Database) AssociateEDUWithDestinations(
txn, // SQL transaction
eduType, // EDU type for coalescing
destination, // destination server name
receipt.nid, // NID from the federationapi_queue_json table
dbReceipt.GetNID(), // NID from the federationapi_queue_json table
expiresAt, // The timestamp this EDU will expire
)
}
@ -81,10 +82,10 @@ func (d *Database) GetPendingEDUs(
serverName gomatrixserverlib.ServerName,
limit int,
) (
edus map[*Receipt]*gomatrixserverlib.EDU,
edus map[*receipt.Receipt]*gomatrixserverlib.EDU,
err error,
) {
edus = make(map[*Receipt]*gomatrixserverlib.EDU)
edus = make(map[*receipt.Receipt]*gomatrixserverlib.EDU)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nids, err := d.FederationQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit)
if err != nil {
@ -94,7 +95,8 @@ func (d *Database) GetPendingEDUs(
retrieve := make([]int64, 0, len(nids))
for _, nid := range nids {
if edu, ok := d.Cache.GetFederationQueuedEDU(nid); ok {
edus[&Receipt{nid}] = edu
newReceipt := receipt.NewReceipt(nid)
edus[&newReceipt] = edu
} else {
retrieve = append(retrieve, nid)
}
@ -110,7 +112,8 @@ func (d *Database) GetPendingEDUs(
if err := json.Unmarshal(blob, &event); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err)
}
edus[&Receipt{nid}] = &event
newReceipt := receipt.NewReceipt(nid)
edus[&newReceipt] = &event
d.Cache.StoreFederationQueuedEDU(nid, &event)
}
@ -124,7 +127,7 @@ func (d *Database) GetPendingEDUs(
func (d *Database) CleanEDUs(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
receipts []*Receipt,
receipts []*receipt.Receipt,
) error {
if len(receipts) == 0 {
return errors.New("expected receipt")
@ -132,7 +135,7 @@ func (d *Database) CleanEDUs(
nids := make([]int64, len(receipts))
for i := range receipts {
nids[i] = receipts[i].nid
nids[i] = receipts[i].GetNID()
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {

View File

@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/gomatrixserverlib"
)
@ -30,7 +31,7 @@ import (
func (d *Database) AssociatePDUWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.ServerName]struct{},
receipt *Receipt,
dbReceipt *receipt.Receipt,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
@ -40,7 +41,7 @@ func (d *Database) AssociatePDUWithDestinations(
txn, // SQL transaction
"", // transaction ID
destination, // destination server name
receipt.nid, // NID from the federationapi_queue_json table
dbReceipt.GetNID(), // NID from the federationapi_queue_json table
)
}
return err
@ -54,7 +55,7 @@ func (d *Database) GetPendingPDUs(
serverName gomatrixserverlib.ServerName,
limit int,
) (
events map[*Receipt]*gomatrixserverlib.HeaderedEvent,
events map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent,
err error,
) {
// Strictly speaking this doesn't need to be using the writer
@ -62,7 +63,7 @@ func (d *Database) GetPendingPDUs(
// a guarantee of transactional isolation, it's actually useful
// to know in SQLite mode that nothing else is trying to modify
// the database.
events = make(map[*Receipt]*gomatrixserverlib.HeaderedEvent)
events = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nids, err := d.FederationQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, limit)
if err != nil {
@ -72,7 +73,8 @@ func (d *Database) GetPendingPDUs(
retrieve := make([]int64, 0, len(nids))
for _, nid := range nids {
if event, ok := d.Cache.GetFederationQueuedPDU(nid); ok {
events[&Receipt{nid}] = event
newReceipt := receipt.NewReceipt(nid)
events[&newReceipt] = event
} else {
retrieve = append(retrieve, nid)
}
@ -88,7 +90,8 @@ func (d *Database) GetPendingPDUs(
if err := json.Unmarshal(blob, &event); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err)
}
events[&Receipt{nid}] = &event
newReceipt := receipt.NewReceipt(nid)
events[&newReceipt] = &event
d.Cache.StoreFederationQueuedPDU(nid, &event)
}
@ -103,7 +106,7 @@ func (d *Database) GetPendingPDUs(
func (d *Database) CleanPDUs(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
receipts []*Receipt,
receipts []*receipt.Receipt,
) error {
if len(receipts) == 0 {
return errors.New("expected receipt")
@ -111,7 +114,7 @@ func (d *Database) CleanPDUs(
nids := make([]int64, len(receipts))
for i := range receipts {
nids[i] = receipts[i].nid
nids[i] = receipts[i].GetNID()
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {

View File

@ -0,0 +1,107 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const assumedOfflineSchema = `
CREATE TABLE IF NOT EXISTS federationsender_assumed_offline(
-- The assumed offline server name
server_name TEXT PRIMARY KEY NOT NULL
);
`
const insertAssumedOfflineSQL = "" +
"INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" +
" ON CONFLICT DO NOTHING"
const selectAssumedOfflineSQL = "" +
"SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1"
const deleteAssumedOfflineSQL = "" +
"DELETE FROM federationsender_assumed_offline WHERE server_name = $1"
const deleteAllAssumedOfflineSQL = "" +
"DELETE FROM federationsender_assumed_offline"
type assumedOfflineStatements struct {
db *sql.DB
insertAssumedOfflineStmt *sql.Stmt
selectAssumedOfflineStmt *sql.Stmt
deleteAssumedOfflineStmt *sql.Stmt
deleteAllAssumedOfflineStmt *sql.Stmt
}
func NewSQLiteAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) {
s = &assumedOfflineStatements{
db: db,
}
_, err = db.Exec(assumedOfflineSchema)
if err != nil {
return
}
return s, sqlutil.StatementList{
{&s.insertAssumedOfflineStmt, insertAssumedOfflineSQL},
{&s.selectAssumedOfflineStmt, selectAssumedOfflineSQL},
{&s.deleteAssumedOfflineStmt, deleteAssumedOfflineSQL},
{&s.deleteAllAssumedOfflineStmt, deleteAllAssumedOfflineSQL},
}.Prepare(db)
}
func (s *assumedOfflineStatements) InsertAssumedOffline(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}
func (s *assumedOfflineStatements) SelectAssumedOffline(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (bool, error) {
stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt)
res, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return false, err
}
defer res.Close() // nolint:errcheck
// The query will return the server name if the server is assume offline, and
// will return no rows if not. By calling Next, we find out if a row was
// returned or not - we don't care about the value itself.
return res.Next(), nil
}
func (s *assumedOfflineStatements) DeleteAssumedOffline(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}
func (s *assumedOfflineStatements) DeleteAllAssumedOffline(
ctx context.Context, txn *sql.Tx,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt)
_, err := stmt.ExecContext(ctx)
return err
}

View File

@ -0,0 +1,148 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const relayServersSchema = `
CREATE TABLE IF NOT EXISTS federationsender_relay_servers (
-- The destination server name
server_name TEXT NOT NULL,
-- The relay server name for a given destination
relay_server_name TEXT NOT NULL,
UNIQUE (server_name, relay_server_name)
);
CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx
ON federationsender_relay_servers (server_name);
`
const insertRelayServersSQL = "" +
"INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" +
" ON CONFLICT DO NOTHING"
const selectRelayServersSQL = "" +
"SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1"
const deleteRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name IN ($2)"
const deleteAllRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1"
type relayServersStatements struct {
db *sql.DB
insertRelayServersStmt *sql.Stmt
selectRelayServersStmt *sql.Stmt
// deleteRelayServersStmt *sql.Stmt - prepared at runtime due to variadic
deleteAllRelayServersStmt *sql.Stmt
}
func NewSQLiteRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) {
s = &relayServersStatements{
db: db,
}
_, err = db.Exec(relayServersSchema)
if err != nil {
return
}
return s, sqlutil.StatementList{
{&s.insertRelayServersStmt, insertRelayServersSQL},
{&s.selectRelayServersStmt, selectRelayServersSQL},
{&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL},
}.Prepare(db)
}
func (s *relayServersStatements) InsertRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
for _, relayServer := range relayServers {
stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt)
if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil {
return err
}
}
return nil
}
func (s *relayServersStatements) SelectRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt)
rows, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var relayServer string
if err = rows.Scan(&relayServer); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(relayServer))
}
return result, nil
}
func (s *relayServersStatements) DeleteRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
deleteSQL := strings.Replace(deleteRelayServersSQL, "($2)", sqlutil.QueryVariadicOffset(len(relayServers), 1), 1)
deleteStmt, err := s.db.Prepare(deleteSQL)
if err != nil {
return err
}
stmt := sqlutil.TxStmt(txn, deleteStmt)
params := make([]interface{}, len(relayServers)+1)
params[0] = serverName
for i, v := range relayServers {
params[i+1] = v
}
_, err = stmt.ExecContext(ctx, params...)
return err
}
func (s *relayServersStatements) DeleteAllRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt)
if _, err := stmt.ExecContext(ctx, serverName); err != nil {
return err
}
return nil
}

View File

@ -1,5 +1,4 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -61,6 +60,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, err
}
assumedOffline, err := NewSQLiteAssumedOfflineTable(d.db)
if err != nil {
return nil, err
}
relayServers, err := NewSQLiteRelayServersTable(d.db)
if err != nil {
return nil, err
}
outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db)
if err != nil {
return nil, err
@ -103,6 +110,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
FederationQueueEDUs: queueEDUs,
FederationQueueJSON: queueJSON,
FederationBlacklist: blacklist,
FederationAssumedOffline: assumedOffline,
FederationRelayServers: relayServers,
FederationOutboundPeeks: outboundPeeks,
FederationInboundPeeks: inboundPeeks,
NotaryServerKeysJSON: notaryKeys,

View File

@ -6,14 +6,13 @@ import (
"testing"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
)
func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
@ -246,3 +245,99 @@ func TestInboundPeeking(t *testing.T) {
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
})
}
func TestServersAssumedOffline(t *testing.T) {
server1 := gomatrixserverlib.ServerName("server1")
server2 := gomatrixserverlib.ServerName("server2")
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, closeDB := mustCreateFederationDatabase(t, dbType)
defer closeDB()
// Set server1 & server2 as assumed offline.
err := db.SetServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
err = db.SetServerAssumedOffline(context.Background(), server2)
assert.Nil(t, err)
// Ensure both servers are assumed offline.
isOffline, err := db.IsServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
assert.True(t, isOffline)
isOffline, err = db.IsServerAssumedOffline(context.Background(), server2)
assert.Nil(t, err)
assert.True(t, isOffline)
// Set server1 as not assumed offline.
err = db.RemoveServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
// Ensure both servers have correct state.
isOffline, err = db.IsServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
assert.False(t, isOffline)
isOffline, err = db.IsServerAssumedOffline(context.Background(), server2)
assert.Nil(t, err)
assert.True(t, isOffline)
// Re-set server1 as assumed offline.
err = db.SetServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
// Ensure server1 is assumed offline.
isOffline, err = db.IsServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
assert.True(t, isOffline)
err = db.RemoveAllServersAssumedOffline(context.Background())
assert.Nil(t, err)
// Ensure both servers have correct state.
isOffline, err = db.IsServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
assert.False(t, isOffline)
isOffline, err = db.IsServerAssumedOffline(context.Background(), server2)
assert.Nil(t, err)
assert.False(t, isOffline)
})
}
func TestRelayServersStored(t *testing.T) {
server := gomatrixserverlib.ServerName("server")
relayServer1 := gomatrixserverlib.ServerName("relayserver1")
relayServer2 := gomatrixserverlib.ServerName("relayserver2")
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, closeDB := mustCreateFederationDatabase(t, dbType)
defer closeDB()
err := db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1})
assert.Nil(t, err)
relayServers, err := db.P2PGetRelayServersForServer(context.Background(), server)
assert.Nil(t, err)
assert.Equal(t, relayServer1, relayServers[0])
err = db.P2PRemoveRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1})
assert.Nil(t, err)
relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server)
assert.Nil(t, err)
assert.Zero(t, len(relayServers))
err = db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1, relayServer2})
assert.Nil(t, err)
relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server)
assert.Nil(t, err)
assert.Equal(t, relayServer1, relayServers[0])
assert.Equal(t, relayServer2, relayServers[1])
err = db.P2PRemoveAllRelayServersForServer(context.Background(), server)
assert.Nil(t, err)
relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server)
assert.Nil(t, err)
assert.Zero(t, len(relayServers))
})
}

View File

@ -49,6 +49,19 @@ type FederationQueueJSON interface {
SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error)
}
type FederationQueueTransactions interface {
InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
}
type FederationTransactionJSON interface {
InsertTransactionJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error)
DeleteTransactionJSON(ctx context.Context, txn *sql.Tx, nids []int64) error
SelectTransactionJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error)
}
type FederationJoinedHosts interface {
InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error
DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error
@ -66,6 +79,20 @@ type FederationBlacklist interface {
DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error
}
type FederationAssumedOffline interface {
InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error
SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error)
DeleteAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error
DeleteAllAssumedOffline(ctx context.Context, txn *sql.Tx) error
}
type FederationRelayServers interface {
InsertRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
SelectRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error)
DeleteRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
DeleteAllRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error
}
type FederationOutboundPeeks interface {
InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)
RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)

View File

@ -0,0 +1,224 @@
package tables_test
import (
"context"
"database/sql"
"testing"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
const (
server1 = "server1"
server2 = "server2"
server3 = "server3"
server4 = "server4"
)
type RelayServersDatabase struct {
DB *sql.DB
Writer sqlutil.Writer
Table tables.FederationRelayServers
}
func mustCreateRelayServersTable(
t *testing.T,
dbType test.DBType,
) (database RelayServersDatabase, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
var tab tables.FederationRelayServers
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresRelayServersTable(db)
assert.NoError(t, err)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteRelayServersTable(db)
assert.NoError(t, err)
}
assert.NoError(t, err)
database = RelayServersDatabase{
DB: db,
Writer: sqlutil.NewDummyWriter(),
Table: tab,
}
return database, close
}
func Equal(a, b []gomatrixserverlib.ServerName) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
func TestShouldInsertRelayServers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType)
defer close()
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
}
})
}
func TestShouldInsertRelayServersWithDuplicates(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType)
defer close()
insertRelayServers := []gomatrixserverlib.ServerName{server2, server2, server2, server3, server2}
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
// Insert the same list again, this shouldn't fail and should have no effect.
err = db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
}
})
}
func TestShouldGetRelayServersUnknownDestination(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType)
defer close()
// Query relay servers for a destination that doesn't exist in the table.
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, []gomatrixserverlib.ServerName{}) {
t.Fatalf("Expected: %v \nActual: %v", []gomatrixserverlib.ServerName{}, relayServers)
}
})
}
func TestShouldDeleteCorrectRelayServers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType)
defer close()
relayServers1 := []gomatrixserverlib.ServerName{server2, server3}
relayServers2 := []gomatrixserverlib.ServerName{server1, server3, server4}
err := db.Table.InsertRelayServers(ctx, nil, server1, relayServers1)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertRelayServers(ctx, nil, server2, relayServers2)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.DeleteRelayServers(ctx, nil, server1, []gomatrixserverlib.ServerName{server2})
if err != nil {
t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error())
}
err = db.Table.DeleteRelayServers(ctx, nil, server2, []gomatrixserverlib.ServerName{server1, server4})
if err != nil {
t.Fatalf("Failed deleting relay servers for %s: %s", server2, err.Error())
}
expectedRelayServers := []gomatrixserverlib.ServerName{server3}
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
}
relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
}
})
}
func TestShouldDeleteAllRelayServers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType)
defer close()
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertRelayServers(ctx, nil, server2, expectedRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.DeleteAllRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error())
}
expectedRelayServers1 := []gomatrixserverlib.ServerName{}
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers1) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers1, relayServers)
}
relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
}
})
}

20
go.mod
View File

@ -22,9 +22,9 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8
github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/matrix-org/gomatrixserverlib v0.0.0-20230119205614-cb888d80b00f
github.com/matrix-org/pinecone v0.11.1-0.20230111184901-61850f0e63cb
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.15
github.com/nats-io/nats-server/v2 v2.9.8
github.com/nats-io/nats.go v1.20.0
@ -37,17 +37,17 @@ require (
github.com/prometheus/client_golang v1.13.0
github.com/sirupsen/logrus v1.9.0
github.com/stretchr/testify v1.8.1
github.com/tidwall/gjson v1.14.3
github.com/tidwall/gjson v1.14.4
github.com/tidwall/sjson v1.2.5
github.com/uber/jaeger-client-go v2.30.0+incompatible
github.com/uber/jaeger-lib v2.4.1+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.4.6
go.uber.org/atomic v1.10.0
golang.org/x/crypto v0.1.0
golang.org/x/crypto v0.5.0
golang.org/x/image v0.1.0
golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e
golang.org/x/net v0.1.0
golang.org/x/term v0.1.0
golang.org/x/net v0.5.0
golang.org/x/term v0.4.0
gopkg.in/h2non/bimg.v1 v1.1.9
gopkg.in/yaml.v2 v2.4.0
gotest.tools/v3 v3.4.0
@ -119,12 +119,12 @@ require (
github.com/prometheus/procfs v0.8.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20220927061507-ef77025ab5aa // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
go.etcd.io/bbolt v1.3.6 // indirect
golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 // indirect
golang.org/x/mod v0.6.0 // indirect
golang.org/x/sys v0.1.0 // indirect
golang.org/x/text v0.4.0 // indirect
golang.org/x/sys v0.4.0 // indirect
golang.org/x/text v0.6.0 // indirect
golang.org/x/time v0.1.0 // indirect
golang.org/x/tools v0.2.0 // indirect
google.golang.org/protobuf v1.28.1 // indirect

42
go.sum
View File

@ -348,16 +348,12 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab h1:ChaQdT2mpxMm3GRXNOZzLDQ/wOnlKZ8o60LmZGOjdj8=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230119135050-7da03ab58f45 h1:zGrmcm2M4F4f+zk5JXAkw3oHa/zXhOh5XVGBdl7GdPo=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230119135050-7da03ab58f45/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8 h1:P7me2oCmksST9B4+1I1nA+XrnDQwIqAWmy6ntQrXwc8=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM=
github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230119205614-cb888d80b00f h1:niRWEVkeeekpjxwnMhKn8PD0PUloDsNXP8W+Ez/co/M=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230119205614-cb888d80b00f/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/pinecone v0.11.1-0.20230111184901-61850f0e63cb h1:2L+ltfNKab56FoBBqAvbBLjoAbxwwoZie+B8d+Mp3JI=
github.com/matrix-org/pinecone v0.11.1-0.20230111184901-61850f0e63cb/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66/go.mod h1:iBI1foelCqA09JJgPV0FYz4qA5dUXYOxMi57FxKBdd4=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
@ -494,12 +490,13 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
@ -543,8 +540,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU=
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE=
golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -625,8 +622,8 @@ golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qx
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0=
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw=
golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -701,12 +698,12 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw=
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.4.0 h1:O7UWfv5+A2qiuulQk30kVinPoMtoIPeVaKLEgLpVkvg=
golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -714,8 +711,9 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k=
golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

View File

@ -101,6 +101,8 @@ func SetupPprof() {
// SetupStdLogging configures the logging format to standard output. Typically, it is called when the config is not yet loaded.
func SetupStdLogging() {
levelLogAddedMu.Lock()
defer levelLogAddedMu.Unlock()
logrus.SetReportCaller(true)
logrus.SetFormatter(&utcFormatter{
&logrus.TextFormatter{

View File

@ -32,6 +32,8 @@ import (
// If something fails here it means that the logging was improperly configured,
// so we just exit with the error
func SetupHookLogging(hooks []config.LogrusHook, componentName string) {
levelLogAddedMu.Lock()
defer levelLogAddedMu.Unlock()
for _, hook := range hooks {
// Check we received a proper logging level
level, err := logrus.ParseLevel(hook.Level)
@ -85,8 +87,6 @@ func checkSyslogHookParams(params map[string]interface{}) {
}
func setupStdLogHook(level logrus.Level) {
levelLogAddedMu.Lock()
defer levelLogAddedMu.Unlock()
if stdLevelLogAdded[level] {
return
}

View File

@ -0,0 +1,356 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"encoding/json"
"fmt"
"sync"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/federationapi/types"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
syncTypes "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
)
var (
PDUCountTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: "dendrite",
Subsystem: "federationapi",
Name: "recv_pdus",
Help: "Number of incoming PDUs from remote servers with labels for success",
},
[]string{"status"}, // 'success' or 'total'
)
EDUCountTotal = prometheus.NewCounter(
prometheus.CounterOpts{
Namespace: "dendrite",
Subsystem: "federationapi",
Name: "recv_edus",
Help: "Number of incoming EDUs from remote servers",
},
)
)
type TxnReq struct {
gomatrixserverlib.Transaction
rsAPI api.FederationRoomserverAPI
keyAPI keyapi.FederationKeyAPI
ourServerName gomatrixserverlib.ServerName
keys gomatrixserverlib.JSONVerifier
roomsMu *MutexByRoom
producer *producers.SyncAPIProducer
inboundPresenceEnabled bool
}
func NewTxnReq(
rsAPI api.FederationRoomserverAPI,
keyAPI keyapi.FederationKeyAPI,
ourServerName gomatrixserverlib.ServerName,
keys gomatrixserverlib.JSONVerifier,
roomsMu *MutexByRoom,
producer *producers.SyncAPIProducer,
inboundPresenceEnabled bool,
pdus []json.RawMessage,
edus []gomatrixserverlib.EDU,
origin gomatrixserverlib.ServerName,
transactionID gomatrixserverlib.TransactionID,
destination gomatrixserverlib.ServerName,
) TxnReq {
t := TxnReq{
rsAPI: rsAPI,
keyAPI: keyAPI,
ourServerName: ourServerName,
keys: keys,
roomsMu: roomsMu,
producer: producer,
inboundPresenceEnabled: inboundPresenceEnabled,
}
t.PDUs = pdus
t.EDUs = edus
t.Origin = origin
t.TransactionID = transactionID
t.Destination = destination
return t
}
func (t *TxnReq) ProcessTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if t.producer != nil {
t.processEDUs(ctx)
}
}()
results := make(map[string]gomatrixserverlib.PDUResult)
roomVersions := make(map[string]gomatrixserverlib.RoomVersion)
getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion {
if v, ok := roomVersions[roomID]; ok {
return v
}
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID}
verRes := api.QueryRoomVersionForRoomResponse{}
if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil {
util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", verReq.RoomID)
return ""
}
roomVersions[roomID] = verRes.RoomVersion
return verRes.RoomVersion
}
for _, pdu := range t.PDUs {
PDUCountTotal.WithLabelValues("total").Inc()
var header struct {
RoomID string `json:"room_id"`
}
if err := json.Unmarshal(pdu, &header); err != nil {
util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to extract room ID from event")
// We don't know the event ID at this point so we can't return the
// failure in the PDU results
continue
}
roomVersion := getRoomVersion(header.RoomID)
event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion)
if err != nil {
if _, ok := err.(gomatrixserverlib.BadJSONError); ok {
// Room version 6 states that homeservers should strictly enforce canonical JSON
// on PDUs.
//
// This enforces that the entire transaction is rejected if a single bad PDU is
// sent. It is unclear if this is the correct behaviour or not.
//
// See https://github.com/matrix-org/synapse/issues/7543
return nil, &util.JSONResponse{
Code: 400,
JSON: jsonerror.BadJSON("PDU contains bad JSON"),
}
}
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu))
continue
}
if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") {
continue
}
if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) {
results[event.EventID()] = gomatrixserverlib.PDUResult{
Error: "Forbidden by server ACLs",
}
continue
}
if err = event.VerifyEventSignatures(ctx, t.keys); err != nil {
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())
results[event.EventID()] = gomatrixserverlib.PDUResult{
Error: err.Error(),
}
continue
}
// pass the event to the roomserver which will do auth checks
// If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently
// discarded by the caller of this function
if err = api.SendEvents(
ctx,
t.rsAPI,
api.KindNew,
[]*gomatrixserverlib.HeaderedEvent{
event.Headered(roomVersion),
},
t.Destination,
t.Origin,
api.DoNotSendToOtherServers,
nil,
true,
); err != nil {
util.GetLogger(ctx).WithError(err).Errorf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err)
results[event.EventID()] = gomatrixserverlib.PDUResult{
Error: err.Error(),
}
continue
}
results[event.EventID()] = gomatrixserverlib.PDUResult{}
PDUCountTotal.WithLabelValues("success").Inc()
}
wg.Wait()
return &gomatrixserverlib.RespSend{PDUs: results}, nil
}
// nolint:gocyclo
func (t *TxnReq) processEDUs(ctx context.Context) {
for _, e := range t.EDUs {
EDUCountTotal.Inc()
switch e.Type {
case gomatrixserverlib.MTyping:
// https://matrix.org/docs/spec/server_server/latest#typing-notifications
var typingPayload struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
Typing bool `json:"typing"`
}
if err := json.Unmarshal(e.Content, &typingPayload); err != nil {
util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event")
continue
}
if _, serverName, err := gomatrixserverlib.SplitID('@', typingPayload.UserID); err != nil {
continue
} else if serverName == t.ourServerName {
continue
} else if serverName != t.Origin {
continue
}
if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to JetStream")
}
case gomatrixserverlib.MDirectToDevice:
// https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema
var directPayload gomatrixserverlib.ToDeviceMessage
if err := json.Unmarshal(e.Content, &directPayload); err != nil {
util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events")
continue
}
if _, serverName, err := gomatrixserverlib.SplitID('@', directPayload.Sender); err != nil {
continue
} else if serverName == t.ourServerName {
continue
} else if serverName != t.Origin {
continue
}
for userID, byUser := range directPayload.Messages {
for deviceID, message := range byUser {
// TODO: check that the user and the device actually exist here
if err := t.producer.SendToDevice(ctx, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil {
sentry.CaptureException(err)
util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{
"sender": directPayload.Sender,
"user_id": userID,
"device_id": deviceID,
}).Error("Failed to send send-to-device event to JetStream")
}
}
}
case gomatrixserverlib.MDeviceListUpdate:
if err := t.producer.SendDeviceListUpdate(ctx, e.Content, t.Origin); err != nil {
sentry.CaptureException(err)
util.GetLogger(ctx).WithError(err).Error("failed to InputDeviceListUpdate")
}
case gomatrixserverlib.MReceipt:
// https://matrix.org/docs/spec/server_server/r0.1.4#receipts
payload := map[string]types.FederationReceiptMRead{}
if err := json.Unmarshal(e.Content, &payload); err != nil {
util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal receipt event")
continue
}
for roomID, receipt := range payload {
for userID, mread := range receipt.User {
_, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from receipt event sender")
continue
}
if t.Origin != domain {
util.GetLogger(ctx).Debugf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin)
continue
}
if err := t.processReceiptEvent(ctx, userID, roomID, "m.read", mread.Data.TS, mread.EventIDs); err != nil {
util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{
"sender": t.Origin,
"user_id": userID,
"room_id": roomID,
"events": mread.EventIDs,
}).Error("Failed to send receipt event to JetStream")
continue
}
}
}
case types.MSigningKeyUpdate:
if err := t.producer.SendSigningKeyUpdate(ctx, e.Content, t.Origin); err != nil {
sentry.CaptureException(err)
logrus.WithError(err).Errorf("Failed to process signing key update")
}
case gomatrixserverlib.MPresence:
if t.inboundPresenceEnabled {
if err := t.processPresence(ctx, e); err != nil {
logrus.WithError(err).Errorf("Failed to process presence update")
}
}
default:
util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU")
}
}
}
// processPresence handles m.receipt events
func (t *TxnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) error {
payload := types.Presence{}
if err := json.Unmarshal(e.Content, &payload); err != nil {
return err
}
for _, content := range payload.Push {
if _, serverName, err := gomatrixserverlib.SplitID('@', content.UserID); err != nil {
continue
} else if serverName == t.ourServerName {
continue
} else if serverName != t.Origin {
continue
}
presence, ok := syncTypes.PresenceFromString(content.Presence)
if !ok {
continue
}
if err := t.producer.SendPresence(ctx, content.UserID, presence, content.StatusMsg, content.LastActiveAgo); err != nil {
return err
}
}
return nil
}
// processReceiptEvent sends receipt events to JetStream
func (t *TxnReq) processReceiptEvent(ctx context.Context,
userID, roomID, receiptType string,
timestamp gomatrixserverlib.Timestamp,
eventIDs []string,
) error {
if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil {
return nil
} else if serverName == t.ourServerName {
return nil
} else if serverName != t.Origin {
return nil
}
// store every event
for _, eventID := range eventIDs {
if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil {
return fmt.Errorf("unable to set receipt event: %w", err)
}
}
return nil
}

View File

@ -0,0 +1,820 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"encoding/json"
"fmt"
"strconv"
"testing"
"time"
"github.com/matrix-org/dendrite/federationapi/producers"
keyAPI "github.com/matrix-org/dendrite/keyserver/api"
rsAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/stretchr/testify/assert"
"go.uber.org/atomic"
"gotest.tools/v3/poll"
)
const (
testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
testDestination = gomatrixserverlib.ServerName("white.orchard")
)
var (
invalidSignatures = json.RawMessage(`{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":3917,"hashes":{"sha256":"cNAWtlHIegrji0mMA6x1rhpYCccY8W1NsWZqSpJFhjs"},"origin":"localhost","origin_server_ts":0,"prev_events":["$4GDB0bVjkWwS3G4noUZCq5oLWzpBYpwzdMcf7gj24CI"],"room_id":"!roomid:localishhost","sender":"@userid:localhost","signatures":{"localhost":{"ed2559:auto":"NKym6Kcy3u9mGUr21Hjfe3h7DfDilDhN5PqztT0QZ4NTZ+8Y7owseLolQVXp+TvNjecvzdDywsXXVvGiaQiWAQ"}},"type":"m.room.member"}`)
testData = []json.RawMessage{
[]byte(`{"auth_events":[],"content":{"creator":"@userid:kaer.morhen"},"depth":0,"event_id":"$0ok8ynDp7kjc95e3:kaer.morhen","hashes":{"sha256":"17kPoH+h0Dk4Omn7Sus0qMb6+oGcf+CZFEgDhv7UKWs"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"jP4a04f5/F10Pw95FPpdCyKAO44JOwUQ/MZOOeA/RTU1Dn+AHPMzGSaZnuGjRr/xQuADt+I3ctb5ZQfLKNzHDw"}},"state_key":"","type":"m.room.create"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"content":{"membership":"join"},"depth":1,"event_id":"$LEwEu0kxrtu5fOiS:kaer.morhen","hashes":{"sha256":"B7M88PhXf3vd1LaFtjQutFu4x/w7fHD28XKZ4sAsJTo"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"p2vqmuJn7ZBRImctSaKbXCAxCcBlIjPH9JHte1ouIUGy84gpu4eLipOvSBCLL26hXfC0Zrm4WUto6Hr+ohdrCg"}},"state_key":"@userid:kaer.morhen","type":"m.room.member"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"join_rule":"public"},"depth":2,"event_id":"$SMHlqUrNhhBBRLeN:kaer.morhen","hashes":{"sha256":"vIuJQvmMjrGxshAkj1SXe0C4RqvMbv4ZADDw9pFCWqQ"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"hBMsb3Qppo3RaqqAl4JyTgaiWEbW5hlckATky6PrHun+F3YM203TzG7w9clwuQU5F5pZoB1a6nw+to0hN90FAw"}},"state_key":"","type":"m.room.join_rules"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"history_visibility":"shared"},"depth":3,"event_id":"$6F1yGIbO0J7TM93h:kaer.morhen","hashes":{"sha256":"Mr23GKSlZW7UCCYLgOWawI2Sg6KIoMjUWO2TDenuOgw"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$SMHlqUrNhhBBRLeN:kaer.morhen",{"sha256":"SylzE8U02I+6eyEHgL+FlU0L5YdqrVp8OOlxKS9VQW0"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sHLKrFI3hKGrEJfpMVZSDS3LvLasQsy50CTsOwru9XTVxgRsPo6wozNtRVjxo1J3Rk18RC9JppovmQ5VR5EcDw"}},"state_key":"","type":"m.room.history_visibility"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"ban":50,"events":null,"events_default":0,"invite":0,"kick":50,"redact":50,"state_default":50,"users":null,"users_default":0},"depth":4,"event_id":"$UKNe10XzYzG0TeA9:kaer.morhen","hashes":{"sha256":"ngbP3yja9U5dlckKerUs/fSOhtKxZMCVvsfhPURSS28"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$6F1yGIbO0J7TM93h:kaer.morhen",{"sha256":"A4CucrKSoWX4IaJXhq02mBg1sxIyZEftbC+5p3fZAvk"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"zOmwlP01QL3yFchzuR9WHvogOoBZA3oVtNIF3lM0ZfDnqlSYZB9sns27G/4HVq0k7alaK7ZE3oGoCrVnMkPNCw"}},"state_key":"","type":"m.room.power_levels"}`),
// messages
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":6,"event_id":"$MYSbs8m4rEbsCWXD:kaer.morhen","hashes":{"sha256":"kgbYM7v4Ud2YaBsjBTolM4ySg6rHcJNYI6nWhMSdFUA"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$gl2T9l3qm0kUbiIJ:kaer.morhen",{"sha256":"C/rD04h9wGxRdN2G/IBfrgoE1UovzLZ+uskwaKZ37/Q"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"x0UoKh968jj/F5l1/R7Ew0T6CTKuew3PLNHASNxqck/bkNe8yYQiDHXRr+kZxObeqPZZTpaF1+EI+bLU9W8GDQ"}},"type":"m.room.message"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":7,"event_id":"$N5x9WJkl9ClPrAEg:kaer.morhen","hashes":{"sha256":"FWM8oz4yquTunRZ67qlW2gzPDzdWfBP6RPHXhK1I/x8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$MYSbs8m4rEbsCWXD:kaer.morhen",{"sha256":"fatqgW+SE8mb2wFn3UN+drmluoD4UJ/EcSrL6Ur9q1M"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"Y+LX/xcyufoXMOIoqQBNOzy6lZfUGB1ffgXIrSugk6obMiyAsiRejHQN/pciZXsHKxMJLYRFAz4zSJoS/LGPAA"}},"type":"m.room.message"}`),
}
testEvent = []byte(`{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":3917,"hashes":{"sha256":"cNAWtlHIegrji0mMA6x1rhpYCccY8W1NsWZqSpJFhjs"},"origin":"localhost","origin_server_ts":0,"prev_events":["$4GDB0bVjkWwS3G4noUZCq5oLWzpBYpwzdMcf7gj24CI"],"room_id":"!roomid:localhost","sender":"@userid:localhost","signatures":{"localhost":{"ed25519:auto":"NKym6Kcy3u9mGUr21Hjfe3h7DfDilDhN5PqztT0QZ4NTZ+8Y7owseLolQVXp+TvNjecvzdDywsXXVvGiuQiWAQ"}},"type":"m.room.message"}`)
testRoomVersion = gomatrixserverlib.RoomVersionV1
testEvents = []*gomatrixserverlib.HeaderedEvent{}
testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
)
type FakeRsAPI struct {
rsAPI.RoomserverInternalAPI
shouldFailQuery bool
bannedFromRoom bool
shouldEventsFail bool
}
func (r *FakeRsAPI) QueryRoomVersionForRoom(
ctx context.Context,
req *rsAPI.QueryRoomVersionForRoomRequest,
res *rsAPI.QueryRoomVersionForRoomResponse,
) error {
if r.shouldFailQuery {
return fmt.Errorf("Failure")
}
res.RoomVersion = gomatrixserverlib.RoomVersionV10
return nil
}
func (r *FakeRsAPI) QueryServerBannedFromRoom(
ctx context.Context,
req *rsAPI.QueryServerBannedFromRoomRequest,
res *rsAPI.QueryServerBannedFromRoomResponse,
) error {
if r.bannedFromRoom {
res.Banned = true
} else {
res.Banned = false
}
return nil
}
func (r *FakeRsAPI) InputRoomEvents(
ctx context.Context,
req *rsAPI.InputRoomEventsRequest,
res *rsAPI.InputRoomEventsResponse,
) error {
if r.shouldEventsFail {
return fmt.Errorf("Failure")
}
return nil
}
func TestEmptyTransactionRequest(t *testing.T) {
txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", nil, nil, nil, false, []json.RawMessage{}, []gomatrixserverlib.EDU{}, "", "", "")
txnRes, jsonRes := txn.ProcessTransaction(context.Background())
assert.Nil(t, jsonRes)
assert.Zero(t, len(txnRes.PDUs))
}
func TestProcessTransactionRequestPDU(t *testing.T) {
keyRing := &test.NopJSONVerifier{}
txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "")
txnRes, jsonRes := txn.ProcessTransaction(context.Background())
assert.Nil(t, jsonRes)
assert.Equal(t, 1, len(txnRes.PDUs))
for _, result := range txnRes.PDUs {
assert.Empty(t, result.Error)
}
}
func TestProcessTransactionRequestPDUs(t *testing.T) {
keyRing := &test.NopJSONVerifier{}
txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, append(testData, testEvent), []gomatrixserverlib.EDU{}, "", "", "")
txnRes, jsonRes := txn.ProcessTransaction(context.Background())
assert.Nil(t, jsonRes)
assert.Equal(t, 1, len(txnRes.PDUs))
for _, result := range txnRes.PDUs {
assert.Empty(t, result.Error)
}
}
func TestProcessTransactionRequestBadPDU(t *testing.T) {
pdu := json.RawMessage("{\"room_id\":\"asdf\"}")
pdu2 := json.RawMessage("\"roomid\":\"asdf\"")
keyRing := &test.NopJSONVerifier{}
txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{pdu, pdu2, testEvent}, []gomatrixserverlib.EDU{}, "", "", "")
txnRes, jsonRes := txn.ProcessTransaction(context.Background())
assert.Nil(t, jsonRes)
assert.Equal(t, 1, len(txnRes.PDUs))
for _, result := range txnRes.PDUs {
assert.Empty(t, result.Error)
}
}
func TestProcessTransactionRequestPDUQueryFailure(t *testing.T) {
keyRing := &test.NopJSONVerifier{}
txn := NewTxnReq(&FakeRsAPI{shouldFailQuery: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "")
txnRes, jsonRes := txn.ProcessTransaction(context.Background())
assert.Nil(t, jsonRes)
assert.Zero(t, len(txnRes.PDUs))
}
func TestProcessTransactionRequestPDUBannedFromRoom(t *testing.T) {
keyRing := &test.NopJSONVerifier{}
txn := NewTxnReq(&FakeRsAPI{bannedFromRoom: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "")
txnRes, jsonRes := txn.ProcessTransaction(context.Background())
assert.Nil(t, jsonRes)
assert.Equal(t, 1, len(txnRes.PDUs))
for _, result := range txnRes.PDUs {
assert.NotEmpty(t, result.Error)
}
}
func TestProcessTransactionRequestPDUInvalidSignature(t *testing.T) {
keyRing := &test.NopJSONVerifier{}
txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{invalidSignatures}, []gomatrixserverlib.EDU{}, "", "", "")
txnRes, jsonRes := txn.ProcessTransaction(context.Background())
assert.Nil(t, jsonRes)
assert.Equal(t, 1, len(txnRes.PDUs))
for _, result := range txnRes.PDUs {
assert.NotEmpty(t, result.Error)
}
}
func TestProcessTransactionRequestPDUSendFail(t *testing.T) {
keyRing := &test.NopJSONVerifier{}
txn := NewTxnReq(&FakeRsAPI{shouldEventsFail: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "")
txnRes, jsonRes := txn.ProcessTransaction(context.Background())
assert.Nil(t, jsonRes)
assert.Equal(t, 1, len(txnRes.PDUs))
for _, result := range txnRes.PDUs {
assert.NotEmpty(t, result.Error)
}
}
func createTransactionWithEDU(ctx *process.ProcessContext, edus []gomatrixserverlib.EDU) (TxnReq, nats.JetStreamContext, *config.Dendrite) {
cfg := &config.Dendrite{}
cfg.Defaults(config.DefaultOpts{
Generate: true,
Monolithic: true,
})
cfg.Global.JetStream.InMemory = true
natsInstance := &jetstream.NATSInstance{}
js, _ := natsInstance.Prepare(ctx, &cfg.Global.JetStream)
producer := &producers.SyncAPIProducer{
JetStream: js,
TopicReceiptEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent),
TopicSendToDeviceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
TopicTypingEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent),
TopicPresenceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent),
TopicDeviceListUpdate: cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
TopicSigningKeyUpdate: cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
Config: &cfg.FederationAPI,
UserAPI: nil,
}
keyRing := &test.NopJSONVerifier{}
txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, producer, true, []json.RawMessage{}, edus, "kaer.morhen", "", "ourserver")
return txn, js, cfg
}
func TestProcessTransactionRequestEDUTyping(t *testing.T) {
var err error
roomID := "!roomid:kaer.morhen"
userID := "@userid:kaer.morhen"
typing := true
edu := gomatrixserverlib.EDU{Type: "m.typing"}
if edu.Content, err = json.Marshal(map[string]interface{}{
"room_id": roomID,
"user_id": userID,
"typing": typing,
}); err != nil {
t.Errorf("failed to marshal EDU JSON")
}
badEDU := gomatrixserverlib.EDU{Type: "m.typing"}
badEDU.Content = gomatrixserverlib.RawJSON("badjson")
edus := []gomatrixserverlib.EDU{badEDU, edu}
ctx := process.NewProcessContext()
defer ctx.ShutdownDendrite()
txn, js, cfg := createTransactionWithEDU(ctx, edus)
received := atomic.NewBool(false)
onMessage := func(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
room := msg.Header.Get(jetstream.RoomID)
assert.Equal(t, roomID, room)
user := msg.Header.Get(jetstream.UserID)
assert.Equal(t, userID, user)
typ, parseErr := strconv.ParseBool(msg.Header.Get("typing"))
if parseErr != nil {
return true
}
assert.Equal(t, typing, typ)
received.Store(true)
return true
}
err = jetstream.JetStreamConsumer(
ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent),
cfg.Global.JetStream.Durable("TestTypingConsumer"), 1,
onMessage, nats.DeliverAll(), nats.ManualAck(),
)
assert.Nil(t, err)
txnRes, jsonRes := txn.ProcessTransaction(ctx.Context())
assert.Nil(t, jsonRes)
assert.Zero(t, len(txnRes.PDUs))
check := func(log poll.LogT) poll.Result {
if received.Load() {
return poll.Success()
}
return poll.Continue("waiting for events to be processed")
}
poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond))
}
func TestProcessTransactionRequestEDUToDevice(t *testing.T) {
var err error
sender := "@userid:kaer.morhen"
messageID := "$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg"
msgType := "m.dendrite.test"
edu := gomatrixserverlib.EDU{Type: "m.direct_to_device"}
if edu.Content, err = json.Marshal(map[string]interface{}{
"sender": sender,
"type": msgType,
"message_id": messageID,
"messages": map[string]interface{}{
"@alice:example.org": map[string]interface{}{
"IWHQUZUIAH": map[string]interface{}{
"algorithm": "m.megolm.v1.aes-sha2",
"room_id": "!Cuyf34gef24t:localhost",
"session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ",
"session_key": "AgAAAADxKHa9uFxcXzwYoNueL5Xqi69IkD4sni8LlfJL7qNBEY...",
},
},
},
}); err != nil {
t.Errorf("failed to marshal EDU JSON")
}
badEDU := gomatrixserverlib.EDU{Type: "m.direct_to_device"}
badEDU.Content = gomatrixserverlib.RawJSON("badjson")
edus := []gomatrixserverlib.EDU{badEDU, edu}
ctx := process.NewProcessContext()
defer ctx.ShutdownDendrite()
txn, js, cfg := createTransactionWithEDU(ctx, edus)
received := atomic.NewBool(false)
onMessage := func(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var output types.OutputSendToDeviceEvent
if err = json.Unmarshal(msg.Data, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream
println(err.Error())
return true
}
assert.Equal(t, sender, output.Sender)
assert.Equal(t, msgType, output.Type)
received.Store(true)
return true
}
err = jetstream.JetStreamConsumer(
ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
cfg.Global.JetStream.Durable("TestToDevice"), 1,
onMessage, nats.DeliverAll(), nats.ManualAck(),
)
assert.Nil(t, err)
txnRes, jsonRes := txn.ProcessTransaction(ctx.Context())
assert.Nil(t, jsonRes)
assert.Zero(t, len(txnRes.PDUs))
check := func(log poll.LogT) poll.Result {
if received.Load() {
return poll.Success()
}
return poll.Continue("waiting for events to be processed")
}
poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond))
}
func TestProcessTransactionRequestEDUDeviceListUpdate(t *testing.T) {
var err error
deviceID := "QBUAZIFURK"
userID := "@john:example.com"
edu := gomatrixserverlib.EDU{Type: "m.device_list_update"}
if edu.Content, err = json.Marshal(map[string]interface{}{
"device_display_name": "Mobile",
"device_id": deviceID,
"key": "value",
"keys": map[string]interface{}{
"algorithms": []string{
"m.olm.v1.curve25519-aes-sha2",
"m.megolm.v1.aes-sha2",
},
"device_id": "JLAFKJWSCS",
"keys": map[string]interface{}{
"curve25519:JLAFKJWSCS": "3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI",
"ed25519:JLAFKJWSCS": "lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI",
},
"signatures": map[string]interface{}{
"@alice:example.com": map[string]interface{}{
"ed25519:JLAFKJWSCS": "dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA",
},
},
"user_id": "@alice:example.com",
},
"prev_id": []int{
5,
},
"stream_id": 6,
"user_id": userID,
}); err != nil {
t.Errorf("failed to marshal EDU JSON")
}
badEDU := gomatrixserverlib.EDU{Type: "m.device_list_update"}
badEDU.Content = gomatrixserverlib.RawJSON("badjson")
edus := []gomatrixserverlib.EDU{badEDU, edu}
ctx := process.NewProcessContext()
defer ctx.ShutdownDendrite()
txn, js, cfg := createTransactionWithEDU(ctx, edus)
received := atomic.NewBool(false)
onMessage := func(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var output gomatrixserverlib.DeviceListUpdateEvent
if err = json.Unmarshal(msg.Data, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream
println(err.Error())
return true
}
assert.Equal(t, userID, output.UserID)
assert.Equal(t, deviceID, output.DeviceID)
received.Store(true)
return true
}
err = jetstream.JetStreamConsumer(
ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
cfg.Global.JetStream.Durable("TestDeviceListUpdate"), 1,
onMessage, nats.DeliverAll(), nats.ManualAck(),
)
assert.Nil(t, err)
txnRes, jsonRes := txn.ProcessTransaction(ctx.Context())
assert.Nil(t, jsonRes)
assert.Zero(t, len(txnRes.PDUs))
check := func(log poll.LogT) poll.Result {
if received.Load() {
return poll.Success()
}
return poll.Continue("waiting for events to be processed")
}
poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond))
}
func TestProcessTransactionRequestEDUReceipt(t *testing.T) {
var err error
roomID := "!some_room:example.org"
edu := gomatrixserverlib.EDU{Type: "m.receipt"}
if edu.Content, err = json.Marshal(map[string]interface{}{
roomID: map[string]interface{}{
"m.read": map[string]interface{}{
"@john:kaer.morhen": map[string]interface{}{
"data": map[string]interface{}{
"ts": 1533358089009,
},
"event_ids": []string{
"$read_this_event:matrix.org",
},
},
},
},
}); err != nil {
t.Errorf("failed to marshal EDU JSON")
}
badEDU := gomatrixserverlib.EDU{Type: "m.receipt"}
badEDU.Content = gomatrixserverlib.RawJSON("badjson")
badUser := gomatrixserverlib.EDU{Type: "m.receipt"}
if badUser.Content, err = json.Marshal(map[string]interface{}{
roomID: map[string]interface{}{
"m.read": map[string]interface{}{
"johnkaer.morhen": map[string]interface{}{
"data": map[string]interface{}{
"ts": 1533358089009,
},
"event_ids": []string{
"$read_this_event:matrix.org",
},
},
},
},
}); err != nil {
t.Errorf("failed to marshal EDU JSON")
}
badDomain := gomatrixserverlib.EDU{Type: "m.receipt"}
if badDomain.Content, err = json.Marshal(map[string]interface{}{
roomID: map[string]interface{}{
"m.read": map[string]interface{}{
"@john:bad.domain": map[string]interface{}{
"data": map[string]interface{}{
"ts": 1533358089009,
},
"event_ids": []string{
"$read_this_event:matrix.org",
},
},
},
},
}); err != nil {
t.Errorf("failed to marshal EDU JSON")
}
edus := []gomatrixserverlib.EDU{badEDU, badUser, edu}
ctx := process.NewProcessContext()
defer ctx.ShutdownDendrite()
txn, js, cfg := createTransactionWithEDU(ctx, edus)
received := atomic.NewBool(false)
onMessage := func(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var output types.OutputReceiptEvent
output.RoomID = msg.Header.Get(jetstream.RoomID)
assert.Equal(t, roomID, output.RoomID)
received.Store(true)
return true
}
err = jetstream.JetStreamConsumer(
ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent),
cfg.Global.JetStream.Durable("TestReceipt"), 1,
onMessage, nats.DeliverAll(), nats.ManualAck(),
)
assert.Nil(t, err)
txnRes, jsonRes := txn.ProcessTransaction(ctx.Context())
assert.Nil(t, jsonRes)
assert.Zero(t, len(txnRes.PDUs))
check := func(log poll.LogT) poll.Result {
if received.Load() {
return poll.Success()
}
return poll.Continue("waiting for events to be processed")
}
poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond))
}
func TestProcessTransactionRequestEDUSigningKeyUpdate(t *testing.T) {
var err error
edu := gomatrixserverlib.EDU{Type: "m.signing_key_update"}
if edu.Content, err = json.Marshal(map[string]interface{}{}); err != nil {
t.Errorf("failed to marshal EDU JSON")
}
badEDU := gomatrixserverlib.EDU{Type: "m.signing_key_update"}
badEDU.Content = gomatrixserverlib.RawJSON("badjson")
edus := []gomatrixserverlib.EDU{badEDU, edu}
ctx := process.NewProcessContext()
defer ctx.ShutdownDendrite()
txn, js, cfg := createTransactionWithEDU(ctx, edus)
received := atomic.NewBool(false)
onMessage := func(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var output keyAPI.CrossSigningKeyUpdate
if err = json.Unmarshal(msg.Data, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream
println(err.Error())
return true
}
received.Store(true)
return true
}
err = jetstream.JetStreamConsumer(
ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
cfg.Global.JetStream.Durable("TestSigningKeyUpdate"), 1,
onMessage, nats.DeliverAll(), nats.ManualAck(),
)
assert.Nil(t, err)
txnRes, jsonRes := txn.ProcessTransaction(ctx.Context())
assert.Nil(t, jsonRes)
assert.Zero(t, len(txnRes.PDUs))
check := func(log poll.LogT) poll.Result {
if received.Load() {
return poll.Success()
}
return poll.Continue("waiting for events to be processed")
}
poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond))
}
func TestProcessTransactionRequestEDUPresence(t *testing.T) {
var err error
userID := "@john:kaer.morhen"
presence := "online"
edu := gomatrixserverlib.EDU{Type: "m.presence"}
if edu.Content, err = json.Marshal(map[string]interface{}{
"push": []map[string]interface{}{{
"currently_active": true,
"last_active_ago": 5000,
"presence": presence,
"status_msg": "Making cupcakes",
"user_id": userID,
}},
}); err != nil {
t.Errorf("failed to marshal EDU JSON")
}
badEDU := gomatrixserverlib.EDU{Type: "m.presence"}
badEDU.Content = gomatrixserverlib.RawJSON("badjson")
edus := []gomatrixserverlib.EDU{badEDU, edu}
ctx := process.NewProcessContext()
defer ctx.ShutdownDendrite()
txn, js, cfg := createTransactionWithEDU(ctx, edus)
received := atomic.NewBool(false)
onMessage := func(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
userIDRes := msg.Header.Get(jetstream.UserID)
presenceRes := msg.Header.Get("presence")
assert.Equal(t, userID, userIDRes)
assert.Equal(t, presence, presenceRes)
received.Store(true)
return true
}
err = jetstream.JetStreamConsumer(
ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent),
cfg.Global.JetStream.Durable("TestPresence"), 1,
onMessage, nats.DeliverAll(), nats.ManualAck(),
)
assert.Nil(t, err)
txnRes, jsonRes := txn.ProcessTransaction(ctx.Context())
assert.Nil(t, jsonRes)
assert.Zero(t, len(txnRes.PDUs))
check := func(log poll.LogT) poll.Result {
if received.Load() {
return poll.Success()
}
return poll.Continue("waiting for events to be processed")
}
poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond))
}
func TestProcessTransactionRequestEDUUnhandled(t *testing.T) {
var err error
edu := gomatrixserverlib.EDU{Type: "m.unhandled"}
if edu.Content, err = json.Marshal(map[string]interface{}{}); err != nil {
t.Errorf("failed to marshal EDU JSON")
}
ctx := process.NewProcessContext()
defer ctx.ShutdownDendrite()
txn, _, _ := createTransactionWithEDU(ctx, []gomatrixserverlib.EDU{edu})
txnRes, jsonRes := txn.ProcessTransaction(ctx.Context())
assert.Nil(t, jsonRes)
assert.Zero(t, len(txnRes.PDUs))
}
func init() {
for _, j := range testData {
e, err := gomatrixserverlib.NewEventFromTrustedJSON(j, false, testRoomVersion)
if err != nil {
panic("cannot load test data: " + err.Error())
}
h := e.Headered(testRoomVersion)
testEvents = append(testEvents, h)
if e.StateKey() != nil {
testStateEvents[gomatrixserverlib.StateKeyTuple{
EventType: e.Type(),
StateKey: *e.StateKey(),
}] = h
}
}
}
type testRoomserverAPI struct {
rsAPI.RoomserverInternalAPITrace
inputRoomEvents []rsAPI.InputRoomEvent
queryStateAfterEvents func(*rsAPI.QueryStateAfterEventsRequest) rsAPI.QueryStateAfterEventsResponse
queryEventsByID func(req *rsAPI.QueryEventsByIDRequest) rsAPI.QueryEventsByIDResponse
queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse
}
func (t *testRoomserverAPI) InputRoomEvents(
ctx context.Context,
request *rsAPI.InputRoomEventsRequest,
response *rsAPI.InputRoomEventsResponse,
) error {
t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...)
for _, ire := range request.InputRoomEvents {
fmt.Println("InputRoomEvents: ", ire.Event.EventID())
}
return nil
}
// Query the latest events and state for a room from the room server.
func (t *testRoomserverAPI) QueryLatestEventsAndState(
ctx context.Context,
request *rsAPI.QueryLatestEventsAndStateRequest,
response *rsAPI.QueryLatestEventsAndStateResponse,
) error {
r := t.queryLatestEventsAndState(request)
response.RoomExists = r.RoomExists
response.RoomVersion = testRoomVersion
response.LatestEvents = r.LatestEvents
response.StateEvents = r.StateEvents
response.Depth = r.Depth
return nil
}
// Query the state after a list of events in a room from the room server.
func (t *testRoomserverAPI) QueryStateAfterEvents(
ctx context.Context,
request *rsAPI.QueryStateAfterEventsRequest,
response *rsAPI.QueryStateAfterEventsResponse,
) error {
response.RoomVersion = testRoomVersion
res := t.queryStateAfterEvents(request)
response.PrevEventsExist = res.PrevEventsExist
response.RoomExists = res.RoomExists
response.StateEvents = res.StateEvents
return nil
}
// Query a list of events by event ID.
func (t *testRoomserverAPI) QueryEventsByID(
ctx context.Context,
request *rsAPI.QueryEventsByIDRequest,
response *rsAPI.QueryEventsByIDResponse,
) error {
res := t.queryEventsByID(request)
response.Events = res.Events
return nil
}
// Query if a server is joined to a room
func (t *testRoomserverAPI) QueryServerJoinedToRoom(
ctx context.Context,
request *rsAPI.QueryServerJoinedToRoomRequest,
response *rsAPI.QueryServerJoinedToRoomResponse,
) error {
response.RoomExists = true
response.IsInRoom = true
return nil
}
// Asks for the room version for a given room.
func (t *testRoomserverAPI) QueryRoomVersionForRoom(
ctx context.Context,
request *rsAPI.QueryRoomVersionForRoomRequest,
response *rsAPI.QueryRoomVersionForRoomResponse,
) error {
response.RoomVersion = testRoomVersion
return nil
}
func (t *testRoomserverAPI) QueryServerBannedFromRoom(
ctx context.Context, req *rsAPI.QueryServerBannedFromRoomRequest, res *rsAPI.QueryServerBannedFromRoomResponse,
) error {
res.Banned = false
return nil
}
func mustCreateTransaction(rsAPI rsAPI.FederationRoomserverAPI, pdus []json.RawMessage) *TxnReq {
t := NewTxnReq(
rsAPI,
nil,
"",
&test.NopJSONVerifier{},
NewMutexByRoom(),
nil,
false,
pdus,
nil,
testOrigin,
gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())),
testDestination)
t.PDUs = pdus
t.Origin = testOrigin
t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
t.Destination = testDestination
return &t
}
func mustProcessTransaction(t *testing.T, txn *TxnReq, pdusWithErrors []string) {
res, err := txn.ProcessTransaction(context.Background())
if err != nil {
t.Errorf("txn.processTransaction returned an error: %v", err)
return
}
if len(res.PDUs) != len(txn.PDUs) {
t.Errorf("txn.processTransaction did not return results for all PDUs, got %d want %d", len(res.PDUs), len(txn.PDUs))
return
}
NextPDU:
for eventID, result := range res.PDUs {
if result.Error == "" {
continue
}
for _, eventIDWantError := range pdusWithErrors {
if eventID == eventIDWantError {
break NextPDU
}
}
t.Errorf("txn.processTransaction PDU %s returned an error %s", eventID, result.Error)
}
}
func assertInputRoomEvents(t *testing.T, got []rsAPI.InputRoomEvent, want []*gomatrixserverlib.HeaderedEvent) {
for _, g := range got {
fmt.Println("GOT ", g.Event.EventID())
}
if len(got) != len(want) {
t.Errorf("wrong number of InputRoomEvents: got %d want %d", len(got), len(want))
return
}
for i := range got {
if got[i].Event.EventID() != want[i].EventID() {
t.Errorf("InputRoomEvents[%d] got %s want %s", i, got[i].Event.EventID(), want[i].EventID())
}
}
}
// The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on
// to the roomserver. It's the most basic test possible.
func TestBasicTransaction(t *testing.T) {
rsAPI := &testRoomserverAPI{}
pdus := []json.RawMessage{
testData[len(testData)-1], // a message event
}
txn := mustCreateTransaction(rsAPI, pdus)
mustProcessTransaction(t, txn, nil)
assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]})
}
// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver
// as it does the auth check.
func TestTransactionFailAuthChecks(t *testing.T) {
rsAPI := &testRoomserverAPI{}
pdus := []json.RawMessage{
testData[len(testData)-1], // a message event
}
txn := mustCreateTransaction(rsAPI, pdus)
mustProcessTransaction(t, txn, []string{})
// expect message to be sent to the roomserver
assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]})
}

View File

@ -108,13 +108,16 @@ func makeDownloadAPI(
activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
) http.HandlerFunc {
counterVec := promauto.NewCounterVec(
var counterVec *prometheus.CounterVec
if cfg.Matrix.Metrics.Enabled {
counterVec = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: name,
Help: "Total number of media_api requests for either thumbnails or full downloads",
},
[]string{"code"},
)
}
httpHandler := func(w http.ResponseWriter, req *http.Request) {
req = util.RequestWithLogging(req)
@ -166,5 +169,12 @@ func makeDownloadAPI(
vars["downloadName"],
)
}
return promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler))
var handlerFunc http.HandlerFunc
if counterVec != nil {
handlerFunc = promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler))
} else {
handlerFunc = http.HandlerFunc(httpHandler)
}
return handlerFunc
}

56
relayapi/api/api.go Normal file
View File

@ -0,0 +1,56 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"context"
"github.com/matrix-org/gomatrixserverlib"
)
// RelayInternalAPI is used to query information from the relay server.
type RelayInternalAPI interface {
RelayServerAPI
// Retrieve from external relay server all transactions stored for us and process them.
PerformRelayServerSync(
ctx context.Context,
userID gomatrixserverlib.UserID,
relayServer gomatrixserverlib.ServerName,
) error
}
// RelayServerAPI exposes the store & query transaction functionality of a relay server.
type RelayServerAPI interface {
// Store transactions for forwarding to the destination at a later time.
PerformStoreTransaction(
ctx context.Context,
transaction gomatrixserverlib.Transaction,
userID gomatrixserverlib.UserID,
) error
// Obtain the oldest stored transaction for the specified userID.
QueryTransactions(
ctx context.Context,
userID gomatrixserverlib.UserID,
previousEntry gomatrixserverlib.RelayEntry,
) (QueryRelayTransactionsResponse, error)
}
type QueryRelayTransactionsResponse struct {
Transaction gomatrixserverlib.Transaction `json:"transaction"`
EntryID int64 `json:"entry_id"`
EntriesQueued bool `json:"entries_queued"`
}

53
relayapi/internal/api.go Normal file
View File

@ -0,0 +1,53 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
fedAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/relayapi/storage"
rsAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
)
type RelayInternalAPI struct {
db storage.Database
fedClient fedAPI.FederationClient
rsAPI rsAPI.RoomserverInternalAPI
keyRing *gomatrixserverlib.KeyRing
producer *producers.SyncAPIProducer
presenceEnabledInbound bool
serverName gomatrixserverlib.ServerName
}
func NewRelayInternalAPI(
db storage.Database,
fedClient fedAPI.FederationClient,
rsAPI rsAPI.RoomserverInternalAPI,
keyRing *gomatrixserverlib.KeyRing,
producer *producers.SyncAPIProducer,
presenceEnabledInbound bool,
serverName gomatrixserverlib.ServerName,
) *RelayInternalAPI {
return &RelayInternalAPI{
db: db,
fedClient: fedClient,
rsAPI: rsAPI,
keyRing: keyRing,
producer: producer,
presenceEnabledInbound: presenceEnabledInbound,
serverName: serverName,
}
}

View File

@ -0,0 +1,141 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
// PerformRelayServerSync implements api.RelayInternalAPI
func (r *RelayInternalAPI) PerformRelayServerSync(
ctx context.Context,
userID gomatrixserverlib.UserID,
relayServer gomatrixserverlib.ServerName,
) error {
// Providing a default RelayEntry (EntryID = 0) is done to ask the relay if there are any
// transactions available for this node.
prevEntry := gomatrixserverlib.RelayEntry{}
asyncResponse, err := r.fedClient.P2PGetTransactionFromRelay(ctx, userID, prevEntry, relayServer)
if err != nil {
logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error())
return err
}
r.processTransaction(&asyncResponse.Txn)
prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID}
for asyncResponse.EntriesQueued {
// There are still more entries available for this node from the relay.
logrus.Infof("Retrieving next entry from relay, previous: %v", prevEntry)
asyncResponse, err = r.fedClient.P2PGetTransactionFromRelay(ctx, userID, prevEntry, relayServer)
prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID}
if err != nil {
logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error())
return err
}
r.processTransaction(&asyncResponse.Txn)
}
return nil
}
// PerformStoreTransaction implements api.RelayInternalAPI
func (r *RelayInternalAPI) PerformStoreTransaction(
ctx context.Context,
transaction gomatrixserverlib.Transaction,
userID gomatrixserverlib.UserID,
) error {
logrus.Warnf("Storing transaction for %v", userID)
receipt, err := r.db.StoreTransaction(ctx, transaction)
if err != nil {
logrus.Errorf("db.StoreTransaction: %s", err.Error())
return err
}
err = r.db.AssociateTransactionWithDestinations(
ctx,
map[gomatrixserverlib.UserID]struct{}{
userID: {},
},
transaction.TransactionID,
receipt)
return err
}
// QueryTransactions implements api.RelayInternalAPI
func (r *RelayInternalAPI) QueryTransactions(
ctx context.Context,
userID gomatrixserverlib.UserID,
previousEntry gomatrixserverlib.RelayEntry,
) (api.QueryRelayTransactionsResponse, error) {
logrus.Infof("QueryTransactions for %s", userID.Raw())
if previousEntry.EntryID > 0 {
logrus.Infof("Cleaning previous entry (%v) from db for %s",
previousEntry.EntryID,
userID.Raw(),
)
prevReceipt := receipt.NewReceipt(previousEntry.EntryID)
err := r.db.CleanTransactions(ctx, userID, []*receipt.Receipt{&prevReceipt})
if err != nil {
logrus.Errorf("db.CleanTransactions: %s", err.Error())
return api.QueryRelayTransactionsResponse{}, err
}
}
transaction, receipt, err := r.db.GetTransaction(ctx, userID)
if err != nil {
logrus.Errorf("db.GetTransaction: %s", err.Error())
return api.QueryRelayTransactionsResponse{}, err
}
response := api.QueryRelayTransactionsResponse{}
if transaction != nil && receipt != nil {
logrus.Infof("Obtained transaction (%v) for %s", transaction.TransactionID, userID.Raw())
response.Transaction = *transaction
response.EntryID = receipt.GetNID()
response.EntriesQueued = true
} else {
logrus.Infof("No more entries in the queue for %s", userID.Raw())
response.EntryID = 0
response.EntriesQueued = false
}
return response, nil
}
func (r *RelayInternalAPI) processTransaction(txn *gomatrixserverlib.Transaction) {
logrus.Warn("Processing transaction from relay server")
mu := internal.NewMutexByRoom()
t := internal.NewTxnReq(
r.rsAPI,
nil,
r.serverName,
r.keyRing,
mu,
r.producer,
r.presenceEnabledInbound,
txn.PDUs,
txn.EDUs,
txn.Origin,
txn.TransactionID,
txn.Destination)
t.ProcessTransaction(context.TODO())
}

View File

@ -0,0 +1,121 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"fmt"
"testing"
fedAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
type testFedClient struct {
fedAPI.FederationClient
shouldFail bool
queryCount uint
queueDepth uint
}
func (f *testFedClient) P2PGetTransactionFromRelay(
ctx context.Context,
u gomatrixserverlib.UserID,
prev gomatrixserverlib.RelayEntry,
relayServer gomatrixserverlib.ServerName,
) (res gomatrixserverlib.RespGetRelayTransaction, err error) {
f.queryCount++
if f.shouldFail {
return res, fmt.Errorf("Error")
}
res = gomatrixserverlib.RespGetRelayTransaction{
Txn: gomatrixserverlib.Transaction{},
EntryID: 0,
}
if f.queueDepth > 0 {
res.EntriesQueued = true
} else {
res.EntriesQueued = false
}
f.queueDepth -= 1
return
}
func TestPerformRelayServerSync(t *testing.T) {
testDB := test.NewInMemoryRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
assert.Nil(t, err, "Invalid userID")
fedClient := &testFedClient{}
relayAPI := NewRelayInternalAPI(
&db, fedClient, nil, nil, nil, false, "",
)
err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay"))
assert.NoError(t, err)
}
func TestPerformRelayServerSyncFedError(t *testing.T) {
testDB := test.NewInMemoryRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
assert.Nil(t, err, "Invalid userID")
fedClient := &testFedClient{shouldFail: true}
relayAPI := NewRelayInternalAPI(
&db, fedClient, nil, nil, nil, false, "",
)
err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay"))
assert.Error(t, err)
}
func TestPerformRelayServerSyncRunsUntilQueueEmpty(t *testing.T) {
testDB := test.NewInMemoryRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
assert.Nil(t, err, "Invalid userID")
fedClient := &testFedClient{queueDepth: 2}
relayAPI := NewRelayInternalAPI(
&db, fedClient, nil, nil, nil, false, "",
)
err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay"))
assert.NoError(t, err)
assert.Equal(t, uint(3), fedClient.queryCount)
}

74
relayapi/relayapi.go Normal file
View File

@ -0,0 +1,74 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package relayapi
import (
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/dendrite/relayapi/internal"
"github.com/matrix-org/dendrite/relayapi/routing"
"github.com/matrix-org/dendrite/relayapi/storage"
rsAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
// AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component.
func AddPublicRoutes(
base *base.BaseDendrite,
keyRing gomatrixserverlib.JSONVerifier,
relayAPI api.RelayInternalAPI,
) {
fedCfg := &base.Cfg.FederationAPI
relay, ok := relayAPI.(*internal.RelayInternalAPI)
if !ok {
panic("relayapi.AddPublicRoutes called with a RelayInternalAPI impl which was not " +
"RelayInternalAPI. This is a programming error.")
}
routing.Setup(
base.PublicFederationAPIMux,
fedCfg,
relay,
keyRing,
)
}
func NewRelayInternalAPI(
base *base.BaseDendrite,
fedClient *gomatrixserverlib.FederationClient,
rsAPI rsAPI.RoomserverInternalAPI,
keyRing *gomatrixserverlib.KeyRing,
producer *producers.SyncAPIProducer,
) api.RelayInternalAPI {
cfg := &base.Cfg.RelayAPI
relayDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName)
if err != nil {
logrus.WithError(err).Panic("failed to connect to relay db")
}
return internal.NewRelayInternalAPI(
relayDB,
fedClient,
rsAPI,
keyRing,
producer,
base.Cfg.Global.Presence.EnableInbound,
base.Cfg.Global.ServerName,
)
}

154
relayapi/relayapi_test.go Normal file
View File

@ -0,0 +1,154 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package relayapi_test
import (
"crypto/ed25519"
"encoding/hex"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
"github.com/matrix-org/dendrite/relayapi"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
func TestCreateNewRelayInternalAPI(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil)
assert.NotNil(t, relayAPI)
})
}
func TestCreateRelayInternalInvalidDatabasePanics(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
if dbType == test.DBTypeSQLite {
base.Cfg.RelayAPI.Database.ConnectionString = "file:"
} else {
base.Cfg.RelayAPI.Database.ConnectionString = "test"
}
defer close()
assert.Panics(t, func() {
relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil)
})
})
}
func TestCreateInvalidRelayPublicRoutesPanics(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
assert.Panics(t, func() {
relayapi.AddPublicRoutes(base, nil, nil)
})
})
}
func createGetRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, userID string) *http.Request {
_, sk, _ := ed25519.GenerateKey(nil)
keyID := signing.KeyID
pk := sk.Public().(ed25519.PublicKey)
origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk))
req := gomatrixserverlib.NewFederationRequest("GET", origin, serverName, "/_matrix/federation/v1/relay_txn/"+userID)
content := gomatrixserverlib.RelayEntry{EntryID: 0}
req.SetContent(content)
req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk)
httpreq, _ := req.HTTPRequest()
vars := map[string]string{"userID": userID}
httpreq = mux.SetURLVars(httpreq, vars)
return httpreq
}
type sendRelayContent struct {
PDUs []json.RawMessage `json:"pdus"`
EDUs []gomatrixserverlib.EDU `json:"edus"`
}
func createSendRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, txnID string, userID string) *http.Request {
_, sk, _ := ed25519.GenerateKey(nil)
keyID := signing.KeyID
pk := sk.Public().(ed25519.PublicKey)
origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk))
req := gomatrixserverlib.NewFederationRequest("PUT", origin, serverName, "/_matrix/federation/v1/send_relay/"+txnID+"/"+userID)
content := sendRelayContent{}
req.SetContent(content)
req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk)
httpreq, _ := req.HTTPRequest()
vars := map[string]string{"userID": userID, "txnID": txnID}
httpreq = mux.SetURLVars(httpreq, vars)
return httpreq
}
func TestCreateRelayPublicRoutes(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil)
assert.NotNil(t, relayAPI)
serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing()
relayapi.AddPublicRoutes(base, keyRing, relayAPI)
testCases := []struct {
name string
req *http.Request
wantCode int
wantJoinedRooms []string
}{
{
name: "relay_txn invalid user id",
req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "user:local"),
wantCode: 400,
},
{
name: "relay_txn valid user id",
req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "@user:local"),
wantCode: 200,
},
{
name: "send_relay invalid user id",
req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "user:local"),
wantCode: 400,
},
{
name: "send_relay valid user id",
req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "@user:local"),
wantCode: 200,
},
}
for _, tc := range testCases {
w := httptest.NewRecorder()
base.PublicFederationAPIMux.ServeHTTP(w, tc.req)
if w.Code != tc.wantCode {
t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode)
}
}
})
}

View File

@ -0,0 +1,74 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"encoding/json"
"net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
)
type RelayTransactionResponse struct {
Transaction gomatrixserverlib.Transaction `json:"transaction"`
EntryID int64 `json:"entry_id,omitempty"`
EntriesQueued bool `json:"entries_queued"`
}
// GetTransactionFromRelay implements /_matrix/federation/v1/relay_txn/{userID}
// This endpoint can be extracted into a separate relay server service.
func GetTransactionFromRelay(
httpReq *http.Request,
fedReq *gomatrixserverlib.FederationRequest,
relayAPI api.RelayInternalAPI,
userID gomatrixserverlib.UserID,
) util.JSONResponse {
logrus.Infof("Handling relay_txn for %s", userID.Raw())
previousEntry := gomatrixserverlib.RelayEntry{}
if err := json.Unmarshal(fedReq.Content(), &previousEntry); err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.BadJSON("invalid json provided"),
}
}
if previousEntry.EntryID < 0 {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.BadJSON("Invalid entry id provided. Must be >= 0."),
}
}
logrus.Infof("Previous entry provided: %v", previousEntry.EntryID)
response, err := relayAPI.QueryTransactions(httpReq.Context(), userID, previousEntry)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: RelayTransactionResponse{
Transaction: response.Transaction,
EntryID: response.EntryID,
EntriesQueued: response.EntriesQueued,
},
}
}

View File

@ -0,0 +1,220 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing_test
import (
"context"
"net/http"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/internal"
"github.com/matrix-org/dendrite/relayapi/routing"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
func createQuery(
userID gomatrixserverlib.UserID,
prevEntry gomatrixserverlib.RelayEntry,
) gomatrixserverlib.FederationRequest {
var federationPathPrefixV1 = "/_matrix/federation/v1"
path := federationPathPrefixV1 + "/relay_txn/" + userID.Raw()
request := gomatrixserverlib.NewFederationRequest("GET", userID.Domain(), "relay", path)
request.SetContent(prevEntry)
return request
}
func TestGetEmptyDatabaseReturnsNothing(t *testing.T) {
testDB := test.NewInMemoryRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
assert.NoError(t, err, "Invalid userID")
transaction := createTransaction()
_, err = db.StoreTransaction(context.Background(), transaction)
assert.NoError(t, err, "Failed to store transaction")
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
request := createQuery(*userID, gomatrixserverlib.RelayEntry{})
response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse := response.JSON.(routing.RelayTransactionResponse)
assert.Equal(t, false, jsonResponse.EntriesQueued)
assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction)
count, err := db.GetTransactionCount(context.Background(), *userID)
assert.NoError(t, err)
assert.Zero(t, count)
}
func TestGetInvalidPrevEntryFails(t *testing.T) {
testDB := test.NewInMemoryRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
assert.NoError(t, err, "Invalid userID")
transaction := createTransaction()
_, err = db.StoreTransaction(context.Background(), transaction)
assert.NoError(t, err, "Failed to store transaction")
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
request := createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: -1})
response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID)
assert.Equal(t, http.StatusInternalServerError, response.Code)
}
func TestGetReturnsSavedTransaction(t *testing.T) {
testDB := test.NewInMemoryRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
assert.NoError(t, err, "Invalid userID")
transaction := createTransaction()
receipt, err := db.StoreTransaction(context.Background(), transaction)
assert.NoError(t, err, "Failed to store transaction")
err = db.AssociateTransactionWithDestinations(
context.Background(),
map[gomatrixserverlib.UserID]struct{}{
*userID: {},
},
transaction.TransactionID,
receipt)
assert.NoError(t, err, "Failed to associate transaction with user")
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
request := createQuery(*userID, gomatrixserverlib.RelayEntry{})
response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse := response.JSON.(routing.RelayTransactionResponse)
assert.True(t, jsonResponse.EntriesQueued)
assert.Equal(t, transaction, jsonResponse.Transaction)
// And once more to clear the queue
request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID})
response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse = response.JSON.(routing.RelayTransactionResponse)
assert.False(t, jsonResponse.EntriesQueued)
assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction)
count, err := db.GetTransactionCount(context.Background(), *userID)
assert.NoError(t, err)
assert.Zero(t, count)
}
func TestGetReturnsMultipleSavedTransactions(t *testing.T) {
testDB := test.NewInMemoryRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
assert.NoError(t, err, "Invalid userID")
transaction := createTransaction()
receipt, err := db.StoreTransaction(context.Background(), transaction)
assert.NoError(t, err, "Failed to store transaction")
err = db.AssociateTransactionWithDestinations(
context.Background(),
map[gomatrixserverlib.UserID]struct{}{
*userID: {},
},
transaction.TransactionID,
receipt)
assert.NoError(t, err, "Failed to associate transaction with user")
transaction2 := createTransaction()
receipt2, err := db.StoreTransaction(context.Background(), transaction2)
assert.NoError(t, err, "Failed to store transaction")
err = db.AssociateTransactionWithDestinations(
context.Background(),
map[gomatrixserverlib.UserID]struct{}{
*userID: {},
},
transaction2.TransactionID,
receipt2)
assert.NoError(t, err, "Failed to associate transaction with user")
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
request := createQuery(*userID, gomatrixserverlib.RelayEntry{})
response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse := response.JSON.(routing.RelayTransactionResponse)
assert.True(t, jsonResponse.EntriesQueued)
assert.Equal(t, transaction, jsonResponse.Transaction)
request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID})
response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse = response.JSON.(routing.RelayTransactionResponse)
assert.True(t, jsonResponse.EntriesQueued)
assert.Equal(t, transaction2, jsonResponse.Transaction)
// And once more to clear the queue
request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID})
response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID)
assert.Equal(t, http.StatusOK, response.Code)
jsonResponse = response.JSON.(routing.RelayTransactionResponse)
assert.False(t, jsonResponse.EntriesQueued)
assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction)
count, err := db.GetTransactionCount(context.Background(), *userID)
assert.NoError(t, err)
assert.Zero(t, count)
}

123
relayapi/routing/routing.go Normal file
View File

@ -0,0 +1,123 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"fmt"
"net/http"
"time"
"github.com/getsentry/sentry-go"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/httputil"
relayInternal "github.com/matrix-org/dendrite/relayapi/internal"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// Setup registers HTTP handlers with the given ServeMux.
// The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly
// path unescape twice (once from the router, once from MakeRelayAPI). We need to have this enabled
// so we can decode paths like foo/bar%2Fbaz as [foo, bar/baz] - by default it will decode to [foo, bar, baz]
//
// Due to Setup being used to call many other functions, a gocyclo nolint is
// applied:
// nolint: gocyclo
func Setup(
fedMux *mux.Router,
cfg *config.FederationAPI,
relayAPI *relayInternal.RelayInternalAPI,
keys gomatrixserverlib.JSONVerifier,
) {
v1fedmux := fedMux.PathPrefix("/v1").Subrouter()
v1fedmux.Handle("/send_relay/{txnID}/{userID}", MakeRelayAPI(
"send_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username was invalid"),
}
}
return SendTransactionToRelay(
httpReq, request, relayAPI, gomatrixserverlib.TransactionID(vars["txnID"]),
*userID,
)
},
)).Methods(http.MethodPut, http.MethodOptions)
v1fedmux.Handle("/relay_txn/{userID}", MakeRelayAPI(
"get_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username was invalid"),
}
}
return GetTransactionFromRelay(httpReq, request, relayAPI, *userID)
},
)).Methods(http.MethodGet, http.MethodOptions)
}
// MakeRelayAPI makes an http.Handler that checks matrix relay authentication.
func MakeRelayAPI(
metricsName string, serverName gomatrixserverlib.ServerName,
isLocalServerName func(gomatrixserverlib.ServerName) bool,
keyRing gomatrixserverlib.JSONVerifier,
f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse,
) http.Handler {
h := func(req *http.Request) util.JSONResponse {
fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
req, time.Now(), serverName, isLocalServerName, keyRing,
)
if fedReq == nil {
return errResp
}
// add the user to Sentry, if enabled
hub := sentry.GetHubFromContext(req.Context())
if hub != nil {
hub.Scope().SetTag("origin", string(fedReq.Origin()))
hub.Scope().SetTag("uri", fedReq.RequestURI())
}
defer func() {
if r := recover(); r != nil {
if hub != nil {
hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path))
}
// re-panic to return the 500
panic(r)
}
}()
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.MatrixErrorResponse(400, "M_UNRECOGNISED", "badly encoded query params")
}
jsonRes := f(req, fedReq, vars)
// do not log 4xx as errors as they are client fails, not server fails
if hub != nil && jsonRes.Code >= 500 {
hub.Scope().SetExtra("response", jsonRes)
hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code))
}
return jsonRes
}
return httputil.MakeExternalAPI(metricsName, h)
}

View File

@ -0,0 +1,77 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"encoding/json"
"net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
)
// SendTransactionToRelay implements PUT /_matrix/federation/v1/relay_txn/{txnID}/{userID}
// This endpoint can be extracted into a separate relay server service.
func SendTransactionToRelay(
httpReq *http.Request,
fedReq *gomatrixserverlib.FederationRequest,
relayAPI api.RelayInternalAPI,
txnID gomatrixserverlib.TransactionID,
userID gomatrixserverlib.UserID,
) util.JSONResponse {
var txnEvents struct {
PDUs []json.RawMessage `json:"pdus"`
EDUs []gomatrixserverlib.EDU `json:"edus"`
}
if err := json.Unmarshal(fedReq.Content(), &txnEvents); err != nil {
logrus.Info("The request body could not be decoded into valid JSON." + err.Error())
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON." + err.Error()),
}
}
// Transactions are limited in size; they can have at most 50 PDUs and 100 EDUs.
// https://matrix.org/docs/spec/server_server/latest#transactions
if len(txnEvents.PDUs) > 50 || len(txnEvents.EDUs) > 100 {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("max 50 pdus / 100 edus"),
}
}
t := gomatrixserverlib.Transaction{}
t.PDUs = txnEvents.PDUs
t.EDUs = txnEvents.EDUs
t.Origin = fedReq.Origin()
t.TransactionID = txnID
t.Destination = userID.Domain()
util.GetLogger(httpReq.Context()).Warnf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, fedReq.Origin(), len(t.PDUs), len(t.EDUs))
err := relayAPI.PerformStoreTransaction(httpReq.Context(), t, userID)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.BadJSON("could not store the transaction for forwarding"),
}
}
return util.JSONResponse{Code: 200}
}

View File

@ -0,0 +1,209 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing_test
import (
"context"
"encoding/json"
"net/http"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/internal"
"github.com/matrix-org/dendrite/relayapi/routing"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
const (
testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
)
func createTransaction() gomatrixserverlib.Transaction {
txn := gomatrixserverlib.Transaction{}
txn.PDUs = []json.RawMessage{
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`),
}
txn.Origin = testOrigin
return txn
}
func createFederationRequest(
userID gomatrixserverlib.UserID,
txnID gomatrixserverlib.TransactionID,
origin gomatrixserverlib.ServerName,
destination gomatrixserverlib.ServerName,
content interface{},
) gomatrixserverlib.FederationRequest {
var federationPathPrefixV1 = "/_matrix/federation/v1"
path := federationPathPrefixV1 + "/send_relay/" + string(txnID) + "/" + userID.Raw()
request := gomatrixserverlib.NewFederationRequest("PUT", origin, destination, path)
request.SetContent(content)
return request
}
func TestForwardEmptyReturnsOk(t *testing.T) {
testDB := test.NewInMemoryRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
assert.NoError(t, err, "Invalid userID")
txn := createTransaction()
request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn)
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID)
assert.Equal(t, 200, response.Code)
}
func TestForwardBadJSONReturnsError(t *testing.T) {
testDB := test.NewInMemoryRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
assert.NoError(t, err, "Invalid userID")
type BadData struct {
Field bool `json:"pdus"`
}
content := BadData{
Field: false,
}
txn := createTransaction()
request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content)
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID)
assert.NotEqual(t, 200, response.Code)
}
func TestForwardTooManyPDUsReturnsError(t *testing.T) {
testDB := test.NewInMemoryRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
assert.NoError(t, err, "Invalid userID")
type BadData struct {
Field []json.RawMessage `json:"pdus"`
}
content := BadData{
Field: []json.RawMessage{},
}
for i := 0; i < 51; i++ {
content.Field = append(content.Field, []byte{})
}
assert.Greater(t, len(content.Field), 50)
txn := createTransaction()
request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content)
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID)
assert.NotEqual(t, 200, response.Code)
}
func TestForwardTooManyEDUsReturnsError(t *testing.T) {
testDB := test.NewInMemoryRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
assert.NoError(t, err, "Invalid userID")
type BadData struct {
Field []gomatrixserverlib.EDU `json:"edus"`
}
content := BadData{
Field: []gomatrixserverlib.EDU{},
}
for i := 0; i < 101; i++ {
content.Field = append(content.Field, gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping})
}
assert.Greater(t, len(content.Field), 100)
txn := createTransaction()
request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content)
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID)
assert.NotEqual(t, 200, response.Code)
}
func TestUniqueTransactionStoredInDatabase(t *testing.T) {
testDB := test.NewInMemoryRelayDatabase()
db := shared.Database{
Writer: sqlutil.NewDummyWriter(),
RelayQueue: testDB,
RelayQueueJSON: testDB,
}
httpReq := &http.Request{}
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
assert.NoError(t, err, "Invalid userID")
txn := createTransaction()
request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn)
relayAPI := internal.NewRelayInternalAPI(
&db, nil, nil, nil, nil, false, "",
)
response := routing.SendTransactionToRelay(
httpReq, &request, relayAPI, txn.TransactionID, *userID)
transaction, _, err := db.GetTransaction(context.Background(), *userID)
assert.NoError(t, err, "Failed retrieving transaction")
transactionCount, err := db.GetTransactionCount(context.Background(), *userID)
assert.NoError(t, err, "Failed retrieving transaction count")
assert.Equal(t, 200, response.Code)
assert.Equal(t, int64(1), transactionCount)
assert.Equal(t, txn.TransactionID, transaction.TransactionID)
}

View File

@ -0,0 +1,47 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
// Adds a new transaction to the queue json table.
// Adding a duplicate transaction will result in a new row being added and a new unique nid.
// return: unique nid representing this entry.
StoreTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*receipt.Receipt, error)
// Adds a new transaction_id: server_name mapping with associated json table nid to the queue
// entry table for each provided destination.
AssociateTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, dbReceipt *receipt.Receipt) error
// Removes every server_name: receipt pair provided from the queue entries table.
// Will then remove every entry for each receipt provided from the queue json table.
// If any of the entries don't exist in either table, nothing will happen for that entry and
// an error will not be generated.
CleanTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*receipt.Receipt) error
// Gets the oldest transaction for the provided server_name.
// If no transactions exist, returns nil and no error.
GetTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *receipt.Receipt, error)
// Gets the number of transactions being stored for the provided server_name.
// If the server doesn't exist in the database then 0 is returned with no error.
GetTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error)
}

View File

@ -0,0 +1,113 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const relayQueueJSONSchema = `
-- The relayapi_queue_json table contains event contents that
-- we are storing for future forwarding.
CREATE TABLE IF NOT EXISTS relayapi_queue_json (
-- The JSON NID. This allows cross-referencing to find the JSON blob.
json_nid BIGSERIAL,
-- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx
ON relayapi_queue_json (json_nid);
`
const insertQueueJSONSQL = "" +
"INSERT INTO relayapi_queue_json (json_body)" +
" VALUES ($1)" +
" RETURNING json_nid"
const deleteQueueJSONSQL = "" +
"DELETE FROM relayapi_queue_json WHERE json_nid = ANY($1)"
const selectQueueJSONSQL = "" +
"SELECT json_nid, json_body FROM relayapi_queue_json" +
" WHERE json_nid = ANY($1)"
type relayQueueJSONStatements struct {
db *sql.DB
insertJSONStmt *sql.Stmt
deleteJSONStmt *sql.Stmt
selectJSONStmt *sql.Stmt
}
func NewPostgresRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) {
s = &relayQueueJSONStatements{
db: db,
}
_, err = s.db.Exec(relayQueueJSONSchema)
if err != nil {
return
}
return s, sqlutil.StatementList{
{&s.insertJSONStmt, insertQueueJSONSQL},
{&s.deleteJSONStmt, deleteQueueJSONSQL},
{&s.selectJSONStmt, selectQueueJSONSQL},
}.Prepare(db)
}
func (s *relayQueueJSONStatements) InsertQueueJSON(
ctx context.Context, txn *sql.Tx, json string,
) (int64, error) {
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
var lastid int64
if err := stmt.QueryRowContext(ctx, json).Scan(&lastid); err != nil {
return 0, err
}
return lastid, nil
}
func (s *relayQueueJSONStatements) DeleteQueueJSON(
ctx context.Context, txn *sql.Tx, nids []int64,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt)
_, err := stmt.ExecContext(ctx, pq.Int64Array(nids))
return err
}
func (s *relayQueueJSONStatements) SelectQueueJSON(
ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
) (map[int64][]byte, error) {
blobs := map[int64][]byte{}
stmt := sqlutil.TxStmt(txn, s.selectJSONStmt)
rows, err := stmt.QueryContext(ctx, pq.Int64Array(jsonNIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed")
for rows.Next() {
var nid int64
var blob []byte
if err = rows.Scan(&nid, &blob); err != nil {
return nil, err
}
blobs[nid] = blob
}
return blobs, err
}

View File

@ -0,0 +1,156 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const relayQueueSchema = `
CREATE TABLE IF NOT EXISTS relayapi_queue (
-- The transaction ID that was generated before persisting the event.
transaction_id TEXT NOT NULL,
-- The destination server that we will send the event to.
server_name TEXT NOT NULL,
-- The JSON NID from the relayapi_queue_json table.
json_nid BIGINT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx
ON relayapi_queue (json_nid, server_name);
CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx
ON relayapi_queue (json_nid);
CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx
ON relayapi_queue (server_name);
`
const insertQueueEntrySQL = "" +
"INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
const deleteQueueEntriesSQL = "" +
"DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid = ANY($2)"
const selectQueueEntriesSQL = "" +
"SELECT json_nid FROM relayapi_queue" +
" WHERE server_name = $1" +
" ORDER BY json_nid" +
" LIMIT $2"
const selectQueueEntryCountSQL = "" +
"SELECT COUNT(*) FROM relayapi_queue" +
" WHERE server_name = $1"
type relayQueueStatements struct {
db *sql.DB
insertQueueEntryStmt *sql.Stmt
deleteQueueEntriesStmt *sql.Stmt
selectQueueEntriesStmt *sql.Stmt
selectQueueEntryCountStmt *sql.Stmt
}
func NewPostgresRelayQueueTable(
db *sql.DB,
) (s *relayQueueStatements, err error) {
s = &relayQueueStatements{
db: db,
}
_, err = s.db.Exec(relayQueueSchema)
if err != nil {
return
}
return s, sqlutil.StatementList{
{&s.insertQueueEntryStmt, insertQueueEntrySQL},
{&s.deleteQueueEntriesStmt, deleteQueueEntriesSQL},
{&s.selectQueueEntriesStmt, selectQueueEntriesSQL},
{&s.selectQueueEntryCountStmt, selectQueueEntryCountSQL},
}.Prepare(db)
}
func (s *relayQueueStatements) InsertQueueEntry(
ctx context.Context,
txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt)
_, err := stmt.ExecContext(
ctx,
transactionID, // the transaction ID that we initially attempted
serverName, // destination server name
nid, // JSON blob NID
)
return err
}
func (s *relayQueueStatements) DeleteQueueEntries(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
jsonNIDs []int64,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteQueueEntriesStmt)
_, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs))
return err
}
func (s *relayQueueStatements) SelectQueueEntries(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []int64
for rows.Next() {
var nid int64
if err = rows.Scan(&nid); err != nil {
return nil, err
}
result = append(result, nid)
}
return result, rows.Err()
}
func (s *relayQueueStatements) SelectQueueEntryCount(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}

View File

@ -0,0 +1,64 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"database/sql"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database stores information needed by the relayapi
type Database struct {
shared.Database
db *sql.DB
writer sqlutil.Writer
}
// NewDatabase opens a new database
func NewDatabase(
base *base.BaseDendrite,
dbProperties *config.DatabaseOptions,
cache caching.FederationCache,
isLocalServerName func(gomatrixserverlib.ServerName) bool,
) (*Database, error) {
var d Database
var err error
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
return nil, err
}
queue, err := NewPostgresRelayQueueTable(d.db)
if err != nil {
return nil, err
}
queueJSON, err := NewPostgresRelayQueueJSONTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
IsLocalServerName: isLocalServerName,
Cache: cache,
Writer: d.writer,
RelayQueue: queue,
RelayQueueJSON: queueJSON,
}
return &d, nil
}

View File

@ -0,0 +1,170 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shared
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
type Database struct {
DB *sql.DB
IsLocalServerName func(gomatrixserverlib.ServerName) bool
Cache caching.FederationCache
Writer sqlutil.Writer
RelayQueue tables.RelayQueue
RelayQueueJSON tables.RelayQueueJSON
}
func (d *Database) StoreTransaction(
ctx context.Context,
transaction gomatrixserverlib.Transaction,
) (*receipt.Receipt, error) {
var err error
jsonTransaction, err := json.Marshal(transaction)
if err != nil {
return nil, fmt.Errorf("failed to marshal: %w", err)
}
var nid int64
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nid, err = d.RelayQueueJSON.InsertQueueJSON(ctx, txn, string(jsonTransaction))
return err
})
if err != nil {
return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
}
newReceipt := receipt.NewReceipt(nid)
return &newReceipt, nil
}
func (d *Database) AssociateTransactionWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.UserID]struct{},
transactionID gomatrixserverlib.TransactionID,
dbReceipt *receipt.Receipt,
) error {
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var lastErr error
for destination := range destinations {
destination := destination
err := d.RelayQueue.InsertQueueEntry(
ctx,
txn,
transactionID,
destination.Domain(),
dbReceipt.GetNID(),
)
if err != nil {
lastErr = fmt.Errorf("d.insertQueueEntry: %w", err)
}
}
return lastErr
})
return err
}
func (d *Database) CleanTransactions(
ctx context.Context,
userID gomatrixserverlib.UserID,
receipts []*receipt.Receipt,
) error {
nids := make([]int64, len(receipts))
for i, dbReceipt := range receipts {
nids[i] = dbReceipt.GetNID()
}
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
deleteEntryErr := d.RelayQueue.DeleteQueueEntries(ctx, txn, userID.Domain(), nids)
// TODO : If there are still queue entries for any of these nids for other destinations
// then we shouldn't delete the json entries.
// But this can't happen with the current api design.
// There will only ever be one server entry for each nid since each call to send_relay
// only accepts a single server name and inside there we create a new json entry.
// So for multiple destinations we would call send_relay multiple times and have multiple
// json entries of the same transaction.
//
// TLDR; this works as expected right now but can easily be optimised in the future.
deleteJSONErr := d.RelayQueueJSON.DeleteQueueJSON(ctx, txn, nids)
if deleteEntryErr != nil {
return fmt.Errorf("d.deleteQueueEntries: %w", deleteEntryErr)
}
if deleteJSONErr != nil {
return fmt.Errorf("d.deleteQueueJSON: %w", deleteJSONErr)
}
return nil
})
return err
}
func (d *Database) GetTransaction(
ctx context.Context,
userID gomatrixserverlib.UserID,
) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) {
entriesRequested := 1
nids, err := d.RelayQueue.SelectQueueEntries(ctx, nil, userID.Domain(), entriesRequested)
if err != nil {
return nil, nil, fmt.Errorf("d.SelectQueueEntries: %w", err)
}
if len(nids) == 0 {
return nil, nil, nil
}
firstNID := nids[0]
txns := map[int64][]byte{}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
txns, err = d.RelayQueueJSON.SelectQueueJSON(ctx, txn, nids)
return err
})
if err != nil {
return nil, nil, fmt.Errorf("d.SelectQueueJSON: %w", err)
}
transaction := &gomatrixserverlib.Transaction{}
if _, ok := txns[firstNID]; !ok {
return nil, nil, fmt.Errorf("Failed retrieving json blob for transaction: %d", firstNID)
}
err = json.Unmarshal(txns[firstNID], transaction)
if err != nil {
return nil, nil, fmt.Errorf("Unmarshal transaction: %w", err)
}
newReceipt := receipt.NewReceipt(firstNID)
return transaction, &newReceipt, nil
}
func (d *Database) GetTransactionCount(
ctx context.Context,
userID gomatrixserverlib.UserID,
) (int64, error) {
count, err := d.RelayQueue.SelectQueueEntryCount(ctx, nil, userID.Domain())
if err != nil {
return 0, fmt.Errorf("d.SelectQueueEntryCount: %w", err)
}
return count, nil
}

View File

@ -0,0 +1,137 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const relayQueueJSONSchema = `
-- The relayapi_queue_json table contains event contents that
-- we are storing for future forwarding.
CREATE TABLE IF NOT EXISTS relayapi_queue_json (
-- The JSON NID. This allows cross-referencing to find the JSON blob.
json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
-- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx
ON relayapi_queue_json (json_nid);
`
const insertQueueJSONSQL = "" +
"INSERT INTO relayapi_queue_json (json_body)" +
" VALUES ($1)"
const deleteQueueJSONSQL = "" +
"DELETE FROM relayapi_queue_json WHERE json_nid IN ($1)"
const selectQueueJSONSQL = "" +
"SELECT json_nid, json_body FROM relayapi_queue_json" +
" WHERE json_nid IN ($1)"
type relayQueueJSONStatements struct {
db *sql.DB
insertJSONStmt *sql.Stmt
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
}
func NewSQLiteRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) {
s = &relayQueueJSONStatements{
db: db,
}
_, err = db.Exec(relayQueueJSONSchema)
if err != nil {
return
}
return s, sqlutil.StatementList{
{&s.insertJSONStmt, insertQueueJSONSQL},
}.Prepare(db)
}
func (s *relayQueueJSONStatements) InsertQueueJSON(
ctx context.Context, txn *sql.Tx, json string,
) (lastid int64, err error) {
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
res, err := stmt.ExecContext(ctx, json)
if err != nil {
return 0, fmt.Errorf("stmt.QueryContext: %w", err)
}
lastid, err = res.LastInsertId()
if err != nil {
return 0, fmt.Errorf("res.LastInsertId: %w", err)
}
return
}
func (s *relayQueueJSONStatements) DeleteQueueJSON(
ctx context.Context, txn *sql.Tx, nids []int64,
) error {
deleteSQL := strings.Replace(deleteQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
deleteStmt, err := txn.Prepare(deleteSQL)
if err != nil {
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
}
iNIDs := make([]interface{}, len(nids))
for k, v := range nids {
iNIDs[k] = v
}
stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err = stmt.ExecContext(ctx, iNIDs...)
return err
}
func (s *relayQueueJSONStatements) SelectQueueJSON(
ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
) (map[int64][]byte, error) {
selectSQL := strings.Replace(selectQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
selectStmt, err := txn.Prepare(selectSQL)
if err != nil {
return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err)
}
iNIDs := make([]interface{}, len(jsonNIDs))
for k, v := range jsonNIDs {
iNIDs[k] = v
}
blobs := map[int64][]byte{}
stmt := sqlutil.TxStmt(txn, selectStmt)
rows, err := stmt.QueryContext(ctx, iNIDs...)
if err != nil {
return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err)
}
defer internal.CloseAndLogIfError(ctx, rows, "selectQueueJSON: rows.close() failed")
for rows.Next() {
var nid int64
var blob []byte
if err = rows.Scan(&nid, &blob); err != nil {
return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err)
}
blobs[nid] = blob
}
return blobs, err
}

View File

@ -0,0 +1,168 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const relayQueueSchema = `
CREATE TABLE IF NOT EXISTS relayapi_queue (
-- The transaction ID that was generated before persisting the event.
transaction_id TEXT NOT NULL,
-- The domain part of the user ID the m.room.member event is for.
server_name TEXT NOT NULL,
-- The JSON NID from the relayapi_queue_json table.
json_nid BIGINT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx
ON relayapi_queue (json_nid, server_name);
CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx
ON relayapi_queue (json_nid);
CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx
ON relayapi_queue (server_name);
`
const insertQueueEntrySQL = "" +
"INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
const deleteQueueEntriesSQL = "" +
"DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid IN ($2)"
const selectQueueEntriesSQL = "" +
"SELECT json_nid FROM relayapi_queue" +
" WHERE server_name = $1" +
" ORDER BY json_nid" +
" LIMIT $2"
const selectQueueEntryCountSQL = "" +
"SELECT COUNT(*) FROM relayapi_queue" +
" WHERE server_name = $1"
type relayQueueStatements struct {
db *sql.DB
insertQueueEntryStmt *sql.Stmt
selectQueueEntriesStmt *sql.Stmt
selectQueueEntryCountStmt *sql.Stmt
// deleteQueueEntriesStmt *sql.Stmt - prepared at runtime due to variadic
}
func NewSQLiteRelayQueueTable(
db *sql.DB,
) (s *relayQueueStatements, err error) {
s = &relayQueueStatements{
db: db,
}
_, err = db.Exec(relayQueueSchema)
if err != nil {
return
}
return s, sqlutil.StatementList{
{&s.insertQueueEntryStmt, insertQueueEntrySQL},
{&s.selectQueueEntriesStmt, selectQueueEntriesSQL},
{&s.selectQueueEntryCountStmt, selectQueueEntryCountSQL},
}.Prepare(db)
}
func (s *relayQueueStatements) InsertQueueEntry(
ctx context.Context,
txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt)
_, err := stmt.ExecContext(
ctx,
transactionID, // the transaction ID that we initially attempted
serverName, // destination server name
nid, // JSON blob NID
)
return err
}
func (s *relayQueueStatements) DeleteQueueEntries(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
jsonNIDs []int64,
) error {
deleteSQL := strings.Replace(deleteQueueEntriesSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
deleteStmt, err := txn.Prepare(deleteSQL)
if err != nil {
return fmt.Errorf("s.deleteQueueEntries s.db.Prepare: %w", err)
}
params := make([]interface{}, len(jsonNIDs)+1)
params[0] = serverName
for k, v := range jsonNIDs {
params[k+1] = v
}
stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err = stmt.ExecContext(ctx, params...)
return err
}
func (s *relayQueueStatements) SelectQueueEntries(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []int64
for rows.Next() {
var nid int64
if err = rows.Scan(&nid); err != nil {
return nil, err
}
result = append(result, nid)
}
return result, rows.Err()
}
func (s *relayQueueStatements) SelectQueueEntryCount(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}

View File

@ -0,0 +1,64 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"database/sql"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database stores information needed by the federation sender
type Database struct {
shared.Database
db *sql.DB
writer sqlutil.Writer
}
// NewDatabase opens a new database
func NewDatabase(
base *base.BaseDendrite,
dbProperties *config.DatabaseOptions,
cache caching.FederationCache,
isLocalServerName func(gomatrixserverlib.ServerName) bool,
) (*Database, error) {
var d Database
var err error
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
return nil, err
}
queue, err := NewSQLiteRelayQueueTable(d.db)
if err != nil {
return nil, err
}
queueJSON, err := NewSQLiteRelayQueueJSONTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
IsLocalServerName: isLocalServerName,
Cache: cache,
Writer: d.writer,
RelayQueue: queue,
RelayQueueJSON: queueJSON,
}
return &d, nil
}

View File

@ -0,0 +1,46 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !wasm
// +build !wasm
package storage
import (
"fmt"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/relayapi/storage/postgres"
"github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// NewDatabase opens a new database
func NewDatabase(
base *base.BaseDendrite,
dbProperties *config.DatabaseOptions,
cache caching.FederationCache,
isLocalServerName func(gomatrixserverlib.ServerName) bool,
) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName)
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View File

@ -0,0 +1,66 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tables
import (
"context"
"database/sql"
"github.com/matrix-org/gomatrixserverlib"
)
// RelayQueue table contains a mapping of server name to transaction id and the corresponding nid.
// These are the transactions being stored for the given destination server.
// The nids correspond to entries in the RelayQueueJSON table.
type RelayQueue interface {
// Adds a new transaction_id: server_name mapping with associated json table nid to the table.
// Will ensure only one transaction id is present for each server_name: nid mapping.
// Adding duplicates will silently do nothing.
InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
// Removes multiple entries from the table corresponding the the list of nids provided.
// If any of the provided nids don't match a row in the table, that deletion is considered
// successful.
DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
// Get a list of nids associated with the provided server name.
// Returns up to `limit` nids. The entries are returned oldest first.
// Will return an empty list if no matches were found.
SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
// Get the number of entries in the table associated with the provided server name.
// If there are no matching rows, a count of 0 is returned with err set to nil.
SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
}
// RelayQueueJSON table contains a map of nid to the raw transaction json.
type RelayQueueJSON interface {
// Adds a new transaction to the table.
// Adding a duplicate transaction will result in a new row being added and a new unique nid.
// return: unique nid representing this entry.
InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error)
// Removes multiple nids from the table.
// If any of the provided nids don't match a row in the table, that deletion is considered
// successful.
DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error
// Get the transaction json corresponding to the provided nids.
// Will return a partial result containing any matching nid from the table.
// Will return an empty map if no matches were found.
// It is the caller's responsibility to deal with the results appropriately.
// return: map indexed by nid of each matching transaction json.
SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error)
}

View File

@ -0,0 +1,173 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tables_test
import (
"context"
"database/sql"
"encoding/json"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/postgres"
"github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
"github.com/matrix-org/dendrite/relayapi/storage/tables"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
const (
testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
)
func mustCreateTransaction() gomatrixserverlib.Transaction {
txn := gomatrixserverlib.Transaction{}
txn.PDUs = []json.RawMessage{
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`),
}
txn.Origin = testOrigin
return txn
}
type RelayQueueJSONDatabase struct {
DB *sql.DB
Writer sqlutil.Writer
Table tables.RelayQueueJSON
}
func mustCreateQueueJSONTable(
t *testing.T,
dbType test.DBType,
) (database RelayQueueJSONDatabase, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
var tab tables.RelayQueueJSON
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresRelayQueueJSONTable(db)
assert.NoError(t, err)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteRelayQueueJSONTable(db)
assert.NoError(t, err)
}
assert.NoError(t, err)
database = RelayQueueJSONDatabase{
DB: db,
Writer: sqlutil.NewDummyWriter(),
Table: tab,
}
return database, close
}
func TestShoudInsertTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueJSONTable(t, dbType)
defer close()
transaction := mustCreateTransaction()
tx, err := json.Marshal(transaction)
if err != nil {
t.Fatalf("Invalid transaction: %s", err.Error())
}
_, err = db.Table.InsertQueueJSON(ctx, nil, string(tx))
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
})
}
func TestShouldRetrieveInsertedTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueJSONTable(t, dbType)
defer close()
transaction := mustCreateTransaction()
tx, err := json.Marshal(transaction)
if err != nil {
t.Fatalf("Invalid transaction: %s", err.Error())
}
nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx))
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
var storedJSON map[int64][]byte
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid})
return err
})
if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
assert.Equal(t, 1, len(storedJSON))
var storedTx gomatrixserverlib.Transaction
json.Unmarshal(storedJSON[1], &storedTx)
assert.Equal(t, transaction, storedTx)
})
}
func TestShouldDeleteTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueJSONTable(t, dbType)
defer close()
transaction := mustCreateTransaction()
tx, err := json.Marshal(transaction)
if err != nil {
t.Fatalf("Invalid transaction: %s", err.Error())
}
nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx))
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
storedJSON := map[int64][]byte{}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
err = db.Table.DeleteQueueJSON(ctx, txn, []int64{nid})
return err
})
if err != nil {
t.Fatalf("Failed deleting transaction: %s", err.Error())
}
storedJSON = map[int64][]byte{}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid})
return err
})
if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
assert.Equal(t, 0, len(storedJSON))
})
}

View File

@ -0,0 +1,229 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tables_test
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/postgres"
"github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
"github.com/matrix-org/dendrite/relayapi/storage/tables"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
type RelayQueueDatabase struct {
DB *sql.DB
Writer sqlutil.Writer
Table tables.RelayQueue
}
func mustCreateQueueTable(
t *testing.T,
dbType test.DBType,
) (database RelayQueueDatabase, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
var tab tables.RelayQueue
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresRelayQueueTable(db)
assert.NoError(t, err)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteRelayQueueTable(db)
assert.NoError(t, err)
}
assert.NoError(t, err)
database = RelayQueueDatabase{
DB: db,
Writer: sqlutil.NewDummyWriter(),
Table: tab,
}
return database, close
}
func TestShoudInsertQueueTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTable(t, dbType)
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
nid := int64(1)
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
})
}
func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTable(t, dbType)
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
nid := int64(1)
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 10)
if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
assert.Equal(t, nid, retrievedNids[0])
assert.Equal(t, 1, len(retrievedNids))
})
}
func TestShouldRetrieveOldestInsertedQueueTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTable(t, dbType)
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
nid := int64(2)
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName = gomatrixserverlib.ServerName("domain")
oldestNID := int64(1)
err = db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, oldestNID)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 1)
if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
assert.Equal(t, oldestNID, retrievedNids[0])
assert.Equal(t, 1, len(retrievedNids))
retrievedNids, err = db.Table.SelectQueueEntries(ctx, nil, serverName, 10)
if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
assert.Equal(t, oldestNID, retrievedNids[0])
assert.Equal(t, nid, retrievedNids[1])
assert.Equal(t, 2, len(retrievedNids))
})
}
func TestShouldDeleteQueueTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTable(t, dbType)
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
nid := int64(1)
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid})
return err
})
if err != nil {
t.Fatalf("Failed deleting transaction: %s", err.Error())
}
count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
assert.Equal(t, int64(0), count)
})
}
func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTable(t, dbType)
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
nid := int64(1)
transactionID2 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d2", time.Now().UnixNano()))
serverName2 := gomatrixserverlib.ServerName("domain2")
nid2 := int64(2)
transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano()))
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertQueueEntry(ctx, nil, transactionID2, serverName2, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertQueueEntry(ctx, nil, transactionID3, serverName, nid2)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid})
return err
})
if err != nil {
t.Fatalf("Failed deleting transaction: %s", err.Error())
}
count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
assert.Equal(t, int64(1), count)
count, err = db.Table.SelectQueueEntryCount(ctx, nil, serverName2)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
assert.Equal(t, int64(1), count)
})
}

View File

@ -595,6 +595,12 @@ func (b *BaseDendrite) WaitForShutdown() {
logrus.Warnf("failed to flush all Sentry events!")
}
}
if b.Fulltext != nil {
err := b.Fulltext.Close()
if err != nil {
logrus.Warnf("failed to close full text search!")
}
}
logrus.Warnf("Dendrite is exiting now")
}

View File

@ -62,6 +62,7 @@ type Dendrite struct {
RoomServer RoomServer `yaml:"room_server"`
SyncAPI SyncAPI `yaml:"sync_api"`
UserAPI UserAPI `yaml:"user_api"`
RelayAPI RelayAPI `yaml:"relay_api"`
MSCs MSCs `yaml:"mscs"`
@ -349,6 +350,7 @@ func (c *Dendrite) Defaults(opts DefaultOpts) {
c.SyncAPI.Defaults(opts)
c.UserAPI.Defaults(opts)
c.AppServiceAPI.Defaults(opts)
c.RelayAPI.Defaults(opts)
c.MSCs.Defaults(opts)
c.Wiring()
}
@ -361,7 +363,7 @@ func (c *Dendrite) Verify(configErrs *ConfigErrors, isMonolith bool) {
&c.Global, &c.ClientAPI, &c.FederationAPI,
&c.KeyServer, &c.MediaAPI, &c.RoomServer,
&c.SyncAPI, &c.UserAPI,
&c.AppServiceAPI, &c.MSCs,
&c.AppServiceAPI, &c.RelayAPI, &c.MSCs,
} {
c.Verify(configErrs, isMonolith)
}
@ -377,6 +379,7 @@ func (c *Dendrite) Wiring() {
c.SyncAPI.Matrix = &c.Global
c.UserAPI.Matrix = &c.Global
c.AppServiceAPI.Matrix = &c.Global
c.RelayAPI.Matrix = &c.Global
c.MSCs.Matrix = &c.Global
c.ClientAPI.Derived = &c.Derived

View File

@ -18,6 +18,12 @@ type FederationAPI struct {
// The default value is 16 if not specified, which is circa 18 hours.
FederationMaxRetries uint32 `yaml:"send_max_retries"`
// P2P Feature: How many consecutive failures that we should tolerate when
// sending federation requests to a specific server until we should assume they
// are offline. If we assume they are offline then we will attempt to send
// messages to their relay server if we know of one that is appropriate.
P2PFederationRetriesUntilAssumedOffline uint32 `yaml:"p2p_retries_until_assumed_offline"`
// FederationDisableTLSValidation disables the validation of X.509 TLS certs
// on remote federation endpoints. This is not recommended in production!
DisableTLSValidation bool `yaml:"disable_tls_validation"`
@ -43,6 +49,7 @@ func (c *FederationAPI) Defaults(opts DefaultOpts) {
c.Database.Defaults(10)
}
c.FederationMaxRetries = 16
c.P2PFederationRetriesUntilAssumedOffline = 2
c.DisableTLSValidation = false
c.DisableHTTPKeepalives = false
if opts.Generate {

View File

@ -0,0 +1,52 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package config
type RelayAPI struct {
Matrix *Global `yaml:"-"`
InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"`
ExternalAPI ExternalAPIOptions `yaml:"external_api,omitempty"`
// The database stores information used by the relay queue to
// forward transactions to remote servers.
Database DatabaseOptions `yaml:"database,omitempty"`
}
func (c *RelayAPI) Defaults(opts DefaultOpts) {
if !opts.Monolithic {
c.InternalAPI.Listen = "http://localhost:7775"
c.InternalAPI.Connect = "http://localhost:7775"
c.ExternalAPI.Listen = "http://[::]:8075"
c.Database.Defaults(10)
}
if opts.Generate {
if !opts.Monolithic {
c.Database.ConnectionString = "file:relayapi.db"
}
}
}
func (c *RelayAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
if isMonolith { // polylith required configs below
return
}
if c.Matrix.DatabaseOptions.ConnectionString == "" {
checkNotEmpty(configErrs, "relay_api.database.connection_string", string(c.Database.ConnectionString))
}
checkURL(configErrs, "relay_api.external_api.listen", string(c.ExternalAPI.Listen))
checkURL(configErrs, "relay_api.internal_api.listen", string(c.InternalAPI.Listen))
checkURL(configErrs, "relay_api.internal_api.connect", string(c.InternalAPI.Connect))
}

View File

@ -20,11 +20,12 @@ import (
"testing"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
)
func TestLoadConfigRelative(t *testing.T) {
_, err := loadConfig("/my/config/dir", []byte(testConfig),
cfg, err := loadConfig("/my/config/dir", []byte(testConfig),
mockReadFile{
"/my/config/dir/matrix_key.pem": testKey,
"/my/config/dir/tls_cert.pem": testCert,
@ -34,6 +35,15 @@ func TestLoadConfigRelative(t *testing.T) {
if err != nil {
t.Error("failed to load config:", err)
}
configErrors := &ConfigErrors{}
cfg.Verify(configErrors, false)
if len(*configErrors) > 0 {
for _, err := range *configErrors {
logrus.Errorf("Configuration error: %s", err)
}
t.Error("configuration verification failed")
}
}
const testConfig = `
@ -68,6 +78,8 @@ global:
display_name: "Server alerts"
avatar: ""
room_name: "Server Alerts"
jetstream:
addresses: ["test"]
app_service_api:
internal_api:
listen: http://localhost:7777
@ -84,7 +96,7 @@ client_api:
connect: http://localhost:7771
external_api:
listen: http://[::]:8071
registration_disabled: false
registration_disabled: true
registration_shared_secret: ""
enable_registration_captcha: false
recaptcha_public_key: ""
@ -112,6 +124,8 @@ federation_api:
connect: http://localhost:7772
external_api:
listen: http://[::]:8072
database:
connection_string: file:federationapi.db
key_server:
internal_api:
listen: http://localhost:7779
@ -194,6 +208,17 @@ user_api:
max_open_conns: 100
max_idle_conns: 2
conn_max_lifetime: -1
relay_api:
internal_api:
listen: http://localhost:7775
connect: http://localhost:7775
external_api:
listen: http://[::]:8075
database:
connection_string: file:relayapi.db
mscs:
database:
connection_string: file:mscs.db
tracing:
enabled: false
jaeger:

View File

@ -23,6 +23,8 @@ import (
"github.com/matrix-org/dendrite/internal/transactions"
keyAPI "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/mediaapi"
"github.com/matrix-org/dendrite/relayapi"
relayAPI "github.com/matrix-org/dendrite/relayapi/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
@ -44,6 +46,7 @@ type Monolith struct {
RoomserverAPI roomserverAPI.RoomserverInternalAPI
UserAPI userapi.UserInternalAPI
KeyAPI keyAPI.KeyInternalAPI
RelayAPI relayAPI.RelayInternalAPI
// Optional
ExtPublicRoomsProvider api.ExtraPublicRoomsProvider
@ -71,4 +74,8 @@ func (m *Monolith) AddAllPublicRoutes(base *base.BaseDendrite) {
syncapi.AddPublicRoutes(
base, m.UserAPI, m.RoomserverAPI, m.KeyAPI,
)
if m.RelayAPI != nil {
relayapi.AddPublicRoutes(base, m.KeyRing, m.RelayAPI)
}
}

View File

@ -101,7 +101,6 @@ func currentUser() string {
// Returns the connection string to use and a close function which must be called when the test finishes.
// Calling this function twice will return the same database, which will have data from previous tests
// unless close() is called.
// TODO: namespace for concurrent package tests
func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) {
if dbType == DBTypeSQLite {
// this will be made in the t.TempDir, which is unique per test

View File

@ -0,0 +1,488 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"context"
"encoding/json"
"errors"
"sync"
"time"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
var nidMutex sync.Mutex
var nid = int64(0)
type InMemoryFederationDatabase struct {
dbMutex sync.Mutex
pendingPDUServers map[gomatrixserverlib.ServerName]struct{}
pendingEDUServers map[gomatrixserverlib.ServerName]struct{}
blacklistedServers map[gomatrixserverlib.ServerName]struct{}
assumedOffline map[gomatrixserverlib.ServerName]struct{}
pendingPDUs map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent
pendingEDUs map[*receipt.Receipt]*gomatrixserverlib.EDU
associatedPDUs map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}
associatedEDUs map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}
relayServers map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName
}
func NewInMemoryFederationDatabase() *InMemoryFederationDatabase {
return &InMemoryFederationDatabase{
pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}),
pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}),
blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}),
assumedOffline: make(map[gomatrixserverlib.ServerName]struct{}),
pendingPDUs: make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent),
pendingEDUs: make(map[*receipt.Receipt]*gomatrixserverlib.EDU),
associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}),
associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}),
relayServers: make(map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName),
}
}
func (d *InMemoryFederationDatabase) StoreJSON(
ctx context.Context,
js string,
) (*receipt.Receipt, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
var event gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal([]byte(js), &event); err == nil {
nidMutex.Lock()
defer nidMutex.Unlock()
nid++
newReceipt := receipt.NewReceipt(nid)
d.pendingPDUs[&newReceipt] = &event
return &newReceipt, nil
}
var edu gomatrixserverlib.EDU
if err := json.Unmarshal([]byte(js), &edu); err == nil {
nidMutex.Lock()
defer nidMutex.Unlock()
nid++
newReceipt := receipt.NewReceipt(nid)
d.pendingEDUs[&newReceipt] = &edu
return &newReceipt, nil
}
return nil, errors.New("Failed to determine type of json to store")
}
func (d *InMemoryFederationDatabase) GetPendingPDUs(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
limit int,
) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
pduCount := 0
pdus = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent)
if receipts, ok := d.associatedPDUs[serverName]; ok {
for dbReceipt := range receipts {
if event, ok := d.pendingPDUs[dbReceipt]; ok {
pdus[dbReceipt] = event
pduCount++
if pduCount == limit {
break
}
}
}
}
return pdus, nil
}
func (d *InMemoryFederationDatabase) GetPendingEDUs(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
limit int,
) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
eduCount := 0
edus = make(map[*receipt.Receipt]*gomatrixserverlib.EDU)
if receipts, ok := d.associatedEDUs[serverName]; ok {
for dbReceipt := range receipts {
if event, ok := d.pendingEDUs[dbReceipt]; ok {
edus[dbReceipt] = event
eduCount++
if eduCount == limit {
break
}
}
}
}
return edus, nil
}
func (d *InMemoryFederationDatabase) AssociatePDUWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.ServerName]struct{},
dbReceipt *receipt.Receipt,
) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
if _, ok := d.pendingPDUs[dbReceipt]; ok {
for destination := range destinations {
if _, ok := d.associatedPDUs[destination]; !ok {
d.associatedPDUs[destination] = make(map[*receipt.Receipt]struct{})
}
d.associatedPDUs[destination][dbReceipt] = struct{}{}
}
return nil
} else {
return errors.New("PDU doesn't exist")
}
}
func (d *InMemoryFederationDatabase) AssociateEDUWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.ServerName]struct{},
dbReceipt *receipt.Receipt,
eduType string,
expireEDUTypes map[string]time.Duration,
) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
if _, ok := d.pendingEDUs[dbReceipt]; ok {
for destination := range destinations {
if _, ok := d.associatedEDUs[destination]; !ok {
d.associatedEDUs[destination] = make(map[*receipt.Receipt]struct{})
}
d.associatedEDUs[destination][dbReceipt] = struct{}{}
}
return nil
} else {
return errors.New("EDU doesn't exist")
}
}
func (d *InMemoryFederationDatabase) CleanPDUs(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
receipts []*receipt.Receipt,
) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
if pdus, ok := d.associatedPDUs[serverName]; ok {
for _, dbReceipt := range receipts {
delete(pdus, dbReceipt)
}
}
return nil
}
func (d *InMemoryFederationDatabase) CleanEDUs(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
receipts []*receipt.Receipt,
) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
if edus, ok := d.associatedEDUs[serverName]; ok {
for _, dbReceipt := range receipts {
delete(edus, dbReceipt)
}
}
return nil
}
func (d *InMemoryFederationDatabase) GetPendingPDUCount(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) (int64, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
var count int64
if pdus, ok := d.associatedPDUs[serverName]; ok {
count = int64(len(pdus))
}
return count, nil
}
func (d *InMemoryFederationDatabase) GetPendingEDUCount(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) (int64, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
var count int64
if edus, ok := d.associatedEDUs[serverName]; ok {
count = int64(len(edus))
}
return count, nil
}
func (d *InMemoryFederationDatabase) GetPendingPDUServerNames(
ctx context.Context,
) ([]gomatrixserverlib.ServerName, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
servers := []gomatrixserverlib.ServerName{}
for server := range d.pendingPDUServers {
servers = append(servers, server)
}
return servers, nil
}
func (d *InMemoryFederationDatabase) GetPendingEDUServerNames(
ctx context.Context,
) ([]gomatrixserverlib.ServerName, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
servers := []gomatrixserverlib.ServerName{}
for server := range d.pendingEDUServers {
servers = append(servers, server)
}
return servers, nil
}
func (d *InMemoryFederationDatabase) AddServerToBlacklist(
serverName gomatrixserverlib.ServerName,
) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
d.blacklistedServers[serverName] = struct{}{}
return nil
}
func (d *InMemoryFederationDatabase) RemoveServerFromBlacklist(
serverName gomatrixserverlib.ServerName,
) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
delete(d.blacklistedServers, serverName)
return nil
}
func (d *InMemoryFederationDatabase) RemoveAllServersFromBlacklist() error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{})
return nil
}
func (d *InMemoryFederationDatabase) IsServerBlacklisted(
serverName gomatrixserverlib.ServerName,
) (bool, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
isBlacklisted := false
if _, ok := d.blacklistedServers[serverName]; ok {
isBlacklisted = true
}
return isBlacklisted, nil
}
func (d *InMemoryFederationDatabase) SetServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
d.assumedOffline[serverName] = struct{}{}
return nil
}
func (d *InMemoryFederationDatabase) RemoveServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
delete(d.assumedOffline, serverName)
return nil
}
func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffine(
ctx context.Context,
) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
d.assumedOffline = make(map[gomatrixserverlib.ServerName]struct{})
return nil
}
func (d *InMemoryFederationDatabase) IsServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) (bool, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
assumedOffline := false
if _, ok := d.assumedOffline[serverName]; ok {
assumedOffline = true
}
return assumedOffline, nil
}
func (d *InMemoryFederationDatabase) P2PGetRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) ([]gomatrixserverlib.ServerName, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
knownRelayServers := []gomatrixserverlib.ServerName{}
if relayServers, ok := d.relayServers[serverName]; ok {
knownRelayServers = relayServers
}
return knownRelayServers, nil
}
func (d *InMemoryFederationDatabase) P2PAddRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
if knownRelayServers, ok := d.relayServers[serverName]; ok {
for _, relayServer := range relayServers {
alreadyKnown := false
for _, knownRelayServer := range knownRelayServers {
if relayServer == knownRelayServer {
alreadyKnown = true
}
}
if !alreadyKnown {
d.relayServers[serverName] = append(d.relayServers[serverName], relayServer)
}
}
} else {
d.relayServers[serverName] = relayServers
}
return nil
}
func (d *InMemoryFederationDatabase) FetchKeys(ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
return nil, nil
}
func (d *InMemoryFederationDatabase) FetcherName() string {
return ""
}
func (d *InMemoryFederationDatabase) StoreKeys(ctx context.Context, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) error {
return nil
}
func (d *InMemoryFederationDatabase) UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error) {
return nil, nil
}
func (d *InMemoryFederationDatabase) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) {
return nil, nil
}
func (d *InMemoryFederationDatabase) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
return nil, nil
}
func (d *InMemoryFederationDatabase) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) {
return nil, nil
}
func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffline(ctx context.Context) error {
return nil
}
func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error {
return nil
}
func (d *InMemoryFederationDatabase) P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error {
return nil
}
func (d *InMemoryFederationDatabase) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
return nil
}
func (d *InMemoryFederationDatabase) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
return nil
}
func (d *InMemoryFederationDatabase) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) {
return nil, nil
}
func (d *InMemoryFederationDatabase) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) {
return nil, nil
}
func (d *InMemoryFederationDatabase) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
return nil
}
func (d *InMemoryFederationDatabase) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
return nil
}
func (d *InMemoryFederationDatabase) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) {
return nil, nil
}
func (d *InMemoryFederationDatabase) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) {
return nil, nil
}
func (d *InMemoryFederationDatabase) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error {
return nil
}
func (d *InMemoryFederationDatabase) GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) {
return nil, nil
}
func (d *InMemoryFederationDatabase) DeleteExpiredEDUs(ctx context.Context) error {
return nil
}
func (d *InMemoryFederationDatabase) PurgeRoom(ctx context.Context, roomID string) error {
return nil
}

140
test/memory_relay_db.go Normal file
View File

@ -0,0 +1,140 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"context"
"database/sql"
"encoding/json"
"sync"
"github.com/matrix-org/gomatrixserverlib"
)
type InMemoryRelayDatabase struct {
nid int64
nidMutex sync.Mutex
transactions map[int64]json.RawMessage
associations map[gomatrixserverlib.ServerName][]int64
}
func NewInMemoryRelayDatabase() *InMemoryRelayDatabase {
return &InMemoryRelayDatabase{
nid: 1,
nidMutex: sync.Mutex{},
transactions: make(map[int64]json.RawMessage),
associations: make(map[gomatrixserverlib.ServerName][]int64),
}
}
func (d *InMemoryRelayDatabase) InsertQueueEntry(
ctx context.Context,
txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
if _, ok := d.associations[serverName]; !ok {
d.associations[serverName] = []int64{}
}
d.associations[serverName] = append(d.associations[serverName], nid)
return nil
}
func (d *InMemoryRelayDatabase) DeleteQueueEntries(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
jsonNIDs []int64,
) error {
for _, nid := range jsonNIDs {
for index, associatedNID := range d.associations[serverName] {
if associatedNID == nid {
d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...)
}
}
}
return nil
}
func (d *InMemoryRelayDatabase) SelectQueueEntries(
ctx context.Context,
txn *sql.Tx, serverName gomatrixserverlib.ServerName,
limit int,
) ([]int64, error) {
results := []int64{}
resultCount := limit
if limit > len(d.associations[serverName]) {
resultCount = len(d.associations[serverName])
}
if resultCount > 0 {
for i := 0; i < resultCount; i++ {
results = append(results, d.associations[serverName][i])
}
}
return results, nil
}
func (d *InMemoryRelayDatabase) SelectQueueEntryCount(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) (int64, error) {
return int64(len(d.associations[serverName])), nil
}
func (d *InMemoryRelayDatabase) InsertQueueJSON(
ctx context.Context,
txn *sql.Tx,
json string,
) (int64, error) {
d.nidMutex.Lock()
defer d.nidMutex.Unlock()
nid := d.nid
d.transactions[nid] = []byte(json)
d.nid++
return nid, nil
}
func (d *InMemoryRelayDatabase) DeleteQueueJSON(
ctx context.Context,
txn *sql.Tx,
nids []int64,
) error {
for _, nid := range nids {
delete(d.transactions, nid)
}
return nil
}
func (d *InMemoryRelayDatabase) SelectQueueJSON(
ctx context.Context,
txn *sql.Tx,
jsonNIDs []int64,
) (map[int64][]byte, error) {
result := make(map[int64][]byte)
for _, nid := range jsonNIDs {
if transaction, ok := d.transactions[nid]; ok {
result[nid] = transaction
}
}
return result, nil
}

View File

@ -67,9 +67,10 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f
case test.DBTypeSQLite:
cfg.Defaults(config.DefaultOpts{
Generate: true,
Monolithic: false, // because we need a database per component
Monolithic: true,
})
cfg.Global.ServerName = "test"
// use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use
// the file system event with InMemory=true :(
cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType)
@ -83,6 +84,7 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f
cfg.RoomServer.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "roomserver.db"))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "syncapi.db"))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "userapi.db"))
cfg.RelayAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "relayapi.db"))
base := base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics)
return base, func() {