Merge SenderID & Per Room User Key work (#3109)

This commit is contained in:
devonh 2023-06-14 14:23:46 +00:00 committed by GitHub
parent 7a2e325d10
commit e4665979bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
75 changed files with 801 additions and 379 deletions

View File

@ -181,7 +181,7 @@ func (s *OutputRoomEventConsumer) sendEvents(
// Create the transaction body. // Create the transaction body.
transaction, err := json.Marshal( transaction, err := json.Marshal(
ApplicationServiceTransaction{ ApplicationServiceTransaction{
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), }),
}, },
@ -236,7 +236,11 @@ func (s *appserviceState) backoffAndPause(err error) error {
// TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool { func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool {
user := "" user := ""
userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return false
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err == nil { if err == nil {
user = userID.String() user = userID.String()
} }

View File

@ -233,11 +233,18 @@ func RemoveLocalAlias(
} }
} }
deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomIDRes.RoomID, *userID) validRoomID, err := spec.NewRoomID(roomIDRes.RoomID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusNotFound,
JSON: spec.InternalServerError{Err: "Could not find SenderID for this device"}, JSON: spec.NotFound("The alias does not exist."),
}
}
deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *userID)
if err != nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound("The alias does not exist."),
} }
} }
@ -321,7 +328,15 @@ func SetVisibility(
JSON: spec.BadJSON("userID for this device is invalid"), JSON: spec.BadJSON("userID for this device is invalid"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("RoomID is invalid"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,

View File

@ -64,7 +64,14 @@ func SendBan(
JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"), JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("RoomID is invalid"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -155,7 +162,14 @@ func SendKick(
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"), JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("RoomID is invalid"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -428,7 +442,11 @@ func buildMembershipEvent(
if err != nil { if err != nil {
return nil, err return nil, err
} }
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return nil, err
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -437,7 +455,7 @@ func buildMembershipEvent(
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *targetID) targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *targetID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -368,7 +368,11 @@ func buildMembershipEvents(
return nil, err return nil, err
} }
for _, roomID := range roomIDs { for _, roomID := range roomIDs {
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return nil, err
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -54,7 +54,14 @@ func SendRedaction(
JSON: spec.Forbidden("userID doesn't have power level to redact"), JSON: spec.Forbidden("userID doesn't have power level to redact"),
} }
} }
senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("RoomID is invalid"),
}
}
senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID)
if queryErr != nil { if queryErr != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,
@ -103,8 +110,8 @@ func SendRedaction(
JSON: spec.Forbidden("You don't have permission to redact this event, no power_levels event in this room."), JSON: spec.Forbidden("You don't have permission to redact this event, no power_levels event in this room."),
} }
} }
pl, err := plEvent.PowerLevels() pl, plErr := plEvent.PowerLevels()
if err != nil { if plErr != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: 403, Code: 403,
JSON: spec.Forbidden( JSON: spec.Forbidden(
@ -134,7 +141,7 @@ func SendRedaction(
Type: spec.MRoomRedaction, Type: spec.MRoomRedaction,
Redacts: eventID, Redacts: eventID,
} }
err := proto.SetContent(r) err = proto.SetContent(r)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("proto.SetContent failed") util.GetLogger(req.Context()).WithError(err).Error("proto.SetContent failed")
return util.JSONResponse{ return util.JSONResponse{

View File

@ -273,7 +273,14 @@ func generateSendEvent(
JSON: spec.BadJSON("Bad userID"), JSON: spec.BadJSON("Bad userID"),
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("RoomID is invalid"),
}
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusNotFound,
@ -344,8 +351,8 @@ func generateSendEvent(
stateEvents[i] = queryRes.StateEvents[i].PDU stateEvents[i] = queryRes.StateEvents[i].PDU
} }
provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, *validRoomID, senderID)
}); err != nil { }); err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,

View File

@ -150,7 +150,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
for _, ev := range stateRes.StateEvents { for _, ev := range stateRes.StateEvents {
stateEvents = append( stateEvents = append(
stateEvents, stateEvents,
synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, ev), }, ev),
) )
@ -173,14 +173,19 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
} }
for _, ev := range stateAfterRes.StateEvents { for _, ev := range stateAfterRes.StateEvents {
sender := spec.UserID{} sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) evRoomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Event roomID is invalid")
continue
}
userID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, ev.SenderID())
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := ev.StateKey() sk := ev.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) skUserID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, spec.SenderID(*ev.StateKey()))
if err == nil && skUserID != nil { if err == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString
@ -367,7 +372,7 @@ func OnIncomingStateTypeRequest(
} }
stateEvent := stateEventInStateResp{ stateEvent := stateEventInStateResp{
ClientEvent: synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { ClientEvent: synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, event), }, event),
} }

View File

@ -359,7 +359,11 @@ func emit3PIDInviteEvent(
if err != nil { if err != nil {
return err return err
} }
sender, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return err
}
sender, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID)
if err != nil { if err != nil {
return err return err
} }

View File

@ -11,11 +11,13 @@ import (
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
@ -66,10 +68,14 @@ func main() {
panic(err) panic(err)
} }
natsInstance := &jetstream.NATSInstance{}
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm,
natsInstance, caching.NewRistrettoCache(128*1024*1024, time.Hour, true), false)
roomInfo := &types.RoomInfo{ roomInfo := &types.RoomInfo{
RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion), RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion),
} }
stateres := state.NewStateResolution(roomserverDB, roomInfo) stateres := state.NewStateResolution(roomserverDB, roomInfo, rsAPI)
if *difference { if *difference {
if len(snapshotNIDs) != 2 { if len(snapshotNIDs) != 2 {
@ -183,8 +189,8 @@ func main() {
fmt.Println("Resolving state") fmt.Println("Resolving state")
var resolved Events var resolved Events
resolved, err = gomatrixserverlib.ResolveConflicts( resolved, err = gomatrixserverlib.ResolveConflicts(
gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return roomserverDB.GetUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, },
) )
if err != nil { if err != nil {

View File

@ -36,11 +36,11 @@ type fedRoomserverAPI struct {
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
} }
func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }
func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
return spec.SenderID(userID.String()), nil return spec.SenderID(userID.String()), nil
} }

View File

@ -154,14 +154,9 @@ func (r *FederationInternalAPI) performJoinUsingServer(
if err != nil { if err != nil {
return err return err
} }
senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, roomID, *user)
if err != nil {
return err
}
joinInput := gomatrixserverlib.PerformJoinInput{ joinInput := gomatrixserverlib.PerformJoinInput{
UserID: user, UserID: user,
SenderID: senderID,
RoomID: room, RoomID: room,
ServerName: serverName, ServerName: serverName,
Content: content, Content: content,
@ -169,12 +164,20 @@ func (r *FederationInternalAPI) performJoinUsingServer(
PrivateKey: r.cfg.Matrix.PrivateKey, PrivateKey: r.cfg.Matrix.PrivateKey,
KeyID: r.cfg.Matrix.KeyID, KeyID: r.cfg.Matrix.KeyID,
KeyRing: r.keyRing, KeyRing: r.keyRing,
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), }),
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, },
SenderIDCreator: func(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (spec.SenderID, error) {
key, keyErr := r.rsAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID)
if keyErr != nil {
return "", keyErr
}
return spec.SenderID(spec.Base64Bytes(key).Encode()), nil
},
} }
response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput) response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput)
@ -368,7 +371,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// authenticate the state returned (check its auth events etc) // authenticate the state returned (check its auth events etc)
// the equivalent of CheckSendJoinResponse() // the equivalent of CheckSendJoinResponse()
userIDProvider := func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { userIDProvider := func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
} }
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
@ -459,7 +462,11 @@ func (r *FederationInternalAPI) PerformLeave(
// Set all the fields to be what they should be, this should be a no-op // Set all the fields to be what they should be, this should be a no-op
// but it's possible that the remote server returned us something "odd" // but it's possible that the remote server returned us something "odd"
senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, request.RoomID, *userID) roomID, err := spec.NewRoomID(request.RoomID)
if err != nil {
return err
}
senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID)
if err != nil { if err != nil {
return err return err
} }
@ -527,7 +534,11 @@ func (r *FederationInternalAPI) SendInvite(
event gomatrixserverlib.PDU, event gomatrixserverlib.PDU,
strippedState []gomatrixserverlib.InviteStrippedState, strippedState []gomatrixserverlib.InviteStrippedState,
) (gomatrixserverlib.PDU, error) { ) (gomatrixserverlib.PDU, error) {
inviter, err := r.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return nil, err
}
inviter, err := r.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -95,7 +95,7 @@ func InviteV2(
StateQuerier: rsAPI.StateQuerier(), StateQuerier: rsAPI.StateQuerier(),
InviteEvent: inviteReq.Event(), InviteEvent: inviteReq.Event(),
StrippedState: inviteReq.InviteRoomState(), StrippedState: inviteReq.InviteRoomState(),
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
}, },
} }
@ -188,7 +188,7 @@ func InviteV1(
StateQuerier: rsAPI.StateQuerier(), StateQuerier: rsAPI.StateQuerier(),
InviteEvent: event, InviteEvent: event,
StrippedState: strippedState, StrippedState: strippedState,
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
}, },
} }

View File

@ -98,7 +98,7 @@ func MakeJoin(
Roomserver: rsAPI, Roomserver: rsAPI,
} }
senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID) senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID, userID)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed") util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed")
return util.JSONResponse{ return util.JSONResponse{
@ -118,7 +118,7 @@ func MakeJoin(
LocalServerName: cfg.Matrix.ServerName, LocalServerName: cfg.Matrix.ServerName,
LocalServerInRoom: res.RoomExists && res.IsInRoom, LocalServerInRoom: res.RoomExists && res.IsInRoom,
RoomQuerier: &roomQuerier, RoomQuerier: &roomQuerier,
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
}, },
BuildEventTemplate: createJoinTemplate, BuildEventTemplate: createJoinTemplate,
@ -215,7 +215,7 @@ func SendJoin(
PrivateKey: cfg.Matrix.PrivateKey, PrivateKey: cfg.Matrix.PrivateKey,
Verifier: keys, Verifier: keys,
MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI},
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
}, },
} }

View File

@ -87,7 +87,7 @@ func MakeLeave(
return event, stateEvents, nil return event, stateEvents, nil
} }
senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID) senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID, userID)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed") util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed")
return util.JSONResponse{ return util.JSONResponse{
@ -105,7 +105,7 @@ func MakeLeave(
LocalServerName: cfg.Matrix.ServerName, LocalServerName: cfg.Matrix.ServerName,
LocalServerInRoom: res.RoomExists && res.IsInRoom, LocalServerInRoom: res.RoomExists && res.IsInRoom,
BuildEventTemplate: createLeaveTemplate, BuildEventTemplate: createLeaveTemplate,
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
}, },
} }
@ -236,7 +236,14 @@ func SendLeave(
// Check that the sender belongs to the server that is sending us // Check that the sender belongs to the server that is sending us
// the request. By this point we've already asserted that the sender // the request. By this point we've already asserted that the sender
// and the state key are equal so we don't need to check both. // and the state key are equal so we don't need to check both.
sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("Room ID is invalid."),
}
}
sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, event.SenderID())
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusForbidden, Code: http.StatusForbidden,

View File

