diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index b579ae1fa..ce018904f 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" @@ -92,12 +93,13 @@ func Setup( v2keysmux.Handle("/query", notaryKeys).Methods(http.MethodPost) v2keysmux.Handle("/query/{serverName}/{keyID}", notaryKeys).Methods(http.MethodGet) + mu := internal.NewMutexByRoom() v1fedmux.Handle("/send/{txnID}", httputil.MakeFedAPI( "federation_send", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return Send( httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]), - cfg, rsAPI, eduAPI, keyAPI, keys, federation, + cfg, rsAPI, eduAPI, keyAPI, keys, federation, mu, ) }, )).Methods(http.MethodPut, http.MethodOptions) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 231a16863..b48d6c0b8 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -26,6 +26,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/dendrite/clientapi/jsonerror" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" @@ -98,6 +99,7 @@ func Send( keyAPI keyapi.KeyInternalAPI, keys gomatrixserverlib.JSONVerifier, federation *gomatrixserverlib.FederationClient, + mu *internal.MutexByRoom, ) util.JSONResponse { t := txnReq{ rsAPI: rsAPI, @@ -107,6 +109,7 @@ func Send( haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), newEvents: make(map[string]bool), keyAPI: keyAPI, + roomsMu: mu, } var txnEvents struct { @@ -163,6 +166,7 @@ type txnReq struct { federation txnFederationClient servers []gomatrixserverlib.ServerName serversMutex sync.RWMutex + roomsMu *internal.MutexByRoom // local cache of events for auth checks, etc - this may include events // which the roomserver is unaware of. haveEvents map[string]*gomatrixserverlib.HeaderedEvent @@ -494,6 +498,8 @@ func (t *txnReq) getServers(ctx context.Context, roomID string) []gomatrixserver } func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) error { + t.roomsMu.Lock(e.RoomID()) + defer t.roomsMu.Unlock(e.RoomID()) logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) t.work = "" // reset from previous event diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 8bdf54c4a..b14cbd35a 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -9,6 +9,7 @@ import ( "time" eduAPI "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" @@ -370,6 +371,7 @@ func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederat federation: fedClient, haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), newEvents: make(map[string]bool), + roomsMu: internal.NewMutexByRoom(), } t.PDUs = pdus t.Origin = testOrigin diff --git a/internal/mutex.go b/internal/mutex.go new file mode 100644 index 000000000..3d36cdac9 --- /dev/null +++ b/internal/mutex.go @@ -0,0 +1,38 @@ +package internal + +import "sync" + +type MutexByRoom struct { + mu *sync.Mutex // protects the map + roomToMu map[string]*sync.Mutex +} + +func NewMutexByRoom() *MutexByRoom { + return &MutexByRoom{ + mu: &sync.Mutex{}, + roomToMu: make(map[string]*sync.Mutex), + } +} + +func (m *MutexByRoom) Lock(roomID string) { + m.mu.Lock() + roomMu := m.roomToMu[roomID] + if roomMu == nil { + roomMu = &sync.Mutex{} + } + m.roomToMu[roomID] = roomMu + m.mu.Unlock() + // don't lock inside m.mu else we can deadlock + roomMu.Lock() +} + +func (m *MutexByRoom) Unlock(roomID string) { + m.mu.Lock() + roomMu := m.roomToMu[roomID] + if roomMu == nil { + panic("MutexByRoom: Unlock before Lock") + } + m.mu.Unlock() + + roomMu.Unlock() +}