diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index ff124514e..1877de37a 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -181,7 +181,7 @@ func (s *OutputRoomEventConsumer) sendEvents( // Create the transaction body. transaction, err := json.Marshal( ApplicationServiceTransaction{ - Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), }, @@ -236,7 +236,11 @@ func (s *appserviceState) backoffAndPause(err error) error { // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool { user := "" - userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return false + } + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) if err == nil { user = userID.String() } diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index f01e24eca..d9129d1bd 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -233,11 +233,18 @@ func RemoveLocalAlias( } } - deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomIDRes.RoomID, *userID) + validRoomID, err := spec.NewRoomID(roomIDRes.RoomID) if err != nil { return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{Err: "Could not find SenderID for this device"}, + Code: http.StatusNotFound, + JSON: spec.NotFound("The alias does not exist."), + } + } + deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("The alias does not exist."), } } @@ -321,7 +328,15 @@ func SetVisibility( JSON: spec.BadJSON("userID for this device is invalid"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 03e85edbf..bafc37b67 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -64,7 +64,14 @@ func SendBan( JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, @@ -155,7 +162,14 @@ func SendKick( JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), } } - senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, @@ -428,7 +442,11 @@ func buildMembershipEvent( if err != nil { return nil, err } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID) if err != nil { return nil, err } @@ -437,7 +455,7 @@ func buildMembershipEvent( if err != nil { return nil, err } - targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *targetID) + targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *targetID) if err != nil { return nil, err } diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index e734e2e4f..8a44834e1 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -368,7 +368,11 @@ func buildMembershipEvents( return nil, err } for _, roomID := range roomIDs { - senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID) if err != nil { return nil, err } diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index da48e84de..42f029395 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -54,7 +54,14 @@ func SendRedaction( JSON: spec.Forbidden("userID doesn't have power level to redact"), } } - senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID) if queryErr != nil { return util.JSONResponse{ Code: http.StatusForbidden, @@ -103,8 +110,8 @@ func SendRedaction( JSON: spec.Forbidden("You don't have permission to redact this event, no power_levels event in this room."), } } - pl, err := plEvent.PowerLevels() - if err != nil { + pl, plErr := plEvent.PowerLevels() + if plErr != nil { return util.JSONResponse{ Code: 403, JSON: spec.Forbidden( @@ -134,7 +141,7 @@ func SendRedaction( Type: spec.MRoomRedaction, Redacts: eventID, } - err := proto.SetContent(r) + err = proto.SetContent(r) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("proto.SetContent failed") return util.JSONResponse{ diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 4d0a9f24a..d51a570de 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -273,7 +273,14 @@ func generateSendEvent( JSON: spec.BadJSON("Bad userID"), } } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("RoomID is invalid"), + } + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusNotFound, @@ -344,8 +351,8 @@ func generateSendEvent( stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) - if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, *validRoomID, senderID) }); err != nil { return nil, &util.JSONResponse{ Code: http.StatusForbidden, diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index e3a209b6e..f53cb3013 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -150,7 +150,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a for _, ev := range stateRes.StateEvents { stateEvents = append( stateEvents, - synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }, ev), ) @@ -173,14 +173,19 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a } for _, ev := range stateAfterRes.StateEvents { sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) + evRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Event roomID is invalid") + continue + } + userID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, ev.SenderID()) if err == nil && userID != nil { sender = *userID } sk := ev.StateKey() if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) + skUserID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, spec.SenderID(*ev.StateKey())) if err == nil && skUserID != nil { skString := skUserID.String() sk = &skString @@ -367,7 +372,7 @@ func OnIncomingStateTypeRequest( } stateEvent := stateEventInStateResp{ - ClientEvent: synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + ClientEvent: synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }, event), } diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index e7ffbac2b..d15cc6d46 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -359,7 +359,11 @@ func emit3PIDInviteEvent( if err != nil { return err } - sender, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return err + } + sender, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID) if err != nil { return err } diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index 15c87f1a8..3ffcac9e6 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -11,11 +11,13 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" @@ -66,10 +68,14 @@ func main() { panic(err) } + natsInstance := &jetstream.NATSInstance{} + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, + natsInstance, caching.NewRistrettoCache(128*1024*1024, time.Hour, true), false) + roomInfo := &types.RoomInfo{ RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion), } - stateres := state.NewStateResolution(roomserverDB, roomInfo) + stateres := state.NewStateResolution(roomserverDB, roomInfo, rsAPI) if *difference { if len(snapshotNIDs) != 2 { @@ -183,8 +189,8 @@ func main() { fmt.Println("Resolving state") var resolved Events resolved, err = gomatrixserverlib.ResolveConflicts( - gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return roomserverDB.GetUserIDForSender(ctx, roomID, senderID) + gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 173908437..5d167c0ee 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -36,11 +36,11 @@ type fedRoomserverAPI struct { queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error } -func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } -func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { +func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { return spec.SenderID(userID.String()), nil } diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 485b79a03..7f61dba41 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -154,14 +154,9 @@ func (r *FederationInternalAPI) performJoinUsingServer( if err != nil { return err } - senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, roomID, *user) - if err != nil { - return err - } joinInput := gomatrixserverlib.PerformJoinInput{ UserID: user, - SenderID: senderID, RoomID: room, ServerName: serverName, Content: content, @@ -169,12 +164,20 @@ func (r *FederationInternalAPI) performJoinUsingServer( PrivateKey: r.cfg.Matrix.PrivateKey, KeyID: r.cfg.Matrix.KeyID, KeyRing: r.keyRing, - EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), - UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }, + SenderIDCreator: func(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (spec.SenderID, error) { + key, keyErr := r.rsAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID) + if keyErr != nil { + return "", keyErr + } + + return spec.SenderID(spec.Base64Bytes(key).Encode()), nil + }, } response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput) @@ -368,7 +371,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // authenticate the state returned (check its auth events etc) // the equivalent of CheckSendJoinResponse() - userIDProvider := func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + userIDProvider := func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) } authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( @@ -459,7 +462,11 @@ func (r *FederationInternalAPI) PerformLeave( // Set all the fields to be what they should be, this should be a no-op // but it's possible that the remote server returned us something "odd" - senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, request.RoomID, *userID) + roomID, err := spec.NewRoomID(request.RoomID) + if err != nil { + return err + } + senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID) if err != nil { return err } @@ -527,7 +534,11 @@ func (r *FederationInternalAPI) SendInvite( event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState, ) (gomatrixserverlib.PDU, error) { - inviter, err := r.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return nil, err + } + inviter, err := r.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) if err != nil { return nil, err } diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 5b15f810d..e45209a2f 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -95,7 +95,7 @@ func InviteV2( StateQuerier: rsAPI.StateQuerier(), InviteEvent: inviteReq.Event(), StrippedState: inviteReq.InviteRoomState(), - UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } @@ -188,7 +188,7 @@ func InviteV1( StateQuerier: rsAPI.StateQuerier(), InviteEvent: event, StrippedState: strippedState, - UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index d14801921..7aa50f65a 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -98,7 +98,7 @@ func MakeJoin( Roomserver: rsAPI, } - senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID) + senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID, userID) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed") return util.JSONResponse{ @@ -118,7 +118,7 @@ func MakeJoin( LocalServerName: cfg.Matrix.ServerName, LocalServerInRoom: res.RoomExists && res.IsInRoom, RoomQuerier: &roomQuerier, - UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, BuildEventTemplate: createJoinTemplate, @@ -215,7 +215,7 @@ func SendJoin( PrivateKey: cfg.Matrix.PrivateKey, Verifier: keys, MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, - UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index 716276bec..5c8dd00f3 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -87,7 +87,7 @@ func MakeLeave( return event, stateEvents, nil } - senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID) + senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID, userID) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed") return util.JSONResponse{ @@ -105,7 +105,7 @@ func MakeLeave( LocalServerName: cfg.Matrix.ServerName, LocalServerInRoom: res.RoomExists && res.IsInRoom, BuildEventTemplate: createLeaveTemplate, - UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } @@ -236,7 +236,14 @@ func SendLeave( // Check that the sender belongs to the server that is sending us // the request. By this point we've already asserted that the sender // and the state key are equal so we don't need to check both. - sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("Room ID is invalid."), + } + } + sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, event.SenderID()) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index 360802de5..42ba8bfe5 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -140,7 +140,14 @@ func ExchangeThirdPartyInvite( } } - userID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(proto.SenderID)) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("Invalid room ID"), + } + } + userID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, spec.SenderID(proto.SenderID)) if err != nil || userID == nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -150,7 +157,7 @@ func ExchangeThirdPartyInvite( senderDomain := userID.Domain() // Check that the state key is correct. - targetUserID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(*proto.StateKey)) + targetUserID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, spec.SenderID(*proto.StateKey)) if err != nil || targetUserID == nil { return util.JSONResponse{ Code: http.StatusBadRequest, diff --git a/go.mod b/go.mod index 2fbae3148..930db3958 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230614140620-4dea2171c8f1 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 @@ -42,11 +42,11 @@ require ( github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.6 go.uber.org/atomic v1.10.0 - golang.org/x/crypto v0.9.0 + golang.org/x/crypto v0.10.0 golang.org/x/image v0.5.0 golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e golang.org/x/sync v0.1.0 - golang.org/x/term v0.8.0 + golang.org/x/term v0.9.0 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 gotest.tools/v3 v3.4.0 @@ -127,8 +127,8 @@ require ( golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/mod v0.8.0 // indirect golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.8.0 // indirect - golang.org/x/text v0.9.0 // indirect + golang.org/x/sys v0.9.0 // indirect + golang.org/x/text v0.10.0 // indirect golang.org/x/time v0.3.0 // indirect golang.org/x/tools v0.6.0 // indirect google.golang.org/protobuf v1.28.1 // indirect diff --git a/go.sum b/go.sum index ef8c298ab..cf6993938 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077 h1:AmKkAUjy9rZA2K+qHXm/O/dPEPnUYfRE2I6SL+Dj+LU= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230614140620-4dea2171c8f1 h1:k75Fy0iQVbDjvddip/x898+BdyopBNAfL1BMNx0awA0= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230614140620-4dea2171c8f1/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= @@ -511,8 +511,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= +golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -669,12 +669,12 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= +golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.9.0 h1:GRRCnKYhdQrD8kfRAdQ6Zcw1P0OcELxGLKJvtjVMZ28= +golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -683,8 +683,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= +golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index ac7608950..28dea97c4 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -115,7 +115,11 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati case SenderKind: userID := "" - sender, err := userIDForSender(event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return false, err + } + sender, err := userIDForSender(*validRoomID, event.SenderID()) if err == nil { userID = sender.String() } diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index 859d1f8a6..a4ccc3d0f 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -8,7 +8,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) -func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func UserIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } @@ -73,7 +73,7 @@ func TestRuleMatches(t *testing.T) { {"emptyOverride", OverrideKind, emptyRule, `{}`, true}, {"emptyContent", ContentKind, emptyRule, `{}`, false}, {"emptyRoom", RoomKind, emptyRule, `{}`, true}, - {"emptySender", SenderKind, emptyRule, `{}`, true}, + {"emptySender", SenderKind, emptyRule, `{"room_id":"!room:example.com"}`, true}, {"emptyUnderride", UnderrideKind, emptyRule, `{}`, true}, {"disabled", OverrideKind, Rule{}, `{}`, false}, @@ -90,8 +90,8 @@ func TestRuleMatches(t *testing.T) { {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!room:example.com"}`, true}, {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!otherroom:example.com"}`, false}, - {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com"}`, true}, - {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com"}`, false}, + {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com","room_id":"!room:example.com"}`, true}, + {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com","room_id":"!room:example.com"}`, false}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go index b2929bb5d..5bf7d819c 100644 --- a/internal/transactionrequest.go +++ b/internal/transactionrequest.go @@ -167,7 +167,7 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut } continue } - if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go index 1d32c8060..ffc1cd89a 100644 --- a/internal/transactionrequest_test.go +++ b/internal/transactionrequest_test.go @@ -70,7 +70,7 @@ type FakeRsAPI struct { bannedFromRoom bool } -func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } @@ -642,7 +642,7 @@ type testRoomserverAPI struct { queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse } -func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index fec28841e..e2dd5dd73 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -51,6 +51,7 @@ type RoomserverInternalAPI interface { UserRoomserverAPI FederationRoomserverAPI QuerySenderIDAPI + UserRoomPrivateKeyCreator // needed to avoid chicken and egg scenario when setting up the // interdependencies between the roomserver and other input APIs @@ -67,7 +68,9 @@ type RoomserverInternalAPI interface { req *QueryAuthChainRequest, res *QueryAuthChainResponse, ) error +} +type UserRoomPrivateKeyCreator interface { // GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created. GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error) } @@ -81,8 +84,8 @@ type InputRoomEventsAPI interface { } type QuerySenderIDAPI interface { - QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) - QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) + QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) + QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) } // Query the latest events and state for a room from the room server. @@ -228,6 +231,7 @@ type FederationRoomserverAPI interface { QueryLatestEventsAndStateAPI QueryBulkStateContentAPI QuerySenderIDAPI + UserRoomPrivateKeyCreator // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go index ba10a4332..d6c10cf92 100644 --- a/roomserver/auth/auth.go +++ b/roomserver/auth/auth.go @@ -15,7 +15,7 @@ package auth import ( "context" - "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" ) @@ -25,7 +25,7 @@ import ( // IsServerAllowed returns true if the server is allowed to see events in the room // at this particular state. This function implements https://matrix.org/docs/spec/client_server/r0.6.0#id87 func IsServerAllowed( - ctx context.Context, db storage.RoomDatabase, + ctx context.Context, querier api.QuerySenderIDAPI, serverName spec.ServerName, serverCurrentlyInRoom bool, authEvents []gomatrixserverlib.PDU, @@ -41,7 +41,7 @@ func IsServerAllowed( return true } // 2. If the user's membership was join, allow. - joinedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Join) + joinedUserExists := IsAnyUserOnServerWithMembership(ctx, querier, serverName, authEvents, spec.Join) if joinedUserExists { return true } @@ -50,7 +50,7 @@ func IsServerAllowed( return true } // 4. If the user's membership was invite, and the history_visibility was set to invited, allow. - invitedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Invite) + invitedUserExists := IsAnyUserOnServerWithMembership(ctx, querier, serverName, authEvents, spec.Invite) if invitedUserExists && historyVisibility == gomatrixserverlib.HistoryVisibilityInvited { return true } @@ -74,7 +74,7 @@ func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.PDU) gomatrixserver return visibility } -func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabase, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool { +func IsAnyUserOnServerWithMembership(ctx context.Context, querier api.QuerySenderIDAPI, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool { for _, ev := range authEvents { if ev.Type() != spec.MRoomMember { continue @@ -89,7 +89,11 @@ func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabas continue } - userID, err := db.GetUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey)) + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + continue + } + userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*stateKey)) if err != nil { continue } diff --git a/roomserver/auth/auth_test.go b/roomserver/auth/auth_test.go index 192d9e5da..058361e6e 100644 --- a/roomserver/auth/auth_test.go +++ b/roomserver/auth/auth_test.go @@ -4,17 +4,17 @@ import ( "context" "testing" - "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" ) -type FakeStorageDB struct { - storage.RoomDatabase +type FakeQuerier struct { + api.QuerySenderIDAPI } -func (f *FakeStorageDB) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (f *FakeQuerier) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } @@ -87,7 +87,7 @@ func TestIsServerAllowed(t *testing.T) { authEvents = append(authEvents, ev.PDU) } - if got := IsServerAllowed(context.Background(), &FakeStorageDB{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want { + if got := IsServerAllowed(context.Background(), &FakeQuerier{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want { t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want) } }) diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index c950024ad..e6fb73383 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -113,6 +113,7 @@ func (r *RoomserverInternalAPI) GetAliasesForRoomID( return nil } +// nolint:gocyclo // RemoveRoomAlias implements alias.RoomserverInternalAPI func (r *RoomserverInternalAPI) RemoveRoomAlias( ctx context.Context, @@ -129,7 +130,12 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return nil } - sender, err := r.QueryUserIDForSender(ctx, roomID, request.SenderID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return err + } + + sender, err := r.QueryUserIDForSender(ctx, *validRoomID, request.SenderID) if err != nil || sender == nil { return fmt.Errorf("r.QueryUserIDForSender: %w", err) } @@ -177,7 +183,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( if request.SenderID != ev.SenderID() { senderID = ev.SenderID() } - sender, err := r.QueryUserIDForSender(ctx, roomID, senderID) + sender, err := r.QueryUserIDForSender(ctx, *validRoomID, senderID) if err != nil || sender == nil { return err } @@ -206,7 +212,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( } stateRes := &api.QueryLatestEventsAndStateResponse{} - if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil { + if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil { return err } diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 4bcd3f3ed..7943ae5c0 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -177,6 +177,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio IsLocalServerName: r.Cfg.Global.IsLocalServerName, DB: r.DB, FSAPI: r.fsAPI, + Querier: r.Queryer, KeyRing: r.KeyRing, // Perspective servers are trusted to not lie about server keys, so we will also // prefer these servers when backfilling (assuming they are in the room) rather diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 7782d07d2..89fae244f 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" @@ -36,6 +37,7 @@ func CheckForSoftFail( roomInfo *types.RoomInfo, event *types.HeaderedEvent, stateEventIDs []string, + querier api.QuerySenderIDAPI, ) (bool, error) { rewritesState := len(stateEventIDs) > 1 @@ -49,7 +51,7 @@ func CheckForSoftFail( } else { // Then get the state entries for the current state snapshot. // We'll use this to check if the event is allowed right now. - roomState := state.NewStateResolution(db, roomInfo) + roomState := state.NewStateResolution(db, roomInfo, querier) authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID()) if err != nil { return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err) @@ -76,8 +78,8 @@ func CheckForSoftFail( } // Check if the event is allowed. - if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return db.GetUserIDForSender(ctx, roomID, senderID) + if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return querier.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { // return true, nil return true, err diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 263cb9f85..febabf411 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -68,7 +68,7 @@ func UpdateToInviteMembership( // memberships. If the servername is not supplied then the local server will be // checked instead using a faster code path. // TODO: This should probably be replaced by an API call. -func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName spec.ServerName, roomID string) (bool, error) { +func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, serverName spec.ServerName, roomID string) (bool, error) { info, err := db.RoomInfo(ctx, roomID) if err != nil { return false, err @@ -94,7 +94,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam for i := range events { gmslEvents[i] = events[i].PDU } - return auth.IsAnyUserOnServerWithMembership(ctx, db, serverName, gmslEvents, spec.Join), nil + return auth.IsAnyUserOnServerWithMembership(ctx, querier, serverName, gmslEvents, spec.Join), nil } func IsInvitePending( @@ -211,8 +211,8 @@ func GetMembershipsAtState( return events, nil } -func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { - roomState := state.NewStateResolution(db, info) +func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID, querier api.QuerySenderIDAPI) ([]types.StateEntry, error) { + roomState := state.NewStateResolution(db, info, querier) // Lookup the event NID eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) if err != nil { @@ -229,8 +229,8 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room return roomState.LoadCombinedStateAfterEvents(ctx, prevState) } -func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) { - roomState := state.NewStateResolution(db, info) +func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID, querier api.QuerySenderIDAPI) (map[string][]types.StateEntry, error) { + roomState := state.NewStateResolution(db, info, querier) // Fetch the state as it was when this event was fired return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID) } @@ -264,7 +264,7 @@ func LoadStateEvents( } func CheckServerAllowedToSeeEvent( - ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool, + ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool, querier api.QuerySenderIDAPI, ) (bool, error) { stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName)) switch err { @@ -273,7 +273,7 @@ func CheckServerAllowedToSeeEvent( case tables.OptimisationNotSupportedError: // The database engine didn't support this optimisation, so fall back to using // the old and slow method - stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName) + stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName, querier) if err != nil { return false, err } @@ -288,13 +288,13 @@ func CheckServerAllowedToSeeEvent( return false, err } } - return auth.IsServerAllowed(ctx, db, serverName, isServerInRoom, stateAtEvent), nil + return auth.IsServerAllowed(ctx, querier, serverName, isServerInRoom, stateAtEvent), nil } func slowGetHistoryVisibilityState( - ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName, + ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName, querier api.QuerySenderIDAPI, ) ([]gomatrixserverlib.PDU, error) { - roomState := state.NewStateResolution(db, info) + roomState := state.NewStateResolution(db, info, querier) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -318,9 +318,13 @@ func slowGetHistoryVisibilityState( // If the event state key doesn't match the given servername // then we'll filter it out. This does preserve state keys that // are "" since these will contain history visibility etc. + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } for nid, key := range stateKeys { if key != "" { - userID, err := db.GetUserIDForSender(ctx, roomID, spec.SenderID(key)) + userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(key)) if err == nil && userID != nil { if userID.Domain() != serverName { delete(stateKeys, nid) @@ -349,7 +353,7 @@ func slowGetHistoryVisibilityState( // TODO: Remove this when we have tests to assert correctness of this function func ScanEventTree( ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int, - serverName spec.ServerName, + serverName spec.ServerName, querier api.QuerySenderIDAPI, ) ([]types.EventNID, map[string]struct{}, error) { var resultNIDs []types.EventNID var err error @@ -392,7 +396,7 @@ BFSLoop: // It's nasty that we have to extract the room ID from an event, but many federation requests // only talk in event IDs, no room IDs at all (!!!) ev := events[0] - isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID()) + isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, querier, serverName, ev.RoomID()) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.") } @@ -415,7 +419,7 @@ BFSLoop: // hasn't been seen before. if !visited[pre] { visited[pre] = true - allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom) + allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom, querier) if err != nil { util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( "Error checking if allowed to see event", @@ -444,7 +448,7 @@ BFSLoop: } func QueryLatestEventsAndState( - ctx context.Context, db storage.Database, + ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, ) error { @@ -457,7 +461,7 @@ func QueryLatestEventsAndState( return nil } - roomState := state.NewStateResolution(db, roomInfo) + roomState := state.NewStateResolution(db, roomInfo, querier) response.RoomExists = true response.RoomVersion = roomInfo.RoomVersion diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 7bb401632..aa05d9594 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -128,7 +128,11 @@ func (r *Inputer) processRoomEvent( if roomInfo == nil && !isCreateEvent { return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) } - sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return err + } + sender, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) if err != nil { return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err) } @@ -282,8 +286,8 @@ func (r *Inputer) processRoomEvent( // Check if the event is allowed by its auth events. If it isn't then // we consider the event to be "rejected" — it will still be persisted. - if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { isRejected = true rejectionErr = err @@ -321,7 +325,7 @@ func (r *Inputer) processRoomEvent( if input.Kind == api.KindNew && !isCreateEvent { // Check that the event passes authentication checks based on the // current room state. - softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs) + softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs, r.Queryer) if err != nil { logger.WithError(err).Warn("Error authing soft-failed event") } @@ -401,7 +405,7 @@ func (r *Inputer) processRoomEvent( redactedEvent gomatrixserverlib.PDU ) if !isRejected && !isCreateEvent { - resolver := state.NewStateResolution(r.DB, roomInfo) + resolver := state.NewStateResolution(r.DB, roomInfo, r.Queryer) redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver) if err != nil { return err @@ -587,8 +591,8 @@ func (r *Inputer) processStateBefore( stateBeforeAuth := gomatrixserverlib.NewAuthEvents( gomatrixserverlib.ToPDUs(stateBeforeEvent), ) - if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); rejectionErr != nil { rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr) return @@ -700,8 +704,8 @@ nextAuthEvent: // Check the signatures of the event. If this fails then we'll simply // skip it, because gomatrixserverlib.Allowed() will notice a problem // if a critical event is missing anyway. - if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { continue nextAuthEvent } @@ -718,8 +722,8 @@ nextAuthEvent: } // Check if the auth event should be rejected. - err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }) if isRejected = err != nil; isRejected { logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) @@ -783,7 +787,7 @@ func (r *Inputer) calculateAndSetState( return fmt.Errorf("r.DB.GetRoomUpdater: %w", err) } defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) - roomState := state.NewStateResolution(updater, roomInfo) + roomState := state.NewStateResolution(updater, roomInfo, r.Queryer) if input.HasState { // We've been told what the state at the event is so we don't need to calculate it. @@ -836,13 +840,18 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r return err } + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return err + } + prevEvents := latestRes.LatestEvents for _, memberEvent := range memberEvents { if memberEvent.StateKey() == nil { continue } - memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, memberEvent.RoomID(), spec.SenderID(*memberEvent.StateKey())) + memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*memberEvent.StateKey())) if err != nil { continue } diff --git a/roomserver/internal/input/input_events_test.go b/roomserver/internal/input/input_events_test.go index 5f2cd9562..4ee6d2110 100644 --- a/roomserver/internal/input/input_events_test.go +++ b/roomserver/internal/input/input_events_test.go @@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) { } // Finally check that the event is NOT allowed - if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) }); err == nil { t.Fatalf("event should not be allowed, but it was") diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 7a7a021a3..940783e03 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -213,7 +213,7 @@ func (u *latestEventsUpdater) latestState() error { defer trace.EndRegion() var err error - roomState := state.NewStateResolution(u.updater, u.roomInfo) + roomState := state.NewStateResolution(u.updater, u.roomInfo, u.api.Queryer) // Work out if the state at the extremities has actually changed // or not. If they haven't then we won't bother doing all of the diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 09c65dfe9..c46f8dba1 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -139,7 +139,11 @@ func (r *Inputer) updateMembership( func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool { isTargetLocalUser := false if statekey := event.StateKey(); statekey != nil { - userID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*statekey)) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return isTargetLocalUser + } + userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*statekey)) if err != nil || userID == nil { return isTargetLocalUser } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index f0f974d26..7ee84e4c0 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -383,7 +383,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even defer trace.EndRegion() var res parsedRespState - roomState := state.NewStateResolution(t.db, t.roomInfo) + roomState := state.NewStateResolution(t.db, t.roomInfo, t.inputer.Queryer) stateAtEvents, err := t.db.StateAtEventIDs(ctx, []string{eventID}) if err != nil { t.log.WithError(err).Warnf("failed to get state after %s locally", eventID) @@ -473,8 +473,8 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion stateEventList = append(stateEventList, state.StateEvents...) } resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts( - roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomID, senderID) + roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { @@ -482,8 +482,8 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion } // apply the current event retryAllowedState: - if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomID, senderID) + if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { switch missing := err.(type) { case gomatrixserverlib.MissingAuthEventError: @@ -569,8 +569,8 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver // will be added and duplicates will be removed. missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events)) for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomID, senderID) + if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { continue } @@ -660,8 +660,8 @@ func (t *missingStateReq) lookupMissingStateViaState( authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{ StateEvents: state.GetStateEvents(), AuthEvents: state.GetAuthEvents(), - }, roomVersion, t.keys, nil, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomID, senderID) + }, roomVersion, t.keys, nil, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }) if err != nil { return nil, err @@ -897,8 +897,8 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) } - if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomID, senderID) + if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) return nil, verifySigError{event.EventID(), err} diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index ec13bff87..12b557f51 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -74,6 +74,10 @@ func (r *Admin) PerformAdminEvacuateRoom( if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { return nil, err } + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } prevEvents := latestRes.LatestEvents var senderDomain spec.ServerName @@ -100,7 +104,7 @@ func (r *Admin) PerformAdminEvacuateRoom( PrevEvents: prevEvents, } - userID, err := r.Queryer.QueryUserIDForSender(ctx, roomID, spec.SenderID(fledglingEvent.SenderID)) + userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(fledglingEvent.SenderID)) if err != nil || userID == nil { continue } @@ -264,16 +268,16 @@ func (r *Admin) PerformAdminDownloadState( return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err) } for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { continue } authEventMap[authEvent.EventID()] = authEvent } for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { continue } @@ -293,7 +297,11 @@ func (r *Admin) PerformAdminDownloadState( stateIDs = append(stateIDs, stateEvent.EventID()) } - senderID, err := r.Queryer.QuerySenderIDForUser(ctx, roomID, *fullUserID) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return err + } + senderID, err := r.Queryer.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID) if err != nil { return err } diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 8e87359a3..533ad25bf 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -42,6 +42,7 @@ type Backfiller struct { DB storage.Database FSAPI federationAPI.RoomserverFederationAPI KeyRing gomatrixserverlib.JSONVerifier + Querier api.QuerySenderIDAPI // The servers which should be preferred above other servers when backfilling PreferServers []spec.ServerName @@ -79,7 +80,7 @@ func (r *Backfiller) PerformBackfill( } // Scan the event tree for events to send back. - resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) + resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName, r.Querier) if err != nil { return err } @@ -113,7 +114,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform if info == nil || info.IsStub() { return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID) } - requester := newBackfillRequester(r.DB, r.FSAPI, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers) + requester := newBackfillRequester(r.DB, r.FSAPI, r.Querier, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers) // Request 100 items regardless of what the query asks for. // We don't want to go much higher than this. // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass @@ -121,8 +122,8 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // Specifically the test "Outbound federation can backfill events" events, err := gomatrixserverlib.RequestBackfill( ctx, req.VirtualHost, requester, - r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Querier.QueryUserIDForSender(ctx, roomID, senderID) }, ) // Only return an error if we really couldn't get any events. @@ -135,7 +136,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) // persist these new events - auth checks have already been done - roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) + roomNID, backfilledEventMap := persistEvents(ctx, r.DB, r.Querier, events) for _, ev := range backfilledEventMap { // now add state for these events @@ -212,8 +213,8 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom continue } loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) - result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.Querier.QueryUserIDForSender(ctx, roomID, senderID) }) if err != nil { logger.WithError(err).Warn("failed to load and verify event") @@ -246,13 +247,14 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom } } util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents)) - persistEvents(ctx, r.DB, newEvents) + persistEvents(ctx, r.DB, r.Querier, newEvents) } // backfillRequester implements gomatrixserverlib.BackfillRequester type backfillRequester struct { db storage.Database fsAPI federationAPI.RoomserverFederationAPI + querier api.QuerySenderIDAPI virtualHost spec.ServerName isLocalServerName func(spec.ServerName) bool preferServer map[spec.ServerName]bool @@ -268,6 +270,7 @@ type backfillRequester struct { func newBackfillRequester( db storage.Database, fsAPI federationAPI.RoomserverFederationAPI, + querier api.QuerySenderIDAPI, virtualHost spec.ServerName, isLocalServerName func(spec.ServerName) bool, bwExtrems map[string][]string, preferServers []spec.ServerName, @@ -279,6 +282,7 @@ func newBackfillRequester( return &backfillRequester{ db: db, fsAPI: fsAPI, + querier: querier, virtualHost: virtualHost, isLocalServerName: isLocalServerName, eventIDToBeforeStateIDs: make(map[string][]string), @@ -460,14 +464,14 @@ FindSuccessor: return nil } - stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID) + stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID, b.querier) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") return nil } // possibly return all joined servers depending on history visiblity - memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, info, stateEntries, b.virtualHost) + memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, b.querier, info, stateEntries, b.virtualHost) b.historyVisiblity = visibility if err != nil { logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") @@ -488,7 +492,11 @@ FindSuccessor: // Store the server names in a temporary map to avoid duplicates. serverSet := make(map[spec.ServerName]bool) for _, event := range memberEvents { - if sender, err := b.db.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil { + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + continue + } + if sender, err := b.querier.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()); err == nil { serverSet[sender.Domain()] = true } } @@ -554,7 +562,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, // TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just // pull all events and then filter by that table. func joinEventsFromHistoryVisibility( - ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, + ctx context.Context, db storage.RoomDatabase, querier api.QuerySenderIDAPI, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, thisServer spec.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) { var eventNIDs []types.EventNID @@ -582,7 +590,7 @@ func joinEventsFromHistoryVisibility( } // Can we see events in the room? - canSeeEvents := auth.IsServerAllowed(ctx, db, thisServer, true, events) + canSeeEvents := auth.IsServerAllowed(ctx, querier, thisServer, true, events) visibility := auth.HistoryVisibilityForRoom(events) if !canSeeEvents { logrus.Infof("ServersAtEvent history not visible to us: %s", visibility) @@ -597,7 +605,7 @@ func joinEventsFromHistoryVisibility( return evs, visibility, err } -func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.PDU) (types.RoomNID, map[string]types.Event) { +func persistEvents(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, events []gomatrixserverlib.PDU) (types.RoomNID, map[string]types.Event) { var roomNID types.RoomNID var eventNID types.EventNID backfilledEventMap := make(map[string]types.Event) @@ -639,7 +647,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []gomatrixse continue } - resolver := state.NewStateResolution(db, roomInfo) + resolver := state.NewStateResolution(db, roomInfo, querier) _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver) if err != nil { diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index 121b257ed..fd8055e09 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -63,13 +63,20 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } } - senderID, err := c.DB.GetSenderIDForUser(ctx, roomID.String(), userID) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user") - return "", &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, + var senderID spec.SenderID + if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + // create user room key if needed + key, keyErr := c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID) + if keyErr != nil { + util.GetLogger(ctx).WithError(keyErr).Error("GetOrCreateUserRoomPrivateKey failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } } + senderID = spec.SenderID(spec.Base64Bytes(key).Encode()) + } else { + senderID = spec.SenderID(userID.String()) } createContent["creator"] = senderID createContent["room_version"] = createRequest.RoomVersion @@ -323,8 +330,8 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } - if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return c.DB.GetUserIDForSender(ctx, roomID, senderID) + if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return c.RSAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") return "", &util.JSONResponse{ @@ -364,18 +371,6 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } - // create user room key if needed - if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { - _, err = c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") - return "", &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - } - // send the remaining events if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs[1:], false); err != nil { util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") @@ -455,7 +450,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo JSON: spec.InternalServerError{}, } } - inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), *inviteeUserID) + inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID, *inviteeUserID) if queryErr != nil { util.GetLogger(ctx).WithError(queryErr).Error("rsapi.QuerySenderIDForUser failed") return "", &util.JSONResponse{ diff --git a/roomserver/internal/perform/perform_inbound_peek.go b/roomserver/internal/perform/perform_inbound_peek.go index 3ac0f6f4d..7fbec3710 100644 --- a/roomserver/internal/perform/perform_inbound_peek.go +++ b/roomserver/internal/perform/perform_inbound_peek.go @@ -79,7 +79,7 @@ func (r *InboundPeeker) PerformInboundPeek( response.LatestEvent = &types.HeaderedEvent{PDU: sortedLatestEvents[0]} // XXX: do we actually need to do a state resolution here? - roomState := state.NewStateResolution(r.DB, info) + roomState := state.NewStateResolution(r.DB, info, r.Inputer.Queryer) var stateEntries []types.StateEntry stateEntries, err = roomState.LoadStateAtSnapshot( diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index cc2c5c191..babd5f812 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -34,6 +34,7 @@ import ( type QueryState struct { storage.Database + querier api.QuerySenderIDAPI } func (q *QueryState) GetAuthEvents(ctx context.Context, event gomatrixserverlib.PDU) (gomatrixserverlib.AuthEventProvider, error) { @@ -46,7 +47,7 @@ func (q *QueryState) GetState(ctx context.Context, roomID spec.RoomID, stateWant return nil, fmt.Errorf("failed to load RoomInfo: %w", err) } if info != nil { - roomState := state.NewStateResolution(q.Database, info) + roomState := state.NewStateResolution(q.Database, info, q.querier) stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( ctx, info.StateSnapshotNID(), stateWanted, ) @@ -98,7 +99,11 @@ func (r *Inviter) ProcessInviteMembership( var outputUpdates []api.OutputEvent var updater *shared.MembershipUpdater - userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey())) + validRoomID, err := spec.NewRoomID(inviteEvent.RoomID()) + if err != nil { + return nil, err + } + userID, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey())) if err != nil { return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} } @@ -126,7 +131,12 @@ func (r *Inviter) PerformInvite( ) error { event := req.Event - sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return err + } + + sender, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) if err != nil { return spec.InvalidParam("The sender user ID is invalid") } @@ -137,18 +147,13 @@ func (r *Inviter) PerformInvite( if event.StateKey() == nil || *event.StateKey() == "" { return fmt.Errorf("invite must be a state event") } - invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey())) if err != nil || invitedUser == nil { return spec.InvalidParam("Could not find the matching senderID for this user") } isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain()) - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return err - } - - invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, event.RoomID(), *invitedUser) + invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, *validRoomID, *invitedUser) if err != nil { return fmt.Errorf("failed looking up senderID for invited user") } @@ -161,9 +166,9 @@ func (r *Inviter) PerformInvite( IsTargetLocal: isTargetLocal, StrippedState: req.InviteRoomState, MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, - StateQuerier: &QueryState{r.DB}, - UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + StateQuerier: &QueryState{r.DB, r.RSAPI}, + UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.RSAPI.QueryUserIDForSender(ctx, roomID, senderID) }, } inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI) diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 74ed87c74..5867ee6e0 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -25,6 +25,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -174,44 +175,6 @@ func (r *Joiner) performJoinRoomByID( req.ServerNames = append(req.ServerNames, roomID.Domain()) } - // Prepare the template for the join event. - userID, err := spec.NewUserID(req.UserID, true) - if err != nil { - return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)} - } - senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomIDOrAlias, *userID) - if err != nil { - return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)} - } - senderIDString := string(senderID) - userDomain := userID.Domain() - proto := gomatrixserverlib.ProtoEvent{ - Type: spec.MRoomMember, - SenderID: senderIDString, - StateKey: &senderIDString, - RoomID: req.RoomIDOrAlias, - Redacts: "", - } - if err = proto.SetUnsigned(struct{}{}); err != nil { - return "", "", fmt.Errorf("eb.SetUnsigned: %w", err) - } - - // It is possible for the request to include some "content" for the - // event. We'll always overwrite the "membership" key, but the rest, - // like "display_name" or "avatar_url", will be kept if supplied. - if req.Content == nil { - req.Content = map[string]interface{}{} - } - req.Content["membership"] = spec.Join - if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil { - return "", "", aerr - } else if authorisedVia != "" { - req.Content["join_authorised_via_users_server"] = authorisedVia - } - if err = proto.SetContent(req.Content); err != nil { - return "", "", fmt.Errorf("eb.SetContent: %w", err) - } - // Force a federated join if we aren't in the room and we've been // given some server names to try joining by. inRoomReq := &rsAPI.QueryServerJoinedToRoomRequest{ @@ -224,29 +187,63 @@ func (r *Joiner) performJoinRoomByID( serverInRoom := inRoomRes.IsInRoom forceFederatedJoin := len(req.ServerNames) > 0 && !serverInRoom + userID, err := spec.NewUserID(req.UserID, true) + if err != nil { + return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)} + } + + // Look up the room NID for the supplied room ID. + var senderID spec.SenderID + checkInvitePending := false + info, err := r.DB.RoomInfo(ctx, req.RoomIDOrAlias) + if err == nil && info != nil { + switch info.RoomVersion { + case gomatrixserverlib.RoomVersionPseudoIDs: + senderID, err = r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID) + if err == nil { + checkInvitePending = true + } else { + // create user room key if needed + key, keyErr := r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID) + if keyErr != nil { + util.GetLogger(ctx).WithError(keyErr).Error("GetOrCreateUserRoomPrivateKey failed") + return "", "", fmt.Errorf("GetOrCreateUserRoomPrivateKey failed: %w", keyErr) + } + senderID = spec.SenderID(spec.Base64Bytes(key).Encode()) + } + default: + checkInvitePending = true + senderID = spec.SenderID(userID.String()) + } + } + + userDomain := userID.Domain() + // Force a federated join if we're dealing with a pending invite // and we aren't in the room. - isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID) - if err == nil && !serverInRoom && isInvitePending { - inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomIDOrAlias, inviteSender) - if queryErr != nil { - return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr) - } + if checkInvitePending { + isInvitePending, inviteSender, _, inviteEvent, inviteErr := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID) + if inviteErr == nil && !serverInRoom && isInvitePending { + inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, inviteSender) + if queryErr != nil { + return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr) + } - // If we were invited by someone from another server then we can - // assume they are in the room so we can join via them. - if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) { - req.ServerNames = append(req.ServerNames, inviter.Domain()) - forceFederatedJoin = true - memberEvent := gjson.Parse(string(inviteEvent.JSON())) - // only set unsigned if we've got a content.membership, which we _should_ - if memberEvent.Get("content.membership").Exists() { - req.Unsigned = map[string]interface{}{ - "prev_sender": memberEvent.Get("sender").Str, - "prev_content": map[string]interface{}{ - "is_direct": memberEvent.Get("content.is_direct").Bool(), - "membership": memberEvent.Get("content.membership").Str, - }, + // If we were invited by someone from another server then we can + // assume they are in the room so we can join via them. + if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) { + req.ServerNames = append(req.ServerNames, inviter.Domain()) + forceFederatedJoin = true + memberEvent := gjson.Parse(string(inviteEvent.JSON())) + // only set unsigned if we've got a content.membership, which we _should_ + if memberEvent.Get("content.membership").Exists() { + req.Unsigned = map[string]interface{}{ + "prev_sender": memberEvent.Get("sender").Str, + "prev_content": map[string]interface{}{ + "is_direct": memberEvent.Get("content.is_direct").Bool(), + "membership": memberEvent.Get("content.membership").Str, + }, + } } } } @@ -274,6 +271,7 @@ func (r *Joiner) performJoinRoomByID( // If we should do a forced federated join then do that. var joinedVia spec.ServerName if forceFederatedJoin { + // TODO : pseudoIDs - pass through userID here since we don't know what the senderID should be yet joinedVia, err = r.performFederatedJoinRoomByID(ctx, req) return req.RoomIDOrAlias, joinedVia, err } @@ -289,19 +287,40 @@ func (r *Joiner) performJoinRoomByID( if err != nil { return "", "", fmt.Errorf("error joining local room: %q", err) } + + senderIDString := string(senderID) + + // Prepare the template for the join event. + proto := gomatrixserverlib.ProtoEvent{ + Type: spec.MRoomMember, + SenderID: senderIDString, + StateKey: &senderIDString, + RoomID: req.RoomIDOrAlias, + Redacts: "", + } + if err = proto.SetUnsigned(struct{}{}); err != nil { + return "", "", fmt.Errorf("eb.SetUnsigned: %w", err) + } + + // It is possible for the request to include some "content" for the + // event. We'll always overwrite the "membership" key, but the rest, + // like "display_name" or "avatar_url", will be kept if supplied. + if req.Content == nil { + req.Content = map[string]interface{}{} + } + req.Content["membership"] = spec.Join + if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil { + return "", "", aerr + } else if authorisedVia != "" { + req.Content["join_authorised_via_users_server"] = authorisedVia + } + if err = proto.SetContent(req.Content); err != nil { + return "", "", fmt.Errorf("eb.SetContent: %w", err) + } event, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, time.Now(), r.RSAPI, &buildRes) switch err.(type) { case nil: - // create user room key if needed - if buildRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { - _, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID) - if err != nil { - logrus.WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") - return "", "", fmt.Errorf("failed to get user room private key: %w", err) - } - } - // The room join is local. Send the new join event into the // roomserver. First of all check that the user isn't already // a member of the room. This is best-effort (as in we won't diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 1b23cc1ff..e1ddb9b50 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -78,7 +78,11 @@ func (r *Leaver) performLeaveRoomByID( req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, // nolint:unparam ) ([]api.OutputEvent, error) { - leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, req.Leaver) + roomID, err := spec.NewRoomID(req.RoomID) + if err != nil { + return nil, err + } + leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, *roomID, req.Leaver) if err != nil { return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String()) } @@ -87,7 +91,7 @@ func (r *Leaver) performLeaveRoomByID( // that. isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver) if err == nil && isInvitePending { - sender, serr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomID, senderUser) + sender, serr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, senderUser) if serr != nil || sender == nil { return nil, fmt.Errorf("sender %q has no matching userID", senderUser) } @@ -133,7 +137,7 @@ func (r *Leaver) performLeaveRoomByID( }, } latestRes := api.QueryLatestEventsAndStateResponse{} - if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &latestReq, &latestRes); err != nil { + if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r.RSAPI, &latestReq, &latestRes); err != nil { return nil, err } if !latestRes.RoomExists { diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 1aaa42c94..32f547dc1 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -54,7 +54,11 @@ func (r *Upgrader) performRoomUpgrade( return "", err } - senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, userID) + fullRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return "", err + } + senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, *fullRoomID, userID) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user") return "", err @@ -488,7 +492,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, send } - if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err) @@ -569,7 +573,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, send stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(stateEvents) - if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client? diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index caea6b526..19fd456b5 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -16,6 +16,7 @@ package query import ( "context" + "crypto/ed25519" "database/sql" "errors" "fmt" @@ -89,7 +90,7 @@ func (r *Queryer) QueryLatestEventsAndState( request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, ) error { - return helpers.QueryLatestEventsAndState(ctx, r.DB, request, response) + return helpers.QueryLatestEventsAndState(ctx, r.DB, r, request, response) } // QueryStateAfterEvents implements api.RoomserverInternalAPI @@ -106,7 +107,7 @@ func (r *Queryer) QueryStateAfterEvents( return nil } - roomState := state.NewStateResolution(r.DB, info) + roomState := state.NewStateResolution(r.DB, info, r) response.RoomExists = true response.RoomVersion = info.RoomVersion @@ -159,8 +160,8 @@ func (r *Queryer) QueryStateAfterEvents( } stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.QueryUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { @@ -271,15 +272,15 @@ func (r *Queryer) QueryMembershipForUser( request *api.QueryMembershipForUserRequest, response *api.QueryMembershipForUserResponse, ) error { - senderID, err := r.DB.GetSenderIDForUser(ctx, request.RoomID, request.UserID) - if err != nil { - return err - } - roomID, err := spec.NewRoomID(request.RoomID) if err != nil { return err } + senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID) + if err != nil { + return err + } + return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response) } @@ -320,7 +321,7 @@ func (r *Queryer) QueryMembershipAtEvent( } response.Membership = make(map[string]*types.HeaderedEvent) - stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID]) + stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID], r) if err != nil { return fmt.Errorf("unable to get state before event: %w", err) } @@ -407,7 +408,7 @@ func (r *Queryer) QueryMembershipsForRoom( return fmt.Errorf("r.DB.Events: %w", err) } for _, event := range events { - clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.QueryUserIDForSender(ctx, roomID, senderID) }, event) response.JoinEvents = append(response.JoinEvents, clientEvent) @@ -445,7 +446,7 @@ func (r *Queryer) QueryMembershipsForRoom( events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs) } else { - stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID) + stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID, r) if err != nil { logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") return err @@ -458,7 +459,7 @@ func (r *Queryer) QueryMembershipsForRoom( } for _, event := range events { - clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.QueryUserIDForSender(ctx, roomID, senderID) }, event) response.JoinEvents = append(response.JoinEvents, clientEvent) @@ -532,7 +533,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent( } return helpers.CheckServerAllowedToSeeEvent( - ctx, r.DB, info, roomID, eventID, serverName, isInRoom, + ctx, r.DB, info, roomID, eventID, serverName, isInRoom, r, ) } @@ -573,7 +574,7 @@ func (r *Queryer) QueryMissingEvents( return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID) } - resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) + resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName, r) if err != nil { return err } @@ -651,8 +652,8 @@ func (r *Queryer) QueryStateAndAuthChain( if request.ResolveState { stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.QueryUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { @@ -673,7 +674,7 @@ func (r *Queryer) QueryStateAndAuthChain( // first bool: is rejected, second bool: state missing func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.PDU, bool, bool, error) { - roomState := state.NewStateResolution(r.DB, roomInfo) + roomState := state.NewStateResolution(r.DB, roomInfo, r) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { switch err.(type) { @@ -989,10 +990,46 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID) } -func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { - return r.DB.GetSenderIDForUser(ctx, roomID, userID) +func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { + version, err := r.DB.GetRoomVersion(ctx, roomID.String()) + if err != nil { + return "", err + } + + switch version { + case gomatrixserverlib.RoomVersionPseudoIDs: + key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID) + if err != nil { + return "", err + } + return spec.SenderID(spec.Base64Bytes(key).Encode()), nil + default: + return spec.SenderID(userID.String()), nil + } } -func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) +func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + userID, err := spec.NewUserID(string(senderID), true) + if err == nil { + return userID, nil + } + + bytes := spec.Base64Bytes{} + err = bytes.Decode(string(senderID)) + if err != nil { + return nil, err + } + queryMap := map[spec.RoomID][]ed25519.PublicKey{roomID: {ed25519.PublicKey(bytes)}} + result, err := r.DB.SelectUserIDsForPublicKeys(ctx, queryMap) + if err != nil { + return nil, err + } + + if userKeys, ok := result[roomID]; ok { + if userID, ok := userKeys[string(senderID)]; ok { + return spec.NewUserID(userID, true) + } + } + + return nil, nil } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 90c94bbce..077957fa1 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -516,6 +516,9 @@ func TestRedaction(t *testing.T) { t.Fatal(err) } + natsInstance := &jetstream.NATSInstance{} + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { authEvents := []types.EventNID{} @@ -551,7 +554,7 @@ func TestRedaction(t *testing.T) { } // Calculate the snapshotNID etc. - plResolver := state.NewStateResolution(db, roomInfo) + plResolver := state.NewStateResolution(db, roomInfo, rsAPI) stateAtEvent.BeforeStateSnapshotNID, err = plResolver.CalculateAndStoreStateBeforeEvent(ctx, ev.PDU, false) assert.NoError(t, err) diff --git a/roomserver/state/state.go b/roomserver/state/state.go index b9c5bbc4a..1e776ff6c 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -29,6 +29,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -44,20 +45,21 @@ type StateResolutionStorage interface { AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) - GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) } type StateResolution struct { db StateResolutionStorage roomInfo *types.RoomInfo events map[types.EventNID]gomatrixserverlib.PDU + Querier api.QuerySenderIDAPI } -func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution { +func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo, querier api.QuerySenderIDAPI) StateResolution { return StateResolution{ db: db, roomInfo: roomInfo, events: make(map[types.EventNID]gomatrixserverlib.PDU), + Querier: querier, } } @@ -947,8 +949,8 @@ func (v *StateResolution) resolveConflictsV1( } // Resolve the conflicts. - resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return v.db.GetUserIDForSender(ctx, roomID, senderID) + resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return v.Querier.QueryUserIDForSender(ctx, roomID, senderID) }) // Map from the full events back to numeric state entries. @@ -1061,8 +1063,8 @@ func (v *StateResolution) resolveConflictsV2( conflictedEvents, nonConflictedEvents, authEvents, - func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return v.db.GetUserIDForSender(ctx, roomID, senderID) + func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return v.Querier.QueryUserIDForSender(ctx, roomID, senderID) }, ) }() diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 7787d9f85..7156c11cc 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -169,10 +169,6 @@ type Database interface { GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) // GetKnownUsers searches all users that userID knows about. GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) - // GetKnownUsers tries to obtain the current mxid for a given user. - GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) - // GetKnownUsers tries to obtain the current senderID for a given user. - GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) // GetKnownRooms returns a list of all rooms we know about. GetKnownRooms(ctx context.Context) ([]string, error) // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room @@ -190,6 +186,7 @@ type Database interface { ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string, ) (map[string]*types.HeaderedEvent, error) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (*types.RoomInfo, error) + GetRoomVersion(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) MaybeRedactEvent( @@ -205,8 +202,12 @@ type UserRoomKeys interface { InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) // SelectUserRoomPrivateKey selects the private key for the given user and room combination SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) + // SelectUserRoomPublicKey selects the public key for the given user and room combination + SelectUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PublicKey, err error) // SelectUserIDsForPublicKeys selects all userIDs for the requested senderKeys. Returns a map from roomID -> map from publicKey to userID. // If a senderKey can't be found, it is omitted in the result. + // TODO: Why is the result map indexed by string not public key? + // TODO: Shouldn't the input & result map be changed to be indexed by string instead of the RoomID struct? SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (map[spec.RoomID]map[string]string, error) } @@ -233,7 +234,6 @@ type RoomDatabase interface { GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) - GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) } type EventDatabase interface { diff --git a/roomserver/storage/postgres/user_room_keys_table.go b/roomserver/storage/postgres/user_room_keys_table.go index 22f978bf0..dbb4af34a 100644 --- a/roomserver/storage/postgres/user_room_keys_table.go +++ b/roomserver/storage/postgres/user_room_keys_table.go @@ -51,12 +51,15 @@ const insertUserRoomPublicKeySQL = ` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` +const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` + const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)` type userRoomKeysStatements struct { insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt + selectUserRoomPublicKeyStmt *sql.Stmt selectUserNIDsStmt *sql.Stmt } @@ -71,6 +74,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { {&s.insertUserRoomPrivateKeyStmt, insertUserRoomPrivateKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, + {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, {&s.selectUserNIDsStmt, selectUserNIDsSQL}, }.Prepare(db) } @@ -102,6 +106,21 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey( return result, err } +func (s *userRoomKeysStatements) SelectUserRoomPublicKey( + ctx context.Context, + txn *sql.Tx, + userNID types.EventStateKeyNID, + roomNID types.RoomNID, +) (ed25519.PublicKey, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomPublicKeyStmt) + var result ed25519.PublicKey + err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return result, err +} + func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt) diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 6fb57332a..70672a33e 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -251,7 +250,3 @@ func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error { func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) } - -func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return u.d.GetUserIDForSender(ctx, roomID, senderID) -} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index bda51da81..61a3520a4 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -721,6 +721,22 @@ func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserver }, err } +func (d *Database) GetRoomVersion(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) { + cachedRoomVersion, versionOK := d.Cache.GetRoomVersion(roomID) + if versionOK { + return cachedRoomVersion, nil + } + + roomInfo, err := d.RoomInfo(ctx, roomID) + if err != nil { + return "", err + } + if roomInfo == nil { + return "", nil + } + return roomInfo.RoomVersion, nil +} + func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, eventType); err != nil { @@ -1550,16 +1566,6 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) } -func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { - // TODO: Use real logic once DB for pseudoIDs is in place - return spec.NewUserID(string(senderID), true) -} - -func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { - // TODO: Use real logic once DB for pseudoIDs is in place - return spec.SenderID(userID.String()), nil -} - // GetKnownRooms returns a list of all rooms we know about. func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil) @@ -1718,6 +1724,35 @@ func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.Use return } +// SelectUserRoomPublicKey queries the users room public key. +// If no key exists, returns no key and no error. Otherwise returns +// the key and a database error, if any. +func (d *Database) SelectUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PublicKey, err error) { + uID := userID.String() + stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) + if sErr != nil { + return nil, sErr + } + stateKeyNID := stateKeyNIDMap[uID] + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + return nil + } + + key, sErr = d.UserRoomKeyTable.SelectUserRoomPublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID) + if !errors.Is(sErr, sql.ErrNoRows) { + return sErr + } + return nil + }) + return +} + // SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) { result = make(map[spec.RoomID]map[string]string, len(publicKeys)) diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index 581d83ee4..c7b915c7d 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -163,12 +163,17 @@ func TestUserRoomKeys(t *testing.T) { gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *roomID) assert.NoError(t, err) assert.Equal(t, key, gotKey) + pubKey, err := db.SelectUserRoomPublicKey(context.Background(), *userID, *roomID) + assert.NoError(t, err) + assert.Equal(t, key.Public(), pubKey) // Key doesn't exist, we shouldn't get anything back - assert.NoError(t, err) gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *doesNotExist) assert.NoError(t, err) assert.Nil(t, gotKey) + pubKey, err = db.SelectUserRoomPublicKey(context.Background(), *userID, *doesNotExist) + assert.NoError(t, err) + assert.Nil(t, pubKey) queryUserIDs := map[spec.RoomID][]ed25519.PublicKey{ *roomID: {key.Public().(ed25519.PublicKey)}, diff --git a/roomserver/storage/sqlite3/user_room_keys_table.go b/roomserver/storage/sqlite3/user_room_keys_table.go index 8af57ea0e..84c8b54ec 100644 --- a/roomserver/storage/sqlite3/user_room_keys_table.go +++ b/roomserver/storage/sqlite3/user_room_keys_table.go @@ -51,12 +51,15 @@ const insertUserRoomPublicKeySQL = ` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` +const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` + const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)` type userRoomKeysStatements struct { insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt + selectUserRoomPublicKeyStmt *sql.Stmt //selectUserNIDsStmt *sql.Stmt //prepared at runtime } @@ -71,6 +74,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { {&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, + {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, //{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime }.Prepare(db) } @@ -102,6 +106,21 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey( return result, err } +func (s *userRoomKeysStatements) SelectUserRoomPublicKey( + ctx context.Context, + txn *sql.Tx, + userNID types.EventStateKeyNID, + roomNID types.RoomNID, +) (ed25519.PublicKey, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomPublicKeyStmt) + var result ed25519.PublicKey + err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return result, err +} + func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { roomNIDs := make([]any, 0, len(senderKeys)) diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index cd0e51686..445c1223f 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -193,6 +193,8 @@ type UserRoomKeys interface { InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error) // SelectUserRoomPrivateKey selects the private key for the given user and room combination SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error) + // SelectUserRoomPublicKey selects the public key for the given user and room combination + SelectUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PublicKey, error) // BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair. // If a senderKey can't be found, it is omitted in the result. BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) diff --git a/roomserver/storage/tables/user_room_keys_table_test.go b/roomserver/storage/tables/user_room_keys_table_test.go index 284309481..8802a3c6e 100644 --- a/roomserver/storage/tables/user_room_keys_table_test.go +++ b/roomserver/storage/tables/user_room_keys_table_test.go @@ -50,6 +50,7 @@ func TestUserRoomKeysTable(t *testing.T) { err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error { var gotKey, key2, key3 ed25519.PrivateKey + var pubKey ed25519.PublicKey gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key) assert.NoError(t, err) assert.Equal(t, gotKey, key) @@ -71,6 +72,9 @@ func TestUserRoomKeysTable(t *testing.T) { gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, roomNID) assert.NoError(t, err) assert.Equal(t, key, gotKey) + pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, roomNID) + assert.NoError(t, err) + assert.Equal(t, key.Public(), pubKey) // try to update an existing key, this should only be done for users NOT on this homeserver var gotPubKey ed25519.PublicKey @@ -82,6 +86,9 @@ func TestUserRoomKeysTable(t *testing.T) { gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, 2) assert.NoError(t, err) assert.Nil(t, gotKey) + pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, 2) + assert.NoError(t, err) + assert.Nil(t, pubKey) // query user NIDs for senderKeys var gotKeys map[string]types.UserRoomKeyPair diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index d3f1c9dd2..f28419905 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -94,7 +94,7 @@ type MSC2836EventRelationshipsResponse struct { func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse { out := &EventRelationshipResponse{ - Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), Limited: res.Limited, diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index e32d6a9f2..16fb3efe1 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -525,11 +525,11 @@ type testRoomserverAPI struct { events map[string]*types.HeaderedEvent } -func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } -func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { +func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { return spec.SenderID(userID.String()), nil } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index c5f2db9c8..d468dfc98 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -377,7 +377,11 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst return sp, fmt.Errorf("unexpected nil state_key") } - userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + return sp, err + } + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*ev.StateKey())) if err != nil || userID == nil { return sp, fmt.Errorf("failed getting userID for sender: %w", err) } @@ -404,7 +408,11 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( return } - userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.Event.RoomID(), spec.SenderID(*msg.Event.StateKey())) + validRoomID, err := spec.NewRoomID(msg.Event.RoomID()) + if err != nil { + return + } + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*msg.Event.StateKey())) if err != nil || userID == nil { return } @@ -454,7 +462,16 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( // Notify any active sync requests that the invite has been retired. s.inviteStream.Advance(pduPos) - userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.RoomID, msg.TargetSenderID) + validRoomID, err := spec.NewRoomID(msg.RoomID) + if err != nil { + log.WithFields(log.Fields{ + "event_id": msg.EventID, + "room_id": msg.RoomID, + log.ErrorKey: err, + }).Errorf("roomID is invalid") + return + } + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, msg.TargetSenderID) if err != nil || userID == nil { log.WithFields(log.Fields{ "event_id": msg.EventID, diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go index ab1a7f83d..ce6846ca4 100644 --- a/syncapi/internal/history_visibility.go +++ b/syncapi/internal/history_visibility.go @@ -139,7 +139,11 @@ func ApplyHistoryVisibilityFilter( if err != nil { return nil, err } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *user) + roomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + return nil, err + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *user) if err == nil { if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) { eventsFiltered = append(eventsFiltered, ev) diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index f4b6ace59..24ffcc041 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -170,11 +170,15 @@ func TrackChangedUsers( return nil, nil, err } for roomID, state := range stateRes.Rooms { + validRoomID, roomErr := spec.NewRoomID(roomID) + if roomErr != nil { + continue + } for tuple, membership := range state { if membership != spec.Join { continue } - user, queryErr := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey)) + user, queryErr := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(tuple.StateKey)) if queryErr != nil || user == nil { continue } @@ -216,13 +220,17 @@ func TrackChangedUsers( return nil, left, err } for roomID, state := range stateRes.Rooms { + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + continue + } for tuple, membership := range state { if membership != spec.Join { continue } // new user who we weren't previously sharing rooms with if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok { - user, err := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey)) + user, err := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(tuple.StateKey)) if err != nil || user == nil { continue } diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index efa641475..3f5e990c4 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -64,7 +64,7 @@ type mockRoomserverAPI struct { roomIDToJoinedMembers map[string][]string } -func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index 4ee7c8605..af8ab0102 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -101,13 +101,20 @@ func (n *Notifier) OnNewEvent( n._removeEmptyUserStreams() if ev != nil { + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + log.WithError(err).WithField("event_id", ev.EventID()).Errorf( + "Notifier.OnNewEvent: RoomID is invalid", + ) + return + } // Map this event's room_id to a list of joined users, and wake them up. usersToNotify := n._joinedUsers(ev.RoomID()) // Map this event's room_id to a list of peeking devices, and wake them up. peekingDevicesToNotify := n._peekingDevices(ev.RoomID()) // If this is an invite, also add in the invitee to this list. if ev.Type() == "m.room.member" && ev.StateKey() != nil { - targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), ev.RoomID(), spec.SenderID(*ev.StateKey())) + targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), *validRoomID, spec.SenderID(*ev.StateKey())) if err != nil { log.WithError(err).WithField("event_id", ev.EventID()).Errorf( "Notifier.OnNewEvent: Failed to find the userID for this event", diff --git a/syncapi/notifier/notifier_test.go b/syncapi/notifier/notifier_test.go index 7076f7134..f86301a06 100644 --- a/syncapi/notifier/notifier_test.go +++ b/syncapi/notifier/notifier_test.go @@ -109,7 +109,7 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) { type TestRoomServer struct{ api.SyncRoomserverAPI } -func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 55fd3c5a2..649d77b41 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -200,10 +200,10 @@ func Context( } } - eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) - eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) @@ -211,7 +211,7 @@ func Context( if filter.LazyLoadMembers { allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) allEvents = append(allEvents, &requestedEvent) - evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) @@ -224,14 +224,14 @@ func Context( } } - ev := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + ev := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }, requestedEvent) response := ContextRespsonse{ Event: &ev, EventsAfter: eventsAfterClient, EventsBefore: eventsBeforeClient, - State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), } diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index de790e5cd..09c2aef02 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -102,14 +102,28 @@ func GetEvent( } sender := spec.UserID{} - senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), roomID, events[0].SenderID()) + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("roomID is invalid"), + } + } + senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, events[0].SenderID()) if err == nil && senderUserID != nil { sender = *senderUserID } sk := events[0].StateKey() if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey())) + evRoomID, err := spec.NewRoomID(events[0].RoomID()) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("roomID is invalid"), + } + } + skUserID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, spec.SenderID(*events[0].StateKey())) if err == nil && skUserID != nil { skString := skUserID.String() sk = &skString diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index cf6769ba4..5e5d0125f 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -152,7 +152,15 @@ func GetMemberships( } } - userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) + validRoomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, ev.SenderID()) if err != nil || userID == nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed") return util.JSONResponse{ @@ -175,7 +183,7 @@ func GetMemberships( } return util.JSONResponse{ Code: http.StatusOK, - JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) })}, } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 6784a27bd..937e20ad8 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -273,7 +273,7 @@ func OnIncomingMessagesRequest( JSON: spec.InternalServerError{}, } } - res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) })...) } @@ -389,7 +389,7 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv "events_before": len(events), "events_after": len(filteredEvents), }).Debug("applied history visibility (messages)") - return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), start, end, err } diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index 6efa065a9..17933b2fb 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -110,19 +110,24 @@ func Relations( return util.ErrorResponse(err) } + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.ErrorResponse(err) + } + // Convert the events into client events, and optionally filter based on the event // type if it was specified. res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents)) for _, event := range filteredEvents { sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID()) + userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID()) if err == nil && userID != nil { sender = *userID } sk := event.StateKey() if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey())) + skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey())) if err == nil && skUserID != nil { skString := skUserID.String() sk = &skString diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index 7d9182f47..d892b604a 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -205,9 +205,14 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts profileInfos := make(map[string]ProfileInfoResponse) for _, ev := range append(eventsBefore, eventsAfter...) { - userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) + validRoomID, roomErr := spec.NewRoomID(ev.RoomID()) + if err != nil { + logrus.WithError(roomErr).WithField("room_id", ev.RoomID()).Warn("failed to query userprofile") + continue + } + userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, ev.SenderID()) if queryErr != nil { - logrus.WithError(queryErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") + logrus.WithError(queryErr).WithField("sender_id", ev.SenderID()).Warn("failed to query userprofile") continue } @@ -231,14 +236,19 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts } sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID()) + validRoomID, roomErr := spec.NewRoomID(event.RoomID()) + if err != nil { + logrus.WithError(roomErr).WithField("room_id", event.RoomID()).Warn("failed to query userprofile") + continue + } + userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID()) if err == nil && userID != nil { sender = *userID } sk := event.StateKey() if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey())) + skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey())) if err == nil && skUserID != nil { skString := skUserID.String() sk = &skString @@ -248,10 +258,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts Context: SearchContextResponse{ Start: startToken.String(), End: endToken.String(), - EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), - EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), ProfileInfo: profileInfos, @@ -272,7 +282,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts JSON: spec.InternalServerError{}, } } - stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }) } diff --git a/syncapi/routing/search_test.go b/syncapi/routing/search_test.go index 5eb094ca3..f6d7fb4eb 100644 --- a/syncapi/routing/search_test.go +++ b/syncapi/routing/search_test.go @@ -25,7 +25,7 @@ import ( type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI } -func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index 799e3d166..1827218b6 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -114,7 +114,14 @@ func (d *Database) StreamEventsToEvents(ctx context.Context, device *userapi.Dev }).WithError(err).Warnf("Failed to add transaction ID to event") continue } - deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, in[i].RoomID(), *userID) + roomID, err := spec.NewRoomID(in[i].RoomID()) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Room ID is invalid") + continue + } + deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID) if err != nil { logrus.WithFields(logrus.Fields{ "event_id": out[i].EventID(), @@ -515,7 +522,11 @@ func getMembershipFromEvent(ctx context.Context, ev gomatrixserverlib.PDU, userI if err != nil { return "", "" } - senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *fullUser) + roomID, err := spec.NewRoomID(ev.RoomID()) + if err != nil { + return "", "" + } + senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *fullUser) if err != nil { return "", "" } diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 3a5badd92..7c29d84ae 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -65,14 +65,18 @@ func (p *InviteStreamProvider) IncrementalSync( for roomID, inviteEvent := range invites { user := spec.UserID{} - sender, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), inviteEvent.SenderID()) + validRoomID, err := spec.NewRoomID(inviteEvent.RoomID()) + if err != nil { + continue + } + sender, err := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, inviteEvent.SenderID()) if err == nil && sender != nil { user = *sender } sk := inviteEvent.StateKey() if sk != nil && *sk != "" { - skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey())) + skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey())) if err == nil && skUserID != nil { skString := skUserID.String() sk = &skString diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index f728d4aea..7939dd8fa 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -376,13 +376,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } jr.Timeline.PrevBatch = &prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Join[delta.RoomID] = jr @@ -391,11 +391,11 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( jr := types.NewJoinResponse() jr.Timeline.PrevBatch = &prevBatch // TODO: Apply history visibility on peeked rooms - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) jr.Timeline.Limited = limited - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Peek[delta.RoomID] = jr @@ -406,13 +406,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( case spec.Ban: lr := types.NewLeaveResponse() lr.Timeline.PrevBatch = &prevBatch - lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. lr.Timeline.Limited = limited && len(events) == len(recentEvents) - lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Leave[delta.RoomID] = lr @@ -564,13 +564,13 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } jr.Timeline.PrevBatch = prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = limited && len(events) == len(recentEvents) - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) return jr, nil @@ -585,6 +585,10 @@ func (p *PDUStreamProvider) lazyLoadMembers( if len(timelineEvents) == 0 { return stateEvents, nil } + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } // Work out which memberships to include timelineUsers := make(map[string]struct{}) if !incremental { @@ -606,8 +610,8 @@ func (p *PDUStreamProvider) lazyLoadMembers( isGappedIncremental := limited && incremental // We want this users membership event, keep it in the list userID := "" - stateKeyUserID, err := p.rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(*event.StateKey())) - if err == nil && stateKeyUserID != nil { + stateKeyUserID, queryErr := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey())) + if queryErr == nil && stateKeyUserID != nil { userID = stateKeyUserID.String() } if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID { diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index b9f13c517..19815b79b 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -40,7 +40,7 @@ type syncRoomserverAPI struct { rooms []*test.Room } -func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index 433be39f8..6f03d9ff0 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -52,14 +52,18 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, continue // TODO: shouldn't happen? } sender := spec.UserID{} - userID, err := userIDForSender(se.RoomID(), se.SenderID()) + validRoomID, err := spec.NewRoomID(se.RoomID()) + if err != nil { + continue + } + userID, err := userIDForSender(*validRoomID, se.SenderID()) if err == nil && userID != nil { sender = *userID } sk := se.StateKey() if sk != nil && *sk != "" { - skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk)) + skUserID, err := userIDForSender(*validRoomID, spec.SenderID(*sk)) if err == nil && skUserID != nil { skString := skUserID.String() sk = &skString @@ -95,14 +99,18 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender sp // It provides default logic for event.SenderID & event.StateKey -> userID conversions. func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent { sender := spec.UserID{} - userID, err := userIDQuery(event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return ClientEvent{} + } + userID, err := userIDQuery(*validRoomID, event.SenderID()) if err == nil && userID != nil { sender = *userID } sk := event.StateKey() if sk != nil && *sk != "" { - skUserID, err := userIDQuery(event.RoomID(), spec.SenderID(*event.StateKey())) + skUserID, err := userIDQuery(*validRoomID, spec.SenderID(*event.StateKey())) if err == nil && skUserID != nil { skString := skUserID.String() sk = &skString diff --git a/test/room.go b/test/room.go index b19c57ddc..da09de7c2 100644 --- a/test/room.go +++ b/test/room.go @@ -39,7 +39,7 @@ var ( roomIDCounter = int64(0) ) -func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func UserIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index b2dc477aa..9cb9419d4 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -302,14 +302,18 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst switch { case event.Type() == spec.MRoomMember: sender := spec.UserID{} - userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + validRoomID, roomErr := spec.NewRoomID(event.RoomID()) + if roomErr != nil { + return roomErr + } + userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) if queryErr == nil && userID != nil { sender = *userID } sk := event.StateKey() if sk != nil && *sk != "" { - skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey())) if queryErr == nil && skUserID != nil { skString := skUserID.String() sk = &skString @@ -544,14 +548,18 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype } sender := spec.UserID{} - userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return err + } + userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) if err == nil && userID != nil { sender = *userID } sk := event.StateKey() if sk != nil && *sk != "" { - skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) + skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey())) if queryErr == nil && skUserID != nil { skString := skUserID.String() sk = &skString @@ -644,7 +652,11 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype // user. Returns actions (including dont_notify). func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { user := "" - sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return nil, err + } + sender, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) if err == nil { user = sender.String() } @@ -682,7 +694,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * roomSize: roomSize, } eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global) - rule, err := eval.MatchEvent(event.PDU, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + rule, err := eval.MatchEvent(event.PDU, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) if err != nil { @@ -790,7 +802,11 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes } default: - sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return nil, err + } + sender, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) if err != nil { logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID()) return nil, err @@ -818,7 +834,13 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes logger.WithError(err).Errorf("Failed to convert local user to userID %s", localpart) return nil, err } - localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, event.RoomID(), *userID) + roomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + logger.WithError(err).Errorf("event roomID is invalid %s", event.RoomID()) + return nil, err + } + + localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID) if err != nil { logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID()) return nil, err diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 954247155..4dc81e74a 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -47,7 +47,7 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent { type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI } -func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } @@ -68,13 +68,13 @@ func Test_evaluatePushRules(t *testing.T) { }{ { name: "m.receipt doesn't notify", - eventContent: `{"type":"m.receipt"}`, + eventContent: `{"type":"m.receipt","room_id":"!room:example.com"}`, wantAction: pushrules.UnknownAction, wantActions: nil, }, { name: "m.reaction doesn't notify", - eventContent: `{"type":"m.reaction"}`, + eventContent: `{"type":"m.reaction","room_id":"!room:example.com"}`, wantAction: pushrules.DontNotifyAction, wantActions: []*pushrules.Action{ { @@ -84,7 +84,7 @@ func Test_evaluatePushRules(t *testing.T) { }, { name: "m.room.message notifies", - eventContent: `{"type":"m.room.message"}`, + eventContent: `{"type":"m.room.message","room_id":"!room:example.com"}`, wantNotify: true, wantAction: pushrules.NotifyAction, wantActions: []*pushrules.Action{ @@ -93,7 +93,7 @@ func Test_evaluatePushRules(t *testing.T) { }, { name: "m.room.message highlights", - eventContent: `{"type":"m.room.message", "content": {"body": "test"}}`, + eventContent: `{"type":"m.room.message", "content": {"body": "test"},"room_id":"!room:example.com"}`, wantNotify: true, wantAction: pushrules.NotifyAction, wantActions: []*pushrules.Action{