@ -140,7 +140,14 @@ func ExchangeThirdPartyInvite(
} }
} }
userID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(proto.SenderID)) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("Invalid room ID"),
}
}
userID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, spec.SenderID(proto.SenderID))
if err != nil || userID == nil { if err != nil || userID == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
@ -150,7 +157,7 @@ func ExchangeThirdPartyInvite(
senderDomain := userID.Domain() senderDomain := userID.Domain()
// Check that the state key is correct. // Check that the state key is correct.
targetUserID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(*proto.StateKey)) targetUserID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, spec.SenderID(*proto.StateKey))
if err != nil || targetUserID == nil { if err != nil || targetUserID == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,

10
go.mod
View File

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077 github.com/matrix-org/gomatrixserverlib v0.0.0-20230614140620-4dea2171c8f1
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.16 github.com/mattn/go-sqlite3 v1.14.16
@ -42,11 +42,11 @@ require (
github.com/uber/jaeger-lib v2.4.1+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.4.6 github.com/yggdrasil-network/yggdrasil-go v0.4.6
go.uber.org/atomic v1.10.0 go.uber.org/atomic v1.10.0
golang.org/x/crypto v0.9.0 golang.org/x/crypto v0.10.0
golang.org/x/image v0.5.0 golang.org/x/image v0.5.0
golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e
golang.org/x/sync v0.1.0 golang.org/x/sync v0.1.0
golang.org/x/term v0.8.0 golang.org/x/term v0.9.0
gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/h2non/bimg.v1 v1.1.9
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
gotest.tools/v3 v3.4.0 gotest.tools/v3 v3.4.0
@ -127,8 +127,8 @@ require (
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
golang.org/x/mod v0.8.0 // indirect golang.org/x/mod v0.8.0 // indirect
golang.org/x/net v0.10.0 // indirect golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.8.0 // indirect golang.org/x/sys v0.9.0 // indirect
golang.org/x/text v0.9.0 // indirect golang.org/x/text v0.10.0 // indirect
golang.org/x/time v0.3.0 // indirect golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.6.0 // indirect golang.org/x/tools v0.6.0 // indirect
google.golang.org/protobuf v1.28.1 // indirect google.golang.org/protobuf v1.28.1 // indirect

20
go.sum
View File

@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077 h1:AmKkAUjy9rZA2K+qHXm/O/dPEPnUYfRE2I6SL+Dj+LU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230614140620-4dea2171c8f1 h1:k75Fy0iQVbDjvddip/x898+BdyopBNAfL1BMNx0awA0=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230614140620-4dea2171c8f1/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
@ -511,8 +511,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -669,12 +669,12 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols= golang.org/x/term v0.9.0 h1:GRRCnKYhdQrD8kfRAdQ6Zcw1P0OcELxGLKJvtjVMZ28=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -683,8 +683,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

View File

@ -115,7 +115,11 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati
case SenderKind: case SenderKind:
userID := "" userID := ""
sender, err := userIDForSender(event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return false, err
}
sender, err := userIDForSender(*validRoomID, event.SenderID())
if err == nil { if err == nil {
userID = sender.String() userID = sender.String()
} }

View File

@ -8,7 +8,7 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
) )
func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) { func UserIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }
@ -73,7 +73,7 @@ func TestRuleMatches(t *testing.T) {
{"emptyOverride", OverrideKind, emptyRule, `{}`, true}, {"emptyOverride", OverrideKind, emptyRule, `{}`, true},
{"emptyContent", ContentKind, emptyRule, `{}`, false}, {"emptyContent", ContentKind, emptyRule, `{}`, false},
{"emptyRoom", RoomKind, emptyRule, `{}`, true}, {"emptyRoom", RoomKind, emptyRule, `{}`, true},
{"emptySender", SenderKind, emptyRule, `{}`, true}, {"emptySender", SenderKind, emptyRule, `{"room_id":"!room:example.com"}`, true},
{"emptyUnderride", UnderrideKind, emptyRule, `{}`, true}, {"emptyUnderride", UnderrideKind, emptyRule, `{}`, true},
{"disabled", OverrideKind, Rule{}, `{}`, false}, {"disabled", OverrideKind, Rule{}, `{}`, false},
@ -90,8 +90,8 @@ func TestRuleMatches(t *testing.T) {
{"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!room:example.com"}`, true}, {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!room:example.com"}`, true},
{"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!otherroom:example.com"}`, false}, {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!otherroom:example.com"}`, false},
{"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com"}`, true}, {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com","room_id":"!room:example.com"}`, true},
{"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com"}`, false}, {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com","room_id":"!room:example.com"}`, false},
} }
for _, tst := range tsts { for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) { t.Run(tst.Name, func(t *testing.T) {

View File

@ -167,7 +167,7 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut
} }
continue continue
} }
if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())

View File

@ -70,7 +70,7 @@ type FakeRsAPI struct {
bannedFromRoom bool bannedFromRoom bool
} }
func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }
@ -642,7 +642,7 @@ type testRoomserverAPI struct {
queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse
} }
func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }

View File

@ -51,6 +51,7 @@ type RoomserverInternalAPI interface {
UserRoomserverAPI UserRoomserverAPI
FederationRoomserverAPI FederationRoomserverAPI
QuerySenderIDAPI QuerySenderIDAPI
UserRoomPrivateKeyCreator
// needed to avoid chicken and egg scenario when setting up the // needed to avoid chicken and egg scenario when setting up the
// interdependencies between the roomserver and other input APIs // interdependencies between the roomserver and other input APIs
@ -67,7 +68,9 @@ type RoomserverInternalAPI interface {
req *QueryAuthChainRequest, req *QueryAuthChainRequest,
res *QueryAuthChainResponse, res *QueryAuthChainResponse,
) error ) error
}
type UserRoomPrivateKeyCreator interface {
// GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created. // GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created.
GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error) GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error)
} }
@ -81,8 +84,8 @@ type InputRoomEventsAPI interface {
} }
type QuerySenderIDAPI interface { type QuerySenderIDAPI interface {
QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error)
} }
// Query the latest events and state for a room from the room server. // Query the latest events and state for a room from the room server.
@ -228,6 +231,7 @@ type FederationRoomserverAPI interface {
QueryLatestEventsAndStateAPI QueryLatestEventsAndStateAPI
QueryBulkStateContentAPI QueryBulkStateContentAPI
QuerySenderIDAPI QuerySenderIDAPI
UserRoomPrivateKeyCreator
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error

View File

@ -15,7 +15,7 @@ package auth
import ( import (
"context" "context"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
) )
@ -25,7 +25,7 @@ import (
// IsServerAllowed returns true if the server is allowed to see events in the room // IsServerAllowed returns true if the server is allowed to see events in the room
// at this particular state. This function implements https://matrix.org/docs/spec/client_server/r0.6.0#id87 // at this particular state. This function implements https://matrix.org/docs/spec/client_server/r0.6.0#id87
func IsServerAllowed( func IsServerAllowed(
ctx context.Context, db storage.RoomDatabase, ctx context.Context, querier api.QuerySenderIDAPI,
serverName spec.ServerName, serverName spec.ServerName,
serverCurrentlyInRoom bool, serverCurrentlyInRoom bool,
authEvents []gomatrixserverlib.PDU, authEvents []gomatrixserverlib.PDU,
@ -41,7 +41,7 @@ func IsServerAllowed(
return true return true
} }
// 2. If the user's membership was join, allow. // 2. If the user's membership was join, allow.
joinedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Join) joinedUserExists := IsAnyUserOnServerWithMembership(ctx, querier, serverName, authEvents, spec.Join)
if joinedUserExists { if joinedUserExists {
return true return true
} }
@ -50,7 +50,7 @@ func IsServerAllowed(
return true return true
} }
// 4. If the user's membership was invite, and the history_visibility was set to invited, allow. // 4. If the user's membership was invite, and the history_visibility was set to invited, allow.
invitedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Invite) invitedUserExists := IsAnyUserOnServerWithMembership(ctx, querier, serverName, authEvents, spec.Invite)
if invitedUserExists && historyVisibility == gomatrixserverlib.HistoryVisibilityInvited { if invitedUserExists && historyVisibility == gomatrixserverlib.HistoryVisibilityInvited {
return true return true
} }
@ -74,7 +74,7 @@ func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.PDU) gomatrixserver
return visibility return visibility
} }
func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabase, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool { func IsAnyUserOnServerWithMembership(ctx context.Context, querier api.QuerySenderIDAPI, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool {
for _, ev := range authEvents { for _, ev := range authEvents {
if ev.Type() != spec.MRoomMember { if ev.Type() != spec.MRoomMember {
continue continue
@ -89,7 +89,11 @@ func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabas
continue continue
} }
userID, err := db.GetUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey)) validRoomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
continue
}
userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*stateKey))
if err != nil { if err != nil {
continue continue
} }

View File

@ -4,17 +4,17 @@ import (
"context" "context"
"testing" "testing"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
) )
type FakeStorageDB struct { type FakeQuerier struct {
storage.RoomDatabase api.QuerySenderIDAPI
} }
func (f *FakeStorageDB) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (f *FakeQuerier) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }
@ -87,7 +87,7 @@ func TestIsServerAllowed(t *testing.T) {
authEvents = append(authEvents, ev.PDU) authEvents = append(authEvents, ev.PDU)
} }
if got := IsServerAllowed(context.Background(), &FakeStorageDB{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want { if got := IsServerAllowed(context.Background(), &FakeQuerier{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want {
t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want) t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want)
} }
}) })

View File

@ -113,6 +113,7 @@ func (r *RoomserverInternalAPI) GetAliasesForRoomID(
return nil return nil
} }
// nolint:gocyclo
// RemoveRoomAlias implements alias.RoomserverInternalAPI // RemoveRoomAlias implements alias.RoomserverInternalAPI
func (r *RoomserverInternalAPI) RemoveRoomAlias( func (r *RoomserverInternalAPI) RemoveRoomAlias(
ctx context.Context, ctx context.Context,
@ -129,7 +130,12 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return nil return nil
} }
sender, err := r.QueryUserIDForSender(ctx, roomID, request.SenderID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return err
}
sender, err := r.QueryUserIDForSender(ctx, *validRoomID, request.SenderID)
if err != nil || sender == nil { if err != nil || sender == nil {
return fmt.Errorf("r.QueryUserIDForSender: %w", err) return fmt.Errorf("r.QueryUserIDForSender: %w", err)
} }
@ -177,7 +183,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
if request.SenderID != ev.SenderID() { if request.SenderID != ev.SenderID() {
senderID = ev.SenderID() senderID = ev.SenderID()
} }
sender, err := r.QueryUserIDForSender(ctx, roomID, senderID) sender, err := r.QueryUserIDForSender(ctx, *validRoomID, senderID)
if err != nil || sender == nil { if err != nil || sender == nil {
return err return err
} }
@ -206,7 +212,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
} }
stateRes := &api.QueryLatestEventsAndStateResponse{} stateRes := &api.QueryLatestEventsAndStateResponse{}
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil { if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil {
return err return err
} }

View File

@ -177,6 +177,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
IsLocalServerName: r.Cfg.Global.IsLocalServerName, IsLocalServerName: r.Cfg.Global.IsLocalServerName,
DB: r.DB, DB: r.DB,
FSAPI: r.fsAPI, FSAPI: r.fsAPI,
Querier: r.Queryer,
KeyRing: r.KeyRing, KeyRing: r.KeyRing,
// Perspective servers are trusted to not lie about server keys, so we will also // Perspective servers are trusted to not lie about server keys, so we will also
// prefer these servers when backfilling (assuming they are in the room) rather // prefer these servers when backfilling (assuming they are in the room) rather

View File

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
@ -36,6 +37,7 @@ func CheckForSoftFail(
roomInfo *types.RoomInfo, roomInfo *types.RoomInfo,
event *types.HeaderedEvent, event *types.HeaderedEvent,
stateEventIDs []string, stateEventIDs []string,
querier api.QuerySenderIDAPI,
) (bool, error) { ) (bool, error) {
rewritesState := len(stateEventIDs) > 1 rewritesState := len(stateEventIDs) > 1
@ -49,7 +51,7 @@ func CheckForSoftFail(
} else { } else {
// Then get the state entries for the current state snapshot. // Then get the state entries for the current state snapshot.
// We'll use this to check if the event is allowed right now. // We'll use this to check if the event is allowed right now.
roomState := state.NewStateResolution(db, roomInfo) roomState := state.NewStateResolution(db, roomInfo, querier)
authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID()) authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID())
if err != nil { if err != nil {
return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err) return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err)
@ -76,8 +78,8 @@ func CheckForSoftFail(
} }
// Check if the event is allowed. // Check if the event is allowed.
if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return db.GetUserIDForSender(ctx, roomID, senderID) return querier.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
// return true, nil // return true, nil
return true, err return true, err

View File

