From 7a2e325d1014d76188b47a011730a42443f3c174 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 13 Jun 2023 16:28:41 +0200 Subject: [PATCH] Add `AssignRoomNID` to pre-assign roomNIDs (#3111) --- roomserver/storage/interface.go | 2 ++ roomserver/storage/shared/storage.go | 20 ++++++++++++++++++++ roomserver/storage/shared/storage_test.go | 22 ++++++++++++++++++++++ 3 files changed, 44 insertions(+) diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index ef4463781..7787d9f85 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -31,6 +31,7 @@ type Database interface { UserRoomKeys // Do we support processing input events for more than one room at a time? SupportsConcurrentRoomInputs() bool + AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) // RoomInfo returns room information for the given room ID, or nil if there is no room. RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) @@ -212,6 +213,7 @@ type UserRoomKeys interface { type RoomDatabase interface { EventDatabase UserRoomKeys + AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) // RoomInfo returns room information for the given room ID, or nil if there is no room. RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index d7ca3cefd..bda51da81 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -662,6 +662,26 @@ func (d *Database) IsEventRejected(ctx context.Context, roomNID types.RoomNID, e return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID) } +func (d *Database) AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) { + // This should already be checked, let's check it anyway. + _, err = gomatrixserverlib.GetRoomVersion(roomVersion) + if err != nil { + return 0, err + } + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomNID, err = d.assignRoomNID(ctx, txn, roomID.String(), roomVersion) + if err != nil { + return err + } + return nil + }) + if err != nil { + return 0, err + } + // Not setting caches, as assignRoomNID already does this + return roomNID, err +} + // GetOrCreateRoomInfo gets or creates a new RoomInfo, which is only safe to use with functions only needing a roomVersion or roomNID. func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (roomInfo *types.RoomInfo, err error) { // Get the default room version. If the client doesn't supply a room_version diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index 4fa451bcc..581d83ee4 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" @@ -199,3 +200,24 @@ func TestUserRoomKeys(t *testing.T) { assert.Error(t, err) }) } + +func TestAssignRoomNID(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + roomID, err := spec.NewRoomID(room.ID) + assert.NoError(t, err) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRoomserverDatabase(t, dbType) + defer close() + + nid, err := db.AssignRoomNID(ctx, *roomID, room.Version) + assert.NoError(t, err) + assert.Greater(t, nid, types.EventNID(0)) + + _, err = db.AssignRoomNID(ctx, spec.RoomID{}, "notaroomversion") + assert.Error(t, err) + }) +}