@ -68,7 +68,7 @@ func UpdateToInviteMembership(
// memberships. If the servername is not supplied then the local server will be // memberships. If the servername is not supplied then the local server will be
// checked instead using a faster code path. // checked instead using a faster code path.
// TODO: This should probably be replaced by an API call. // TODO: This should probably be replaced by an API call.
func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName spec.ServerName, roomID string) (bool, error) { func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, serverName spec.ServerName, roomID string) (bool, error) {
info, err := db.RoomInfo(ctx, roomID) info, err := db.RoomInfo(ctx, roomID)
if err != nil { if err != nil {
return false, err return false, err
@ -94,7 +94,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
for i := range events { for i := range events {
gmslEvents[i] = events[i].PDU gmslEvents[i] = events[i].PDU
} }
return auth.IsAnyUserOnServerWithMembership(ctx, db, serverName, gmslEvents, spec.Join), nil return auth.IsAnyUserOnServerWithMembership(ctx, querier, serverName, gmslEvents, spec.Join), nil
} }
func IsInvitePending( func IsInvitePending(
@ -211,8 +211,8 @@ func GetMembershipsAtState(
return events, nil return events, nil
} }
func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID, querier api.QuerySenderIDAPI) ([]types.StateEntry, error) {
roomState := state.NewStateResolution(db, info) roomState := state.NewStateResolution(db, info, querier)
// Lookup the event NID // Lookup the event NID
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
if err != nil { if err != nil {
@ -229,8 +229,8 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room
return roomState.LoadCombinedStateAfterEvents(ctx, prevState) return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
} }
func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) { func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID, querier api.QuerySenderIDAPI) (map[string][]types.StateEntry, error) {
roomState := state.NewStateResolution(db, info) roomState := state.NewStateResolution(db, info, querier)
// Fetch the state as it was when this event was fired // Fetch the state as it was when this event was fired
return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID) return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID)
} }
@ -264,7 +264,7 @@ func LoadStateEvents(
} }
func CheckServerAllowedToSeeEvent( func CheckServerAllowedToSeeEvent(
ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool, ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool, querier api.QuerySenderIDAPI,
) (bool, error) { ) (bool, error) {
stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName)) stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName))
switch err { switch err {
@ -273,7 +273,7 @@ func CheckServerAllowedToSeeEvent(
case tables.OptimisationNotSupportedError: case tables.OptimisationNotSupportedError:
// The database engine didn't support this optimisation, so fall back to using // The database engine didn't support this optimisation, so fall back to using
// the old and slow method // the old and slow method
stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName) stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName, querier)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -288,13 +288,13 @@ func CheckServerAllowedToSeeEvent(
return false, err return false, err
} }
} }
return auth.IsServerAllowed(ctx, db, serverName, isServerInRoom, stateAtEvent), nil return auth.IsServerAllowed(ctx, querier, serverName, isServerInRoom, stateAtEvent), nil
} }
func slowGetHistoryVisibilityState( func slowGetHistoryVisibilityState(
ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName, ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName, querier api.QuerySenderIDAPI,
) ([]gomatrixserverlib.PDU, error) { ) ([]gomatrixserverlib.PDU, error) {
roomState := state.NewStateResolution(db, info) roomState := state.NewStateResolution(db, info, querier)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
@ -318,9 +318,13 @@ func slowGetHistoryVisibilityState(
// If the event state key doesn't match the given servername // If the event state key doesn't match the given servername
// then we'll filter it out. This does preserve state keys that // then we'll filter it out. This does preserve state keys that
// are "" since these will contain history visibility etc. // are "" since these will contain history visibility etc.
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return nil, err
}
for nid, key := range stateKeys { for nid, key := range stateKeys {
if key != "" { if key != "" {
userID, err := db.GetUserIDForSender(ctx, roomID, spec.SenderID(key)) userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(key))
if err == nil && userID != nil { if err == nil && userID != nil {
if userID.Domain() != serverName { if userID.Domain() != serverName {
delete(stateKeys, nid) delete(stateKeys, nid)
@ -349,7 +353,7 @@ func slowGetHistoryVisibilityState(
// TODO: Remove this when we have tests to assert correctness of this function // TODO: Remove this when we have tests to assert correctness of this function
func ScanEventTree( func ScanEventTree(
ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int, ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int,
serverName spec.ServerName, serverName spec.ServerName, querier api.QuerySenderIDAPI,
) ([]types.EventNID, map[string]struct{}, error) { ) ([]types.EventNID, map[string]struct{}, error) {
var resultNIDs []types.EventNID var resultNIDs []types.EventNID
var err error var err error
@ -392,7 +396,7 @@ BFSLoop:
// It's nasty that we have to extract the room ID from an event, but many federation requests // It's nasty that we have to extract the room ID from an event, but many federation requests
// only talk in event IDs, no room IDs at all (!!!) // only talk in event IDs, no room IDs at all (!!!)
ev := events[0] ev := events[0]
isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID()) isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, querier, serverName, ev.RoomID())
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.") util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
} }
@ -415,7 +419,7 @@ BFSLoop:
// hasn't been seen before. // hasn't been seen before.
if !visited[pre] { if !visited[pre] {
visited[pre] = true visited[pre] = true
allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom) allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom, querier)
if err != nil { if err != nil {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
"Error checking if allowed to see event", "Error checking if allowed to see event",
@ -444,7 +448,7 @@ BFSLoop:
} }
func QueryLatestEventsAndState( func QueryLatestEventsAndState(
ctx context.Context, db storage.Database, ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI,
request *api.QueryLatestEventsAndStateRequest, request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse, response *api.QueryLatestEventsAndStateResponse,
) error { ) error {
@ -457,7 +461,7 @@ func QueryLatestEventsAndState(
return nil return nil
} }
roomState := state.NewStateResolution(db, roomInfo) roomState := state.NewStateResolution(db, roomInfo, querier)
response.RoomExists = true response.RoomExists = true
response.RoomVersion = roomInfo.RoomVersion response.RoomVersion = roomInfo.RoomVersion

View File

@ -128,7 +128,11 @@ func (r *Inputer) processRoomEvent(
if roomInfo == nil && !isCreateEvent { if roomInfo == nil && !isCreateEvent {
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
} }
sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return err
}
sender, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err != nil { if err != nil {
return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err) return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err)
} }
@ -282,8 +286,8 @@ func (r *Inputer) processRoomEvent(
// Check if the event is allowed by its auth events. If it isn't then // Check if the event is allowed by its auth events. If it isn't then
// we consider the event to be "rejected" — it will still be persisted. // we consider the event to be "rejected" — it will still be persisted.
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
isRejected = true isRejected = true
rejectionErr = err rejectionErr = err
@ -321,7 +325,7 @@ func (r *Inputer) processRoomEvent(
if input.Kind == api.KindNew && !isCreateEvent { if input.Kind == api.KindNew && !isCreateEvent {
// Check that the event passes authentication checks based on the // Check that the event passes authentication checks based on the
// current room state. // current room state.
softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs) softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs, r.Queryer)
if err != nil { if err != nil {
logger.WithError(err).Warn("Error authing soft-failed event") logger.WithError(err).Warn("Error authing soft-failed event")
} }
@ -401,7 +405,7 @@ func (r *Inputer) processRoomEvent(
redactedEvent gomatrixserverlib.PDU redactedEvent gomatrixserverlib.PDU
) )
if !isRejected && !isCreateEvent { if !isRejected && !isCreateEvent {
resolver := state.NewStateResolution(r.DB, roomInfo) resolver := state.NewStateResolution(r.DB, roomInfo, r.Queryer)
redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver) redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver)
if err != nil { if err != nil {
return err return err
@ -587,8 +591,8 @@ func (r *Inputer) processStateBefore(
stateBeforeAuth := gomatrixserverlib.NewAuthEvents( stateBeforeAuth := gomatrixserverlib.NewAuthEvents(
gomatrixserverlib.ToPDUs(stateBeforeEvent), gomatrixserverlib.ToPDUs(stateBeforeEvent),
) )
if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); rejectionErr != nil { }); rejectionErr != nil {
rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr) rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr)
return return
@ -700,8 +704,8 @@ nextAuthEvent:
// Check the signatures of the event. If this fails then we'll simply // Check the signatures of the event. If this fails then we'll simply
// skip it, because gomatrixserverlib.Allowed() will notice a problem // skip it, because gomatrixserverlib.Allowed() will notice a problem
// if a critical event is missing anyway. // if a critical event is missing anyway.
if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
continue nextAuthEvent continue nextAuthEvent
} }
@ -718,8 +722,8 @@ nextAuthEvent:
} }
// Check if the auth event should be rejected. // Check if the auth event should be rejected.
err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}) })
if isRejected = err != nil; isRejected { if isRejected = err != nil; isRejected {
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
@ -783,7 +787,7 @@ func (r *Inputer) calculateAndSetState(
return fmt.Errorf("r.DB.GetRoomUpdater: %w", err) return fmt.Errorf("r.DB.GetRoomUpdater: %w", err)
} }
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
roomState := state.NewStateResolution(updater, roomInfo) roomState := state.NewStateResolution(updater, roomInfo, r.Queryer)
if input.HasState { if input.HasState {
// We've been told what the state at the event is so we don't need to calculate it. // We've been told what the state at the event is so we don't need to calculate it.
@ -836,13 +840,18 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
return err return err
} }
validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return err
}
prevEvents := latestRes.LatestEvents prevEvents := latestRes.LatestEvents
for _, memberEvent := range memberEvents { for _, memberEvent := range memberEvents {
if memberEvent.StateKey() == nil { if memberEvent.StateKey() == nil {
continue continue
} }
memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, memberEvent.RoomID(), spec.SenderID(*memberEvent.StateKey())) memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*memberEvent.StateKey()))
if err != nil { if err != nil {
continue continue
} }

View File

@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) {
} }
// Finally check that the event is NOT allowed // Finally check that the event is NOT allowed
if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
}); err == nil { }); err == nil {
t.Fatalf("event should not be allowed, but it was") t.Fatalf("event should not be allowed, but it was")

View File

@ -213,7 +213,7 @@ func (u *latestEventsUpdater) latestState() error {
defer trace.EndRegion() defer trace.EndRegion()
var err error var err error
roomState := state.NewStateResolution(u.updater, u.roomInfo) roomState := state.NewStateResolution(u.updater, u.roomInfo, u.api.Queryer)
// Work out if the state at the extremities has actually changed // Work out if the state at the extremities has actually changed
// or not. If they haven't then we won't bother doing all of the // or not. If they haven't then we won't bother doing all of the

View File

@ -139,7 +139,11 @@ func (r *Inputer) updateMembership(
func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool { func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool {
isTargetLocalUser := false isTargetLocalUser := false
if statekey := event.StateKey(); statekey != nil { if statekey := event.StateKey(); statekey != nil {
userID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*statekey)) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return isTargetLocalUser
}
userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*statekey))
if err != nil || userID == nil { if err != nil || userID == nil {
return isTargetLocalUser return isTargetLocalUser
} }

View File

@ -383,7 +383,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
defer trace.EndRegion() defer trace.EndRegion()
var res parsedRespState var res parsedRespState
roomState := state.NewStateResolution(t.db, t.roomInfo) roomState := state.NewStateResolution(t.db, t.roomInfo, t.inputer.Queryer)
stateAtEvents, err := t.db.StateAtEventIDs(ctx, []string{eventID}) stateAtEvents, err := t.db.StateAtEventIDs(ctx, []string{eventID})
if err != nil { if err != nil {
t.log.WithError(err).Warnf("failed to get state after %s locally", eventID) t.log.WithError(err).Warnf("failed to get state after %s locally", eventID)
@ -473,8 +473,8 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
stateEventList = append(stateEventList, state.StateEvents...) stateEventList = append(stateEventList, state.StateEvents...)
} }
resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts( resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(
roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}, },
) )
if err != nil { if err != nil {
@ -482,8 +482,8 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
} }
// apply the current event // apply the current event
retryAllowedState: retryAllowedState:
if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
switch missing := err.(type) { switch missing := err.(type) {
case gomatrixserverlib.MissingAuthEventError: case gomatrixserverlib.MissingAuthEventError:
@ -569,8 +569,8 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver
// will be added and duplicates will be removed. // will be added and duplicates will be removed.
missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events)) missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events))
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
continue continue
} }
@ -660,8 +660,8 @@ func (t *missingStateReq) lookupMissingStateViaState(
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{ authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{
StateEvents: state.GetStateEvents(), StateEvents: state.GetStateEvents(),
AuthEvents: state.GetAuthEvents(), AuthEvents: state.GetAuthEvents(),
}, roomVersion, t.keys, nil, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { }, roomVersion, t.keys, nil, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -897,8 +897,8 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers))
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers))
} }
if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return t.db.GetUserIDForSender(ctx, roomID, senderID) return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
return nil, verifySigError{event.EventID(), err} return nil, verifySigError{event.EventID(), err}

View File

@ -74,6 +74,10 @@ func (r *Admin) PerformAdminEvacuateRoom(
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
return nil, err return nil, err
} }
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return nil, err
}
prevEvents := latestRes.LatestEvents prevEvents := latestRes.LatestEvents
var senderDomain spec.ServerName var senderDomain spec.ServerName
@ -100,7 +104,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
PrevEvents: prevEvents, PrevEvents: prevEvents,
} }
userID, err := r.Queryer.QueryUserIDForSender(ctx, roomID, spec.SenderID(fledglingEvent.SenderID)) userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(fledglingEvent.SenderID))
if err != nil || userID == nil { if err != nil || userID == nil {
continue continue
} }
@ -264,16 +268,16 @@ func (r *Admin) PerformAdminDownloadState(
return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err) return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err)
} }
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
continue continue
} }
authEventMap[authEvent.EventID()] = authEvent authEventMap[authEvent.EventID()] = authEvent
} }
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
continue continue
} }
@ -293,7 +297,11 @@ func (r *Admin) PerformAdminDownloadState(
stateIDs = append(stateIDs, stateEvent.EventID()) stateIDs = append(stateIDs, stateEvent.EventID())
} }
senderID, err := r.Queryer.QuerySenderIDForUser(ctx, roomID, *fullUserID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return err
}
senderID, err := r.Queryer.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID)
if err != nil { if err != nil {
return err return err
} }

View File

@ -42,6 +42,7 @@ type Backfiller struct {
DB storage.Database DB storage.Database
FSAPI federationAPI.RoomserverFederationAPI FSAPI federationAPI.RoomserverFederationAPI
KeyRing gomatrixserverlib.JSONVerifier KeyRing gomatrixserverlib.JSONVerifier
Querier api.QuerySenderIDAPI
// The servers which should be preferred above other servers when backfilling // The servers which should be preferred above other servers when backfilling
PreferServers []spec.ServerName PreferServers []spec.ServerName
@ -79,7 +80,7 @@ func (r *Backfiller) PerformBackfill(
} }
// Scan the event tree for events to send back. // Scan the event tree for events to send back.
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName, r.Querier)
if err != nil { if err != nil {
return err return err
} }
@ -113,7 +114,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
if info == nil || info.IsStub() { if info == nil || info.IsStub() {
return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID) return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID)
} }
requester := newBackfillRequester(r.DB, r.FSAPI, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers) requester := newBackfillRequester(r.DB, r.FSAPI, r.Querier, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers)
// Request 100 items regardless of what the query asks for. // Request 100 items regardless of what the query asks for.
// We don't want to go much higher than this. // We don't want to go much higher than this.
// We can't honour exactly the limit as some sytests rely on requesting more for tests to pass // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass
@ -121,8 +122,8 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
// Specifically the test "Outbound federation can backfill events" // Specifically the test "Outbound federation can backfill events"
events, err := gomatrixserverlib.RequestBackfill( events, err := gomatrixserverlib.RequestBackfill(
ctx, req.VirtualHost, requester, ctx, req.VirtualHost, requester,
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.Querier.QueryUserIDForSender(ctx, roomID, senderID)
}, },
) )
// Only return an error if we really couldn't get any events. // Only return an error if we really couldn't get any events.
@ -135,7 +136,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
// persist these new events - auth checks have already been done // persist these new events - auth checks have already been done
roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) roomNID, backfilledEventMap := persistEvents(ctx, r.DB, r.Querier, events)
for _, ev := range backfilledEventMap { for _, ev := range backfilledEventMap {
// now add state for these events // now add state for these events
@ -212,8 +213,8 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
continue continue
} }
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.Querier.QueryUserIDForSender(ctx, roomID, senderID)
}) })
if err != nil { if err != nil {
logger.WithError(err).Warn("failed to load and verify event") logger.WithError(err).Warn("failed to load and verify event")
@ -246,13 +247,14 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
} }
} }
util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents)) util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents))
persistEvents(ctx, r.DB, newEvents) persistEvents(ctx, r.DB, r.Querier, newEvents)
} }
// backfillRequester implements gomatrixserverlib.BackfillRequester // backfillRequester implements gomatrixserverlib.BackfillRequester
type backfillRequester struct { type backfillRequester struct {
db storage.Database db storage.Database
fsAPI federationAPI.RoomserverFederationAPI fsAPI federationAPI.RoomserverFederationAPI
querier api.QuerySenderIDAPI
virtualHost spec.ServerName virtualHost spec.ServerName
isLocalServerName func(spec.ServerName) bool isLocalServerName func(spec.ServerName) bool
preferServer map[spec.ServerName]bool preferServer map[spec.ServerName]bool
@ -268,6 +270,7 @@ type backfillRequester struct {
func newBackfillRequester( func newBackfillRequester(
db storage.Database, fsAPI federationAPI.RoomserverFederationAPI, db storage.Database, fsAPI federationAPI.RoomserverFederationAPI,
querier api.QuerySenderIDAPI,
virtualHost spec.ServerName, virtualHost spec.ServerName,
isLocalServerName func(spec.ServerName) bool, isLocalServerName func(spec.ServerName) bool,
bwExtrems map[string][]string, preferServers []spec.ServerName, bwExtrems map[string][]string, preferServers []spec.ServerName,
@ -279,6 +282,7 @@ func newBackfillRequester(
return &backfillRequester{ return &backfillRequester{
db: db, db: db,
fsAPI: fsAPI, fsAPI: fsAPI,
querier: querier,
virtualHost: virtualHost, virtualHost: virtualHost,
isLocalServerName: isLocalServerName, isLocalServerName: isLocalServerName,
eventIDToBeforeStateIDs: make(map[string][]string), eventIDToBeforeStateIDs: make(map[string][]string),
@ -460,14 +464,14 @@ FindSuccessor:
return nil return nil
} }
stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID) stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID, b.querier)
if err != nil { if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
return nil return nil
} }
// possibly return all joined servers depending on history visiblity // possibly return all joined servers depending on history visiblity
memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, info, stateEntries, b.virtualHost) memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, b.querier, info, stateEntries, b.virtualHost)
b.historyVisiblity = visibility b.historyVisiblity = visibility
if err != nil { if err != nil {
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
@ -488,7 +492,11 @@ FindSuccessor:
// Store the server names in a temporary map to avoid duplicates. // Store the server names in a temporary map to avoid duplicates.
serverSet := make(map[spec.ServerName]bool) serverSet := make(map[spec.ServerName]bool)
for _, event := range memberEvents { for _, event := range memberEvents {
if sender, err := b.db.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil { validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
continue
}
if sender, err := b.querier.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()); err == nil {
serverSet[sender.Domain()] = true serverSet[sender.Domain()] = true
} }
} }
@ -554,7 +562,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just // TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just
// pull all events and then filter by that table. // pull all events and then filter by that table.
func joinEventsFromHistoryVisibility( func joinEventsFromHistoryVisibility(
ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, ctx context.Context, db storage.RoomDatabase, querier api.QuerySenderIDAPI, roomInfo *types.RoomInfo, stateEntries []types.StateEntry,
thisServer spec.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) { thisServer spec.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) {
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
@ -582,7 +590,7 @@ func joinEventsFromHistoryVisibility(
} }
// Can we see events in the room? // Can we see events in the room?
canSeeEvents := auth.IsServerAllowed(ctx, db, thisServer, true, events) canSeeEvents := auth.IsServerAllowed(ctx, querier, thisServer, true, events)
visibility := auth.HistoryVisibilityForRoom(events) visibility := auth.HistoryVisibilityForRoom(events)
if !canSeeEvents { if !canSeeEvents {
logrus.Infof("ServersAtEvent history not visible to us: %s", visibility) logrus.Infof("ServersAtEvent history not visible to us: %s", visibility)
@ -597,7 +605,7 @@ func joinEventsFromHistoryVisibility(
return evs, visibility, err return evs, visibility, err
} }
func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.PDU) (types.RoomNID, map[string]types.Event) { func persistEvents(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, events []gomatrixserverlib.PDU) (types.RoomNID, map[string]types.Event) {
var roomNID types.RoomNID var roomNID types.RoomNID
var eventNID types.EventNID var eventNID types.EventNID
backfilledEventMap := make(map[string]types.Event) backfilledEventMap := make(map[string]types.Event)
@ -639,7 +647,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []gomatrixse
continue continue
} }
resolver := state.NewStateResolution(db, roomInfo) resolver := state.NewStateResolution(db, roomInfo, querier)
_, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver) _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver)
if err != nil { if err != nil {

View File

@ -63,13 +63,20 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
} }
} }
} }
senderID, err := c.DB.GetSenderIDForUser(ctx, roomID.String(), userID) var senderID spec.SenderID
if err != nil { if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user") // create user room key if needed
return "", &util.JSONResponse{ key, keyErr := c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID)
Code: http.StatusInternalServerError, if keyErr != nil {
JSON: spec.InternalServerError{}, util.GetLogger(ctx).WithError(keyErr).Error("GetOrCreateUserRoomPrivateKey failed")
return "", &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
} }
senderID = spec.SenderID(spec.Base64Bytes(key).Encode())
} else {
senderID = spec.SenderID(userID.String())
} }
createContent["creator"] = senderID createContent["creator"] = senderID
createContent["room_version"] = createRequest.RoomVersion createContent["room_version"] = createRequest.RoomVersion
@ -323,8 +330,8 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
} }
} }
if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return c.DB.GetUserIDForSender(ctx, roomID, senderID) return c.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
return "", &util.JSONResponse{ return "", &util.JSONResponse{
@ -364,18 +371,6 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
} }
} }
// create user room key if needed
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
_, err = c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("GetOrCreateUserRoomPrivateKey failed")
return "", &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
}
// send the remaining events // send the remaining events
if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs[1:], false); err != nil { if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs[1:], false); err != nil {
util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
@ -455,7 +450,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), *inviteeUserID) inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID, *inviteeUserID)
if queryErr != nil { if queryErr != nil {
util.GetLogger(ctx).WithError(queryErr).Error("rsapi.QuerySenderIDForUser failed") util.GetLogger(ctx).WithError(queryErr).Error("rsapi.QuerySenderIDForUser failed")
return "", &util.JSONResponse{ return "", &util.JSONResponse{

View File

@ -79,7 +79,7 @@ func (r *InboundPeeker) PerformInboundPeek(
response.LatestEvent = &types.HeaderedEvent{PDU: sortedLatestEvents[0]} response.LatestEvent = &types.HeaderedEvent{PDU: sortedLatestEvents[0]}
// XXX: do we actually need to do a state resolution here? // XXX: do we actually need to do a state resolution here?
roomState := state.NewStateResolution(r.DB, info) roomState := state.NewStateResolution(r.DB, info, r.Inputer.Queryer)
var stateEntries []types.StateEntry var stateEntries []types.StateEntry
stateEntries, err = roomState.LoadStateAtSnapshot( stateEntries, err = roomState.LoadStateAtSnapshot(

View File

@ -34,6 +34,7 @@ import (
type QueryState struct { type QueryState struct {
storage.Database storage.Database
querier api.QuerySenderIDAPI
} }
func (q *QueryState) GetAuthEvents(ctx context.Context, event gomatrixserverlib.PDU) (gomatrixserverlib.AuthEventProvider, error) { func (q *QueryState) GetAuthEvents(ctx context.Context, event gomatrixserverlib.PDU) (gomatrixserverlib.AuthEventProvider, error) {
@ -46,7 +47,7 @@ func (q *QueryState) GetState(ctx context.Context, roomID spec.RoomID, stateWant
return nil, fmt.Errorf("failed to load RoomInfo: %w", err) return nil, fmt.Errorf("failed to load RoomInfo: %w", err)
} }
if info != nil { if info != nil {
roomState := state.NewStateResolution(q.Database, info) roomState := state.NewStateResolution(q.Database, info, q.querier)
stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples(
ctx, info.StateSnapshotNID(), stateWanted, ctx, info.StateSnapshotNID(), stateWanted,
) )
@ -98,7 +99,11 @@ func (r *Inviter) ProcessInviteMembership(
var outputUpdates []api.OutputEvent var outputUpdates []api.OutputEvent
var updater *shared.MembershipUpdater var updater *shared.MembershipUpdater
userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey())) validRoomID, err := spec.NewRoomID(inviteEvent.RoomID())
if err != nil {
return nil, err
}
userID, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey()))
if err != nil { if err != nil {
return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())}
} }
@ -126,7 +131,12 @@ func (r *Inviter) PerformInvite(
) error { ) error {
event := req.Event event := req.Event
sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return err
}
sender, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err != nil { if err != nil {
return spec.InvalidParam("The sender user ID is invalid") return spec.InvalidParam("The sender user ID is invalid")
} }
@ -137,18 +147,13 @@ func (r *Inviter) PerformInvite(
if event.StateKey() == nil || *event.StateKey() == "" { if event.StateKey() == nil || *event.StateKey() == "" {
return fmt.Errorf("invite must be a state event") return fmt.Errorf("invite must be a state event")
} }
invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
if err != nil || invitedUser == nil { if err != nil || invitedUser == nil {
return spec.InvalidParam("Could not find the matching senderID for this user") return spec.InvalidParam("Could not find the matching senderID for this user")
} }
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain()) isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain())
validRoomID, err := spec.NewRoomID(event.RoomID()) invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, *validRoomID, *invitedUser)
if err != nil {
return err
}
invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, event.RoomID(), *invitedUser)
if err != nil { if err != nil {
return fmt.Errorf("failed looking up senderID for invited user") return fmt.Errorf("failed looking up senderID for invited user")
} }
@ -161,9 +166,9 @@ func (r *Inviter) PerformInvite(
IsTargetLocal: isTargetLocal, IsTargetLocal: isTargetLocal,
StrippedState: req.InviteRoomState, StrippedState: req.InviteRoomState,
MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI},
StateQuerier: &QueryState{r.DB}, StateQuerier: &QueryState{r.DB, r.RSAPI},
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, },
} }
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI) inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)

View File

@ -25,6 +25,7 @@ import (
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@ -174,44 +175,6 @@ func (r *Joiner) performJoinRoomByID(
req.ServerNames = append(req.ServerNames, roomID.Domain()) req.ServerNames = append(req.ServerNames, roomID.Domain())
} }
// Prepare the template for the join event.
userID, err := spec.NewUserID(req.UserID, true)
if err != nil {
return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
}
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomIDOrAlias, *userID)
if err != nil {
return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
}
senderIDString := string(senderID)
userDomain := userID.Domain()
proto := gomatrixserverlib.ProtoEvent{
Type: spec.MRoomMember,
SenderID: senderIDString,
StateKey: &senderIDString,
RoomID: req.RoomIDOrAlias,
Redacts: "",
}
if err = proto.SetUnsigned(struct{}{}); err != nil {
return "", "", fmt.Errorf("eb.SetUnsigned: %w", err)
}
// It is possible for the request to include some "content" for the
// event. We'll always overwrite the "membership" key, but the rest,
// like "display_name" or "avatar_url", will be kept if supplied.
if req.Content == nil {
req.Content = map[string]interface{}{}
}
req.Content["membership"] = spec.Join
if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil {
return "", "", aerr
} else if authorisedVia != "" {
req.Content["join_authorised_via_users_server"] = authorisedVia
}
if err = proto.SetContent(req.Content); err != nil {
return "", "", fmt.Errorf("eb.SetContent: %w", err)
}
// Force a federated join if we aren't in the room and we've been // Force a federated join if we aren't in the room and we've been
// given some server names to try joining by. // given some server names to try joining by.
inRoomReq := &rsAPI.QueryServerJoinedToRoomRequest{ inRoomReq := &rsAPI.QueryServerJoinedToRoomRequest{
@ -224,29 +187,63 @@ func (r *Joiner) performJoinRoomByID(
serverInRoom := inRoomRes.IsInRoom serverInRoom := inRoomRes.IsInRoom
forceFederatedJoin := len(req.ServerNames) > 0 && !serverInRoom forceFederatedJoin := len(req.ServerNames) > 0 && !serverInRoom
userID, err := spec.NewUserID(req.UserID, true)
if err != nil {
return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
}
// Look up the room NID for the supplied room ID.
var senderID spec.SenderID
checkInvitePending := false
info, err := r.DB.RoomInfo(ctx, req.RoomIDOrAlias)
if err == nil && info != nil {
switch info.RoomVersion {
case gomatrixserverlib.RoomVersionPseudoIDs:
senderID, err = r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID)
if err == nil {
checkInvitePending = true
} else {
// create user room key if needed
key, keyErr := r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID)
if keyErr != nil {
util.GetLogger(ctx).WithError(keyErr).Error("GetOrCreateUserRoomPrivateKey failed")
return "", "", fmt.Errorf("GetOrCreateUserRoomPrivateKey failed: %w", keyErr)
}
senderID = spec.SenderID(spec.Base64Bytes(key).Encode())
}
default:
checkInvitePending = true
senderID = spec.SenderID(userID.String())
}
}
userDomain := userID.Domain()
// Force a federated join if we're dealing with a pending invite // Force a federated join if we're dealing with a pending invite
// and we aren't in the room. // and we aren't in the room.
isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID) if checkInvitePending {
if err == nil && !serverInRoom && isInvitePending { isInvitePending, inviteSender, _, inviteEvent, inviteErr := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID)
inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomIDOrAlias, inviteSender) if inviteErr == nil && !serverInRoom && isInvitePending {
if queryErr != nil { inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, inviteSender)
return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr) if queryErr != nil {
} return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr)
}
// If we were invited by someone from another server then we can // If we were invited by someone from another server then we can
// assume they are in the room so we can join via them. // assume they are in the room so we can join via them.
if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) { if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) {
req.ServerNames = append(req.ServerNames, inviter.Domain()) req.ServerNames = append(req.ServerNames, inviter.Domain())
forceFederatedJoin = true forceFederatedJoin = true
memberEvent := gjson.Parse(string(inviteEvent.JSON())) memberEvent := gjson.Parse(string(inviteEvent.JSON()))
// only set unsigned if we've got a content.membership, which we _should_ // only set unsigned if we've got a content.membership, which we _should_
if memberEvent.Get("content.membership").Exists() { if memberEvent.Get("content.membership").Exists() {
req.Unsigned = map[string]interface{}{ req.Unsigned = map[string]interface{}{
"prev_sender": memberEvent.Get("sender").Str, "prev_sender": memberEvent.Get("sender").Str,
"prev_content": map[string]interface{}{ "prev_content": map[string]interface{}{
"is_direct": memberEvent.Get("content.is_direct").Bool(), "is_direct": memberEvent.Get("content.is_direct").Bool(),
"membership": memberEvent.Get("content.membership").Str, "membership": memberEvent.Get("content.membership").Str,
}, },
}
} }
} }
} }
@ -274,6 +271,7 @@ func (r *Joiner) performJoinRoomByID(
// If we should do a forced federated join then do that. // If we should do a forced federated join then do that.
var joinedVia spec.ServerName var joinedVia spec.ServerName
if forceFederatedJoin { if forceFederatedJoin {
// TODO : pseudoIDs - pass through userID here since we don't know what the senderID should be yet
joinedVia, err = r.performFederatedJoinRoomByID(ctx, req) joinedVia, err = r.performFederatedJoinRoomByID(ctx, req)
return req.RoomIDOrAlias, joinedVia, err return req.RoomIDOrAlias, joinedVia, err
} }
@ -289,19 +287,40 @@ func (r *Joiner) performJoinRoomByID(
if err != nil { if err != nil {
return "", "", fmt.Errorf("error joining local room: %q", err) return "", "", fmt.Errorf("error joining local room: %q", err)
} }
senderIDString := string(senderID)
// Prepare the template for the join event.
proto := gomatrixserverlib.ProtoEvent{
Type: spec.MRoomMember,
SenderID: senderIDString,
StateKey: &senderIDString,
RoomID: req.RoomIDOrAlias,
Redacts: "",
}
if err = proto.SetUnsigned(struct{}{}); err != nil {
return "", "", fmt.Errorf("eb.SetUnsigned: %w", err)
}
// It is possible for the request to include some "content" for the
// event. We'll always overwrite the "membership" key, but the rest,
// like "display_name" or "avatar_url", will be kept if supplied.
if req.Content == nil {
req.Content = map[string]interface{}{}
}
req.Content["membership"] = spec.Join
if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil {
return "", "", aerr
} else if authorisedVia != "" {
req.Content["join_authorised_via_users_server"] = authorisedVia
}
if err = proto.SetContent(req.Content); err != nil {
return "", "", fmt.Errorf("eb.SetContent: %w", err)
}
event, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, time.Now(), r.RSAPI, &buildRes) event, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, time.Now(), r.RSAPI, &buildRes)
switch err.(type) { switch err.(type) {
case nil: case nil:
// create user room key if needed
if buildRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
_, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID)
if err != nil {
logrus.WithError(err).Error("GetOrCreateUserRoomPrivateKey failed")
return "", "", fmt.Errorf("failed to get user room private key: %w", err)
}
}
// The room join is local. Send the new join event into the // The room join is local. Send the new join event into the
// roomserver. First of all check that the user isn't already // roomserver. First of all check that the user isn't already
// a member of the room. This is best-effort (as in we won't // a member of the room. This is best-effort (as in we won't

View File

@ -78,7 +78,11 @@ func (r *Leaver) performLeaveRoomByID(
req *api.PerformLeaveRequest, req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, // nolint:unparam res *api.PerformLeaveResponse, // nolint:unparam
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, req.Leaver) roomID, err := spec.NewRoomID(req.RoomID)
if err != nil {
return nil, err
}
leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, *roomID, req.Leaver)
if err != nil { if err != nil {
return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String()) return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String())
} }
@ -87,7 +91,7 @@ func (r *Leaver) performLeaveRoomByID(
// that. // that.
isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver) isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver)
if err == nil && isInvitePending { if err == nil && isInvitePending {
sender, serr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomID, senderUser) sender, serr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, senderUser)
if serr != nil || sender == nil { if serr != nil || sender == nil {
return nil, fmt.Errorf("sender %q has no matching userID", senderUser) return nil, fmt.Errorf("sender %q has no matching userID", senderUser)
} }
@ -133,7 +137,7 @@ func (r *Leaver) performLeaveRoomByID(
}, },
} }
latestRes := api.QueryLatestEventsAndStateResponse{} latestRes := api.QueryLatestEventsAndStateResponse{}
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &latestReq, &latestRes); err != nil { if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r.RSAPI, &latestReq, &latestRes); err != nil {
return nil, err return nil, err
} }
if !latestRes.RoomExists { if !latestRes.RoomExists {

View File

@ -54,7 +54,11 @@ func (r *Upgrader) performRoomUpgrade(
return "", err return "", err
} }
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, userID) fullRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return "", err
}
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, *fullRoomID, userID)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user") util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
return "", err return "", err
@ -488,7 +492,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, send
} }
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err) return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err)
@ -569,7 +573,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, send
stateEvents[i] = queryRes.StateEvents[i].PDU stateEvents[i] = queryRes.StateEvents[i].PDU
} }
provider := gomatrixserverlib.NewAuthEvents(stateEvents) provider := gomatrixserverlib.NewAuthEvents(stateEvents)
if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil { }); err != nil {
return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client? return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client?

View File

@ -16,6 +16,7 @@ package query
import ( import (
"context" "context"
"crypto/ed25519"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
@ -89,7 +90,7 @@ func (r *Queryer) QueryLatestEventsAndState(
request *api.QueryLatestEventsAndStateRequest, request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse, response *api.QueryLatestEventsAndStateResponse,
) error { ) error {
return helpers.QueryLatestEventsAndState(ctx, r.DB, request, response) return helpers.QueryLatestEventsAndState(ctx, r.DB, r, request, response)
} }
// QueryStateAfterEvents implements api.RoomserverInternalAPI // QueryStateAfterEvents implements api.RoomserverInternalAPI
@ -106,7 +107,7 @@ func (r *Queryer) QueryStateAfterEvents(
return nil return nil
} }
roomState := state.NewStateResolution(r.DB, info) roomState := state.NewStateResolution(r.DB, info, r)
response.RoomExists = true response.RoomExists = true
response.RoomVersion = info.RoomVersion response.RoomVersion = info.RoomVersion
@ -159,8 +160,8 @@ func (r *Queryer) QueryStateAfterEvents(
} }
stateEvents, err = gomatrixserverlib.ResolveConflicts( stateEvents, err = gomatrixserverlib.ResolveConflicts(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.QueryUserIDForSender(ctx, roomID, senderID)
}, },
) )
if err != nil { if err != nil {
@ -271,15 +272,15 @@ func (r *Queryer) QueryMembershipForUser(
request *api.QueryMembershipForUserRequest, request *api.QueryMembershipForUserRequest,
response *api.QueryMembershipForUserResponse, response *api.QueryMembershipForUserResponse,
) error { ) error {
senderID, err := r.DB.GetSenderIDForUser(ctx, request.RoomID, request.UserID)
if err != nil {
return err
}
roomID, err := spec.NewRoomID(request.RoomID) roomID, err := spec.NewRoomID(request.RoomID)
if err != nil { if err != nil {
return err return err
} }
senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID)
if err != nil {
return err
}
return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response) return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response)
} }
@ -320,7 +321,7 @@ func (r *Queryer) QueryMembershipAtEvent(
} }
response.Membership = make(map[string]*types.HeaderedEvent) response.Membership = make(map[string]*types.HeaderedEvent)
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID]) stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID], r)
if err != nil { if err != nil {
return fmt.Errorf("unable to get state before event: %w", err) return fmt.Errorf("unable to get state before event: %w", err)
} }
@ -407,7 +408,7 @@ func (r *Queryer) QueryMembershipsForRoom(
return fmt.Errorf("r.DB.Events: %w", err) return fmt.Errorf("r.DB.Events: %w", err)
} }
for _, event := range events { for _, event := range events {
clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID) return r.QueryUserIDForSender(ctx, roomID, senderID)
}, event) }, event)
response.JoinEvents = append(response.JoinEvents, clientEvent) response.JoinEvents = append(response.JoinEvents, clientEvent)
@ -445,7 +446,7 @@ func (r *Queryer) QueryMembershipsForRoom(
events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs) events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs)
} else { } else {
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID) stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID, r)
if err != nil { if err != nil {
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
return err return err
@ -458,7 +459,7 @@ func (r *Queryer) QueryMembershipsForRoom(
} }
for _, event := range events { for _, event := range events {
clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID) return r.QueryUserIDForSender(ctx, roomID, senderID)
}, event) }, event)
response.JoinEvents = append(response.JoinEvents, clientEvent) response.JoinEvents = append(response.JoinEvents, clientEvent)
@ -532,7 +533,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
} }
return helpers.CheckServerAllowedToSeeEvent( return helpers.CheckServerAllowedToSeeEvent(
ctx, r.DB, info, roomID, eventID, serverName, isInRoom, ctx, r.DB, info, roomID, eventID, serverName, isInRoom, r,
) )
} }
@ -573,7 +574,7 @@ func (r *Queryer) QueryMissingEvents(
return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID) return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID)
} }
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName, r)
if err != nil { if err != nil {
return err return err
} }
@ -651,8 +652,8 @@ func (r *Queryer) QueryStateAndAuthChain(
if request.ResolveState { if request.ResolveState {
stateEvents, err = gomatrixserverlib.ResolveConflicts( stateEvents, err = gomatrixserverlib.ResolveConflicts(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) return r.QueryUserIDForSender(ctx, roomID, senderID)
}, },
) )
if err != nil { if err != nil {
@ -673,7 +674,7 @@ func (r *Queryer) QueryStateAndAuthChain(
// first bool: is rejected, second bool: state missing // first bool: is rejected, second bool: state missing
func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.PDU, bool, bool, error) { func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.PDU, bool, bool, error) {
roomState := state.NewStateResolution(r.DB, roomInfo) roomState := state.NewStateResolution(r.DB, roomInfo, r)
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
@ -989,10 +990,46 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro
return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID) return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID)
} }
func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
return r.DB.GetSenderIDForUser(ctx, roomID, userID) version, err := r.DB.GetRoomVersion(ctx, roomID.String())
if err != nil {
return "", err
}
switch version {
case gomatrixserverlib.RoomVersionPseudoIDs:
key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID)
if err != nil {
return "", err
}
return spec.SenderID(spec.Base64Bytes(key).Encode()), nil
default:
return spec.SenderID(userID.String()), nil
}
} }
func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.DB.GetUserIDForSender(ctx, roomID, senderID) userID, err := spec.NewUserID(string(senderID), true)
if err == nil {
return userID, nil
}
bytes := spec.Base64Bytes{}
err = bytes.Decode(string(senderID))
if err != nil {
return nil, err
}
queryMap := map[spec.RoomID][]ed25519.PublicKey{roomID: {ed25519.PublicKey(bytes)}}
result, err := r.DB.SelectUserIDsForPublicKeys(ctx, queryMap)
if err != nil {
return nil, err
}
if userKeys, ok := result[roomID]; ok {
if userID, ok := userKeys[string(senderID)]; ok {
return spec.NewUserID(userID, true)
}
}
return nil, nil
} }

View File

@ -516,6 +516,9 @@ func TestRedaction(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
natsInstance := &jetstream.NATSInstance{}
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
authEvents := []types.EventNID{} authEvents := []types.EventNID{}
@ -551,7 +554,7 @@ func TestRedaction(t *testing.T) {
} }
// Calculate the snapshotNID etc. // Calculate the snapshotNID etc.
plResolver := state.NewStateResolution(db, roomInfo) plResolver := state.NewStateResolution(db, roomInfo, rsAPI)
stateAtEvent.BeforeStateSnapshotNID, err = plResolver.CalculateAndStoreStateBeforeEvent(ctx, ev.PDU, false) stateAtEvent.BeforeStateSnapshotNID, err = plResolver.CalculateAndStoreStateBeforeEvent(ctx, ev.PDU, false)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -29,6 +29,7 @@ import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -44,20 +45,21 @@ type StateResolutionStorage interface {
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
} }
type StateResolution struct { type StateResolution struct {
db StateResolutionStorage db StateResolutionStorage
roomInfo *types.RoomInfo roomInfo *types.RoomInfo
events map[types.EventNID]gomatrixserverlib.PDU events map[types.EventNID]gomatrixserverlib.PDU
Querier api.QuerySenderIDAPI
} }
func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution { func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo, querier api.QuerySenderIDAPI) StateResolution {
return StateResolution{ return StateResolution{
db: db, db: db,
roomInfo: roomInfo, roomInfo: roomInfo,
events: make(map[types.EventNID]gomatrixserverlib.PDU), events: make(map[types.EventNID]gomatrixserverlib.PDU),
Querier: querier,
} }
} }
@ -947,8 +949,8 @@ func (v *StateResolution) resolveConflictsV1(
} }
// Resolve the conflicts. // Resolve the conflicts.
resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return v.db.GetUserIDForSender(ctx, roomID, senderID) return v.Querier.QueryUserIDForSender(ctx, roomID, senderID)
}) })
// Map from the full events back to numeric state entries. // Map from the full events back to numeric state entries.
@ -1061,8 +1063,8 @@ func (v *StateResolution) resolveConflictsV2(
conflictedEvents, conflictedEvents,
nonConflictedEvents, nonConflictedEvents,
authEvents, authEvents,
func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return v.db.GetUserIDForSender(ctx, roomID, senderID) return v.Querier.QueryUserIDForSender(ctx, roomID, senderID)
}, },
) )
}() }()

View File

@ -169,10 +169,6 @@ type Database interface {
GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error)
// GetKnownUsers searches all users that userID knows about. // GetKnownUsers searches all users that userID knows about.
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
// GetKnownUsers tries to obtain the current mxid for a given user.
GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
// GetKnownUsers tries to obtain the current senderID for a given user.
GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error)
// GetKnownRooms returns a list of all rooms we know about. // GetKnownRooms returns a list of all rooms we know about.
GetKnownRooms(ctx context.Context) ([]string, error) GetKnownRooms(ctx context.Context) ([]string, error)
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
@ -190,6 +186,7 @@ type Database interface {
ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string, ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string,
) (map[string]*types.HeaderedEvent, error) ) (map[string]*types.HeaderedEvent, error)
GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (*types.RoomInfo, error) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (*types.RoomInfo, error)
GetRoomVersion(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
MaybeRedactEvent( MaybeRedactEvent(
@ -205,8 +202,12 @@ type UserRoomKeys interface {
InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error)
// SelectUserRoomPrivateKey selects the private key for the given user and room combination // SelectUserRoomPrivateKey selects the private key for the given user and room combination
SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error)
// SelectUserRoomPublicKey selects the public key for the given user and room combination
SelectUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PublicKey, err error)
// SelectUserIDsForPublicKeys selects all userIDs for the requested senderKeys. Returns a map from roomID -> map from publicKey to userID. // SelectUserIDsForPublicKeys selects all userIDs for the requested senderKeys. Returns a map from roomID -> map from publicKey to userID.
// If a senderKey can't be found, it is omitted in the result. // If a senderKey can't be found, it is omitted in the result.
// TODO: Why is the result map indexed by string not public key?
// TODO: Shouldn't the input & result map be changed to be indexed by string instead of the RoomID struct?
SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (map[spec.RoomID]map[string]string, error) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (map[spec.RoomID]map[string]string, error)
} }
@ -233,7 +234,6 @@ type RoomDatabase interface {
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error)
GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
} }
type EventDatabase interface { type EventDatabase interface {

View File

@ -51,12 +51,15 @@ const insertUserRoomPublicKeySQL = `
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)` const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)`
type userRoomKeysStatements struct { type userRoomKeysStatements struct {
insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPrivateKeyStmt *sql.Stmt
insertUserRoomPublicKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt
selectUserRoomPublicKeyStmt *sql.Stmt
selectUserNIDsStmt *sql.Stmt selectUserNIDsStmt *sql.Stmt
} }
@ -71,6 +74,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.insertUserRoomPrivateKeyStmt, insertUserRoomPrivateKeySQL}, {&s.insertUserRoomPrivateKeyStmt, insertUserRoomPrivateKeySQL},
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
{&s.selectUserNIDsStmt, selectUserNIDsSQL}, {&s.selectUserNIDsStmt, selectUserNIDsSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -102,6 +106,21 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey(
return result, err return result, err
} }
func (s *userRoomKeysStatements) SelectUserRoomPublicKey(
ctx context.Context,
txn *sql.Tx,
userNID types.EventStateKeyNID,
roomNID types.RoomNID,
) (ed25519.PublicKey, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomPublicKeyStmt)
var result ed25519.PublicKey
err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return result, err
}
func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt) stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt)

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -251,7 +250,3 @@ func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
} }
func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return u.d.GetUserIDForSender(ctx, roomID, senderID)
}

View File

@ -721,6 +721,22 @@ func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserver
}, err }, err
} }
func (d *Database) GetRoomVersion(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) {
cachedRoomVersion, versionOK := d.Cache.GetRoomVersion(roomID)
if versionOK {
return cachedRoomVersion, nil
}
roomInfo, err := d.RoomInfo(ctx, roomID)
if err != nil {
return "", err
}
if roomInfo == nil {
return "", nil
}
return roomInfo.RoomVersion, nil
}
func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) { func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, eventType); err != nil { if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, eventType); err != nil {
@ -1550,16 +1566,6 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
} }
func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
// TODO: Use real logic once DB for pseudoIDs is in place
return spec.NewUserID(string(senderID), true)
}
func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
// TODO: Use real logic once DB for pseudoIDs is in place
return spec.SenderID(userID.String()), nil
}
// GetKnownRooms returns a list of all rooms we know about. // GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil) return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)
@ -1718,6 +1724,35 @@ func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.Use
return return
} }
// SelectUserRoomPublicKey queries the users room public key.
// If no key exists, returns no key and no error. Otherwise returns
// the key and a database error, if any.
func (d *Database) SelectUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PublicKey, err error) {
uID := userID.String()
stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID})
if sErr != nil {
return nil, sErr
}
stateKeyNID := stateKeyNIDMap[uID]
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String())
if rErr != nil {
return rErr
}
if roomInfo == nil {
return nil
}
key, sErr = d.UserRoomKeyTable.SelectUserRoomPublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID)
if !errors.Is(sErr, sql.ErrNoRows) {
return sErr
}
return nil
})
return
}
// SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID // SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID
func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) { func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) {
result = make(map[spec.RoomID]map[string]string, len(publicKeys)) result = make(map[spec.RoomID]map[string]string, len(publicKeys))

View File

@ -163,12 +163,17 @@ func TestUserRoomKeys(t *testing.T) {
gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *roomID) gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *roomID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, key, gotKey) assert.Equal(t, key, gotKey)
pubKey, err := db.SelectUserRoomPublicKey(context.Background(), *userID, *roomID)
assert.NoError(t, err)
assert.Equal(t, key.Public(), pubKey)
// Key doesn't exist, we shouldn't get anything back // Key doesn't exist, we shouldn't get anything back
assert.NoError(t, err)
gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *doesNotExist) gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *doesNotExist)
assert.NoError(t, err) assert.NoError(t, err)
assert.Nil(t, gotKey) assert.Nil(t, gotKey)
pubKey, err = db.SelectUserRoomPublicKey(context.Background(), *userID, *doesNotExist)
assert.NoError(t, err)
assert.Nil(t, pubKey)
queryUserIDs := map[spec.RoomID][]ed25519.PublicKey{ queryUserIDs := map[spec.RoomID][]ed25519.PublicKey{
*roomID: {key.Public().(ed25519.PublicKey)}, *roomID: {key.Public().(ed25519.PublicKey)},

View File

@ -51,12 +51,15 @@ const insertUserRoomPublicKeySQL = `
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)` const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
type userRoomKeysStatements struct { type userRoomKeysStatements struct {
insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPrivateKeyStmt *sql.Stmt
insertUserRoomPublicKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt
selectUserRoomPublicKeyStmt *sql.Stmt
//selectUserNIDsStmt *sql.Stmt //prepared at runtime //selectUserNIDsStmt *sql.Stmt //prepared at runtime
} }
@ -71,6 +74,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL}, {&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL},
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
//{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime //{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
}.Prepare(db) }.Prepare(db)
} }
@ -102,6 +106,21 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey(
return result, err return result, err
} }
func (s *userRoomKeysStatements) SelectUserRoomPublicKey(
ctx context.Context,
txn *sql.Tx,
userNID types.EventStateKeyNID,
roomNID types.RoomNID,
) (ed25519.PublicKey, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomPublicKeyStmt)
var result ed25519.PublicKey
err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return result, err
}
func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) {
roomNIDs := make([]any, 0, len(senderKeys)) roomNIDs := make([]any, 0, len(senderKeys))

View File

@ -193,6 +193,8 @@ type UserRoomKeys interface {
InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error) InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error)
// SelectUserRoomPrivateKey selects the private key for the given user and room combination // SelectUserRoomPrivateKey selects the private key for the given user and room combination
SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error) SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error)
// SelectUserRoomPublicKey selects the public key for the given user and room combination
SelectUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PublicKey, error)
// BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair. // BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair.
// If a senderKey can't be found, it is omitted in the result. // If a senderKey can't be found, it is omitted in the result.
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)

View File

@ -50,6 +50,7 @@ func TestUserRoomKeysTable(t *testing.T) {
err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error { err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
var gotKey, key2, key3 ed25519.PrivateKey var gotKey, key2, key3 ed25519.PrivateKey
var pubKey ed25519.PublicKey
gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key) gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, gotKey, key) assert.Equal(t, gotKey, key)
@ -71,6 +72,9 @@ func TestUserRoomKeysTable(t *testing.T) {
gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, roomNID) gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, roomNID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, key, gotKey) assert.Equal(t, key, gotKey)
pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, roomNID)
assert.NoError(t, err)
assert.Equal(t, key.Public(), pubKey)
// try to update an existing key, this should only be done for users NOT on this homeserver // try to update an existing key, this should only be done for users NOT on this homeserver
var gotPubKey ed25519.PublicKey var gotPubKey ed25519.PublicKey
@ -82,6 +86,9 @@ func TestUserRoomKeysTable(t *testing.T) {
gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, 2) gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, 2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Nil(t, gotKey) assert.Nil(t, gotKey)
pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, 2)
assert.NoError(t, err)
assert.Nil(t, pubKey)
// query user NIDs for senderKeys // query user NIDs for senderKeys
var gotKeys map[string]types.UserRoomKeyPair var gotKeys map[string]types.UserRoomKeyPair

View File

@ -94,7 +94,7 @@ type MSC2836EventRelationshipsResponse struct {
func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse { func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse {
out := &EventRelationshipResponse{ out := &EventRelationshipResponse{
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), }),
Limited: res.Limited, Limited: res.Limited,

View File

@ -525,11 +525,11 @@ type testRoomserverAPI struct {
events map[string]*types.HeaderedEvent events map[string]*types.HeaderedEvent
} }
func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }
func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
return spec.SenderID(userID.String()), nil return spec.SenderID(userID.String()), nil
} }

View File

@ -377,7 +377,11 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst
return sp, fmt.Errorf("unexpected nil state_key") return sp, fmt.Errorf("unexpected nil state_key")
} }
userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) validRoomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
return sp, err
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*ev.StateKey()))
if err != nil || userID == nil { if err != nil || userID == nil {
return sp, fmt.Errorf("failed getting userID for sender: %w", err) return sp, fmt.Errorf("failed getting userID for sender: %w", err)
} }
@ -404,7 +408,11 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
return return
} }
userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.Event.RoomID(), spec.SenderID(*msg.Event.StateKey())) validRoomID, err := spec.NewRoomID(msg.Event.RoomID())
if err != nil {
return
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*msg.Event.StateKey()))
if err != nil || userID == nil { if err != nil || userID == nil {
return return
} }
@ -454,7 +462,16 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
// Notify any active sync requests that the invite has been retired. // Notify any active sync requests that the invite has been retired.
s.inviteStream.Advance(pduPos) s.inviteStream.Advance(pduPos)
userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.RoomID, msg.TargetSenderID) validRoomID, err := spec.NewRoomID(msg.RoomID)
if err != nil {
log.WithFields(log.Fields{
"event_id": msg.EventID,
"room_id": msg.RoomID,
log.ErrorKey: err,
}).Errorf("roomID is invalid")
return
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, msg.TargetSenderID)
if err != nil || userID == nil { if err != nil || userID == nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": msg.EventID, "event_id": msg.EventID,

View File

@ -139,7 +139,11 @@ func ApplyHistoryVisibilityFilter(
if err != nil { if err != nil {
return nil, err return nil, err
} }
senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *user) roomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
return nil, err
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *user)
if err == nil { if err == nil {
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) { if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) {
eventsFiltered = append(eventsFiltered, ev) eventsFiltered = append(eventsFiltered, ev)

View File

@ -170,11 +170,15 @@ func TrackChangedUsers(
return nil, nil, err return nil, nil, err
} }
for roomID, state := range stateRes.Rooms { for roomID, state := range stateRes.Rooms {
validRoomID, roomErr := spec.NewRoomID(roomID)
if roomErr != nil {
continue
}
for tuple, membership := range state { for tuple, membership := range state {
if membership != spec.Join { if membership != spec.Join {
continue continue
} }
user, queryErr := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey)) user, queryErr := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(tuple.StateKey))
if queryErr != nil || user == nil { if queryErr != nil || user == nil {
continue continue
} }
@ -216,13 +220,17 @@ func TrackChangedUsers(
return nil, left, err return nil, left, err
} }
for roomID, state := range stateRes.Rooms { for roomID, state := range stateRes.Rooms {
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
continue
}
for tuple, membership := range state { for tuple, membership := range state {
if membership != spec.Join { if membership != spec.Join {
continue continue
} }
// new user who we weren't previously sharing rooms with // new user who we weren't previously sharing rooms with
if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok { if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok {
user, err := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey)) user, err := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(tuple.StateKey))
if err != nil || user == nil { if err != nil || user == nil {
continue continue
} }

View File

@ -64,7 +64,7 @@ type mockRoomserverAPI struct {
roomIDToJoinedMembers map[string][]string roomIDToJoinedMembers map[string][]string
} }
func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }

View File

@ -101,13 +101,20 @@ func (n *Notifier) OnNewEvent(
n._removeEmptyUserStreams() n._removeEmptyUserStreams()
if ev != nil { if ev != nil {
validRoomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
"Notifier.OnNewEvent: RoomID is invalid",
)
return
}
// Map this event's room_id to a list of joined users, and wake them up. // Map this event's room_id to a list of joined users, and wake them up.
usersToNotify := n._joinedUsers(ev.RoomID()) usersToNotify := n._joinedUsers(ev.RoomID())
// Map this event's room_id to a list of peeking devices, and wake them up. // Map this event's room_id to a list of peeking devices, and wake them up.
peekingDevicesToNotify := n._peekingDevices(ev.RoomID()) peekingDevicesToNotify := n._peekingDevices(ev.RoomID())
// If this is an invite, also add in the invitee to this list. // If this is an invite, also add in the invitee to this list.
if ev.Type() == "m.room.member" && ev.StateKey() != nil { if ev.Type() == "m.room.member" && ev.StateKey() != nil {
targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), ev.RoomID(), spec.SenderID(*ev.StateKey())) targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), *validRoomID, spec.SenderID(*ev.StateKey()))
if err != nil { if err != nil {
log.WithError(err).WithField("event_id", ev.EventID()).Errorf( log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
"Notifier.OnNewEvent: Failed to find the userID for this event", "Notifier.OnNewEvent: Failed to find the userID for this event",

View File

@ -109,7 +109,7 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) {
type TestRoomServer struct{ api.SyncRoomserverAPI } type TestRoomServer struct{ api.SyncRoomserverAPI }
func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }

View File

@ -200,10 +200,10 @@ func Context(
} }
} }
eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
@ -211,7 +211,7 @@ func Context(
if filter.LazyLoadMembers { if filter.LazyLoadMembers {
allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...)
allEvents = append(allEvents, &requestedEvent) allEvents = append(allEvents, &requestedEvent)
evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache)
@ -224,14 +224,14 @@ func Context(
} }
} }
ev := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { ev := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, requestedEvent) }, requestedEvent)
response := ContextRespsonse{ response := ContextRespsonse{
Event: &ev, Event: &ev,
EventsAfter: eventsAfterClient, EventsAfter: eventsAfterClient,
EventsBefore: eventsBeforeClient, EventsBefore: eventsBeforeClient,
State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), }),
} }

View File

@ -102,14 +102,28 @@ func GetEvent(
} }
sender := spec.UserID{} sender := spec.UserID{}
senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), roomID, events[0].SenderID()) validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("roomID is invalid"),
}
}
senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, events[0].SenderID())
if err == nil && senderUserID != nil { if err == nil && senderUserID != nil {
sender = *senderUserID sender = *senderUserID
} }
sk := events[0].StateKey() sk := events[0].StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey())) evRoomID, err := spec.NewRoomID(events[0].RoomID())
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("roomID is invalid"),
}
}
skUserID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, spec.SenderID(*events[0].StateKey()))
if err == nil && skUserID != nil { if err == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString

View File

@ -152,7 +152,15 @@ func GetMemberships(
} }
} }
userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) validRoomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, ev.SenderID())
if err != nil || userID == nil { if err != nil || userID == nil {
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed") util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed")
return util.JSONResponse{ return util.JSONResponse{
@ -175,7 +183,7 @@ func GetMemberships(
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
})}, })},
} }

View File

@ -273,7 +273,7 @@ func OnIncomingMessagesRequest(
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
})...) })...)
} }
@ -389,7 +389,7 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv
"events_before": len(events), "events_before": len(events),
"events_after": len(filteredEvents), "events_after": len(filteredEvents),
}).Debug("applied history visibility (messages)") }).Debug("applied history visibility (messages)")
return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}), start, end, err }), start, end, err
} }

View File

@ -110,19 +110,24 @@ func Relations(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return util.ErrorResponse(err)
}
// Convert the events into client events, and optionally filter based on the event // Convert the events into client events, and optionally filter based on the event
// type if it was specified. // type if it was specified.
res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents)) res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents))
for _, event := range filteredEvents { for _, event := range filteredEvents {
sender := spec.UserID{} sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID()) userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID())
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey() sk := event.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey())) skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil { if err == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString

View File

@ -205,9 +205,14 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
profileInfos := make(map[string]ProfileInfoResponse) profileInfos := make(map[string]ProfileInfoResponse)
for _, ev := range append(eventsBefore, eventsAfter...) { for _, ev := range append(eventsBefore, eventsAfter...) {
userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) validRoomID, roomErr := spec.NewRoomID(ev.RoomID())
if err != nil {
logrus.WithError(roomErr).WithField("room_id", ev.RoomID()).Warn("failed to query userprofile")
continue
}
userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, ev.SenderID())
if queryErr != nil { if queryErr != nil {
logrus.WithError(queryErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") logrus.WithError(queryErr).WithField("sender_id", ev.SenderID()).Warn("failed to query userprofile")
continue continue
} }
@ -231,14 +236,19 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
} }
sender := spec.UserID{} sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID()) validRoomID, roomErr := spec.NewRoomID(event.RoomID())
if err != nil {
logrus.WithError(roomErr).WithField("room_id", event.RoomID()).Warn("failed to query userprofile")
continue
}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID())
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey() sk := event.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey())) skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil { if err == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString
@ -248,10 +258,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
Context: SearchContextResponse{ Context: SearchContextResponse{
Start: startToken.String(), Start: startToken.String(),
End: endToken.String(), End: endToken.String(),
EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}), }),
EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}), }),
ProfileInfo: profileInfos, ProfileInfo: profileInfos,
@ -272,7 +282,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
}) })
} }

View File

@ -25,7 +25,7 @@ import (
type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI } type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI }
func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }

View File

@ -114,7 +114,14 @@ func (d *Database) StreamEventsToEvents(ctx context.Context, device *userapi.Dev
}).WithError(err).Warnf("Failed to add transaction ID to event") }).WithError(err).Warnf("Failed to add transaction ID to event")
continue continue
} }
deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, in[i].RoomID(), *userID) roomID, err := spec.NewRoomID(in[i].RoomID())
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Room ID is invalid")
continue
}
deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID)
if err != nil { if err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(), "event_id": out[i].EventID(),
@ -515,7 +522,11 @@ func getMembershipFromEvent(ctx context.Context, ev gomatrixserverlib.PDU, userI
if err != nil { if err != nil {
return "", "" return "", ""
} }
senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *fullUser) roomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
return "", ""
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *fullUser)
if err != nil { if err != nil {
return "", "" return "", ""
} }

View File

@ -65,14 +65,18 @@ func (p *InviteStreamProvider) IncrementalSync(
for roomID, inviteEvent := range invites { for roomID, inviteEvent := range invites {
user := spec.UserID{} user := spec.UserID{}
sender, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), inviteEvent.SenderID()) validRoomID, err := spec.NewRoomID(inviteEvent.RoomID())
if err != nil {
continue
}
sender, err := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, inviteEvent.SenderID())
if err == nil && sender != nil { if err == nil && sender != nil {
user = *sender user = *sender
} }
sk := inviteEvent.StateKey() sk := inviteEvent.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey())) skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey()))
if err == nil && skUserID != nil { if err == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString

View File

@ -376,13 +376,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
} }
} }
jr.Timeline.PrevBatch = &prevBatch jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
// If we are limited by the filter AND the history visibility filter // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
req.Response.Rooms.Join[delta.RoomID] = jr req.Response.Rooms.Join[delta.RoomID] = jr
@ -391,11 +391,11 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = &prevBatch jr.Timeline.PrevBatch = &prevBatch
// TODO: Apply history visibility on peeked rooms // TODO: Apply history visibility on peeked rooms
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
jr.Timeline.Limited = limited jr.Timeline.Limited = limited
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
req.Response.Rooms.Peek[delta.RoomID] = jr req.Response.Rooms.Peek[delta.RoomID] = jr
@ -406,13 +406,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
case spec.Ban: case spec.Ban:
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = &prevBatch lr.Timeline.PrevBatch = &prevBatch
lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
// If we are limited by the filter AND the history visibility filter // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
lr.Timeline.Limited = limited && len(events) == len(recentEvents) lr.Timeline.Limited = limited && len(events) == len(recentEvents)
lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
req.Response.Rooms.Leave[delta.RoomID] = lr req.Response.Rooms.Leave[delta.RoomID] = lr
@ -564,13 +564,13 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
} }
jr.Timeline.PrevBatch = prevBatch jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
// If we are limited by the filter AND the history visibility filter // If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited. // didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = limited && len(events) == len(recentEvents) jr.Timeline.Limited = limited && len(events) == len(recentEvents)
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
return jr, nil return jr, nil
@ -585,6 +585,10 @@ func (p *PDUStreamProvider) lazyLoadMembers(
if len(timelineEvents) == 0 { if len(timelineEvents) == 0 {
return stateEvents, nil return stateEvents, nil
} }
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return nil, err
}
// Work out which memberships to include // Work out which memberships to include
timelineUsers := make(map[string]struct{}) timelineUsers := make(map[string]struct{})
if !incremental { if !incremental {
@ -606,8 +610,8 @@ func (p *PDUStreamProvider) lazyLoadMembers(
isGappedIncremental := limited && incremental isGappedIncremental := limited && incremental
// We want this users membership event, keep it in the list // We want this users membership event, keep it in the list
userID := "" userID := ""
stateKeyUserID, err := p.rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(*event.StateKey())) stateKeyUserID, queryErr := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
if err == nil && stateKeyUserID != nil { if queryErr == nil && stateKeyUserID != nil {
userID = stateKeyUserID.String() userID = stateKeyUserID.String()
} }
if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID { if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID {

View File

@ -40,7 +40,7 @@ type syncRoomserverAPI struct {
rooms []*test.Room rooms []*test.Room
} }
func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }

View File

@ -52,14 +52,18 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat,
continue // TODO: shouldn't happen? continue // TODO: shouldn't happen?
} }
sender := spec.UserID{} sender := spec.UserID{}
userID, err := userIDForSender(se.RoomID(), se.SenderID()) validRoomID, err := spec.NewRoomID(se.RoomID())
if err != nil {
continue
}
userID, err := userIDForSender(*validRoomID, se.SenderID())
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := se.StateKey() sk := se.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk)) skUserID, err := userIDForSender(*validRoomID, spec.SenderID(*sk))
if err == nil && skUserID != nil { if err == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString
@ -95,14 +99,18 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender sp
// It provides default logic for event.SenderID & event.StateKey -> userID conversions. // It provides default logic for event.SenderID & event.StateKey -> userID conversions.
func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent { func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent {
sender := spec.UserID{} sender := spec.UserID{}
userID, err := userIDQuery(event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return ClientEvent{}
}
userID, err := userIDQuery(*validRoomID, event.SenderID())
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey() sk := event.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, err := userIDQuery(event.RoomID(), spec.SenderID(*event.StateKey())) skUserID, err := userIDQuery(*validRoomID, spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil { if err == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString

View File

@ -39,7 +39,7 @@ var (
roomIDCounter = int64(0) roomIDCounter = int64(0)
) )
func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) { func UserIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }

View File

@ -302,14 +302,18 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
switch { switch {
case event.Type() == spec.MRoomMember: case event.Type() == spec.MRoomMember:
sender := spec.UserID{} sender := spec.UserID{}
userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) validRoomID, roomErr := spec.NewRoomID(event.RoomID())
if roomErr != nil {
return roomErr
}
userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if queryErr == nil && userID != nil { if queryErr == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey() sk := event.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
if queryErr == nil && skUserID != nil { if queryErr == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString
@ -544,14 +548,18 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
} }
sender := spec.UserID{} sender := spec.UserID{}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return err
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey() sk := event.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
if queryErr == nil && skUserID != nil { if queryErr == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString
@ -644,7 +652,11 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
// user. Returns actions (including dont_notify). // user. Returns actions (including dont_notify).
func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
user := "" user := ""
sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return nil, err
}
sender, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err == nil { if err == nil {
user = sender.String() user = sender.String()
} }
@ -682,7 +694,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
roomSize: roomSize, roomSize: roomSize,
} }
eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global) eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global)
rule, err := eval.MatchEvent(event.PDU, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { rule, err := eval.MatchEvent(event.PDU, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}) })
if err != nil { if err != nil {
@ -790,7 +802,11 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
} }
default: default:
sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) validRoomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
return nil, err
}
sender, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err != nil { if err != nil {
logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID()) logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID())
return nil, err return nil, err
@ -818,7 +834,13 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
logger.WithError(err).Errorf("Failed to convert local user to userID %s", localpart) logger.WithError(err).Errorf("Failed to convert local user to userID %s", localpart)
return nil, err return nil, err
} }
localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, event.RoomID(), *userID) roomID, err := spec.NewRoomID(event.RoomID())
if err != nil {
logger.WithError(err).Errorf("event roomID is invalid %s", event.RoomID())
return nil, err
}
localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID)
if err != nil { if err != nil {
logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID()) logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID())
return nil, err return nil, err

View File

@ -47,7 +47,7 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent {
type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI } type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI }
func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }
@ -68,13 +68,13 @@ func Test_evaluatePushRules(t *testing.T) {
}{ }{
{ {
name: "m.receipt doesn't notify", name: "m.receipt doesn't notify",
eventContent: `{"type":"m.receipt"}`, eventContent: `{"type":"m.receipt","room_id":"!room:example.com"}`,
wantAction: pushrules.UnknownAction, wantAction: pushrules.UnknownAction,
wantActions: nil, wantActions: nil,
}, },
{ {
name: "m.reaction doesn't notify", name: "m.reaction doesn't notify",
eventContent: `{"type":"m.reaction"}`, eventContent: `{"type":"m.reaction","room_id":"!room:example.com"}`,
wantAction: pushrules.DontNotifyAction, wantAction: pushrules.DontNotifyAction,
wantActions: []*pushrules.Action{ wantActions: []*pushrules.Action{
{ {
@ -84,7 +84,7 @@ func Test_evaluatePushRules(t *testing.T) {
}, },
{ {
name: "m.room.message notifies", name: "m.room.message notifies",
eventContent: `{"type":"m.room.message"}`, eventContent: `{"type":"m.room.message","room_id":"!room:example.com"}`,
wantNotify: true, wantNotify: true,
wantAction: pushrules.NotifyAction, wantAction: pushrules.NotifyAction,
wantActions: []*pushrules.Action{ wantActions: []*pushrules.Action{
@ -93,7 +93,7 @@ func Test_evaluatePushRules(t *testing.T) {
}, },
{ {
name: "m.room.message highlights", name: "m.room.message highlights",
eventContent: `{"type":"m.room.message", "content": {"body": "test"}}`, eventContent: `{"type":"m.room.message", "content": {"body": "test"},"room_id":"!room:example.com"}`,
wantNotify: true, wantNotify: true,
wantAction: pushrules.NotifyAction, wantAction: pushrules.NotifyAction,
wantActions: []*pushrules.Action{ wantActions: []*pushrules.Action{