diff --git a/federationapi/internal/keys.go b/federationapi/internal/keys.go index 00e78a1c1..a642f3a4b 100644 --- a/federationapi/internal/keys.go +++ b/federationapi/internal/keys.go @@ -170,7 +170,7 @@ func (s *FederationInternalAPI) handleDatabaseKeys( // in that case. If the key isn't valid right now, then by // leaving it in the 'requests' map, we'll try to update the // key using the fetchers in handleFetcherKeys. - if res.WasValidAt(now, true) { + if res.WasValidAt(now, gomatrixserverlib.StrictValiditySignatureCheck) { delete(requests, req) } } diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index 9e1595053..552c4eac2 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -95,7 +95,7 @@ func Backfill( } } - // Query the roomserver. + // Query the Roomserver. if err = rsAPI.PerformBackfill(httpReq.Context(), &req, &res); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("query.PerformBackfill failed") return util.JSONResponse{ diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index c6f96375e..2980c2af2 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -15,7 +15,6 @@ package routing import ( - "context" "fmt" "net/http" "sort" @@ -33,53 +32,6 @@ import ( "github.com/matrix-org/dendrite/setup/config" ) -type JoinRoomQuerier struct { - roomserver api.FederationRoomserverAPI -} - -func (rq *JoinRoomQuerier) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) { - return rq.roomserver.CurrentStateEvent(ctx, roomID, eventType, stateKey) -} - -func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) { - return rq.roomserver.InvitePending(ctx, roomID, userID) -} - -func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { - roomInfo, err := rq.roomserver.QueryRoomInfo(ctx, roomID) - if err != nil || roomInfo == nil || roomInfo.IsStub() { - return nil, err - } - - req := api.QueryServerJoinedToRoomRequest{ - ServerName: localServerName, - RoomID: roomID.String(), - } - res := api.QueryServerJoinedToRoomResponse{} - if err = rq.roomserver.QueryServerJoinedToRoom(ctx, &req, &res); err != nil { - util.GetLogger(ctx).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") - return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) - } - - userJoinedToRoom, err := rq.roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") - return nil, fmt.Errorf("InternalServerError: %w", err) - } - - locallyJoinedUsers, err := rq.roomserver.LocallyJoinedUsers(ctx, roomInfo.RoomVersion, types.RoomNID(roomInfo.RoomNID)) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("rsAPI.GetLocallyJoinedUsers failed") - return nil, fmt.Errorf("InternalServerError: %w", err) - } - - return &gomatrixserverlib.RestrictedRoomJoinInfo{ - LocalServerInRoom: res.RoomExists && res.IsInRoom, - UserJoinedToRoom: userJoinedToRoom, - JoinedUsers: locallyJoinedUsers, - }, nil -} - // MakeJoin implements the /make_join API func MakeJoin( httpReq *http.Request, @@ -142,8 +94,8 @@ func MakeJoin( return event, stateEvents, nil } - roomQuerier := JoinRoomQuerier{ - roomserver: rsAPI, + roomQuerier := api.JoinRoomQuerier{ + Roomserver: rsAPI, } input := gomatrixserverlib.HandleMakeJoinInput{ diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index a767168d8..d7d5b599d 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -291,10 +291,10 @@ func SendLeave( } } verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ - ServerName: serverName, - Message: redacted, - AtTS: event.OriginServerTS(), - StrictValidityChecking: true, + ServerName: serverName, + Message: redacted, + AtTS: event.OriginServerTS(), + ValidityCheckingFunc: gomatrixserverlib.StrictValiditySignatureCheck, }} verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests) if err != nil { diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 3c8e0cbef..966694541 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -34,7 +34,7 @@ import ( ) const ( - // Event was passed to the roomserver + // Event was passed to the Roomserver MetricsOutcomeOK = "ok" // Event failed to be processed MetricsOutcomeFail = "fail" diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index beeb52495..76a2f3d5a 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -223,7 +223,7 @@ func ExchangeThirdPartyInvite( } } - // Send the event to the roomserver + // Send the event to the Roomserver if err = api.SendEvents( httpReq.Context(), rsAPI, api.KindNew, @@ -324,7 +324,7 @@ func buildMembershipEvent( return nil, errors.New("expecting state tuples for event builder, got none") } - // Ask the roomserver for information about this room + // Ask the Roomserver for information about this room queryReq := api.QueryLatestEventsAndStateRequest{ RoomID: protoEvent.RoomID, StateToFetch: eventsNeeded.Tuples(), diff --git a/go.mod b/go.mod index a20757bbc..a49dfa0c9 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230531155817-0e3adf17bee6 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230606112941-1c41e92ddf9e github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index a1946adaa..79154624a 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230531155817-0e3adf17bee6 h1:Kh1TNvJDhWN5CdgtICNUC4G0wV2km51LGr46Dvl153A= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230531155817-0e3adf17bee6/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230606112941-1c41e92ddf9e h1:I3Sfr8gZvVtLHOeI8lgc62kgLuzpMhBZ6EQOMyexXEA= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230606112941-1c41e92ddf9e/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/internal/caching/cache_serverkeys.go b/internal/caching/cache_serverkeys.go index 37e331ab0..7400b868c 100644 --- a/internal/caching/cache_serverkeys.go +++ b/internal/caching/cache_serverkeys.go @@ -28,7 +28,7 @@ func (c Caches) GetServerKey( ) (gomatrixserverlib.PublicKeyLookupResult, bool) { key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID) val, found := c.ServerKeys.Get(key) - if found && !val.WasValidAt(timestamp, true) { + if found && !val.WasValidAt(timestamp, gomatrixserverlib.StrictValiditySignatureCheck) { // The key wasn't valid at the requested timestamp so don't // return it. The caller will have to work out what to do. c.ServerKeys.Unset(key) diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 7cb3379e0..a37ade3a3 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -32,6 +32,16 @@ func (e ErrNotAllowed) Error() string { return e.Err.Error() } +type RestrictedJoinAPI interface { + CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) + InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) + RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) + QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error) + QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error + UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error) + LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error) +} + // RoomserverInputAPI is used to write events to the room server. type RoomserverInternalAPI interface { SyncRoomserverAPI @@ -199,6 +209,7 @@ type UserRoomserverAPI interface { } type FederationRoomserverAPI interface { + RestrictedJoinAPI InputRoomEventsAPI QueryLatestEventsAndStateAPI QueryBulkStateContentAPI @@ -223,7 +234,7 @@ type FederationRoomserverAPI interface { // Query whether a server is allowed to see an event QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string) (allowed bool, err error) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error - QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error + QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error HandleInvite(ctx context.Context, event *types.HeaderedEvent) error diff --git a/roomserver/api/query.go b/roomserver/api/query.go index b33698c82..e741c1402 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/roomserver/types" @@ -351,26 +352,6 @@ type QueryServerBannedFromRoomResponse struct { Banned bool `json:"banned"` } -type QueryRestrictedJoinAllowedRequest struct { - UserID string `json:"user_id"` - RoomID string `json:"room_id"` -} - -type QueryRestrictedJoinAllowedResponse struct { - // True if the room membership is restricted by the join rule being set to "restricted" - Restricted bool `json:"restricted"` - // True if our local server is joined to all of the allowed rooms specified in the "allow" - // key of the join rule, false if we are missing from some of them and therefore can't - // reliably decide whether or not we can satisfy the join - Resident bool `json:"resident"` - // True if the restricted join is allowed because we found the membership in one of the - // allowed rooms from the join rule, false if not - Allowed bool `json:"allowed"` - // Contains the user ID of the selected user ID that has power to issue invites, this will - // get populated into the "join_authorised_via_users_server" content in the membership - AuthorisedVia string `json:"authorised_via,omitempty"` -} - // MarshalJSON stringifies the room ID and StateKeyTuple keys so they can be sent over the wire in HTTP API mode. func (r *QueryBulkStateContentResponse) MarshalJSON() ([]byte, error) { se := make(map[string]string) @@ -459,6 +440,53 @@ type QueryLeftUsersResponse struct { LeftUsers []string `json:"user_ids"` } +type JoinRoomQuerier struct { + Roomserver RestrictedJoinAPI +} + +func (rq *JoinRoomQuerier) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) { + return rq.Roomserver.CurrentStateEvent(ctx, roomID, eventType, stateKey) +} + +func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) { + return rq.Roomserver.InvitePending(ctx, roomID, userID) +} + +func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { + roomInfo, err := rq.Roomserver.QueryRoomInfo(ctx, roomID) + if err != nil || roomInfo == nil || roomInfo.IsStub() { + return nil, err + } + + req := QueryServerJoinedToRoomRequest{ + ServerName: localServerName, + RoomID: roomID.String(), + } + res := QueryServerJoinedToRoomResponse{} + if err = rq.Roomserver.QueryServerJoinedToRoom(ctx, &req, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") + return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) + } + + userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") + return nil, fmt.Errorf("InternalServerError: %w", err) + } + + locallyJoinedUsers, err := rq.Roomserver.LocallyJoinedUsers(ctx, roomInfo.RoomVersion, types.RoomNID(roomInfo.RoomNID)) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.GetLocallyJoinedUsers failed") + return nil, fmt.Errorf("InternalServerError: %w", err) + } + + return &gomatrixserverlib.RestrictedRoomJoinInfo{ + LocalServerInRoom: res.RoomExists && res.IsInRoom, + UserJoinedToRoom: userJoinedToRoom, + JoinedUsers: locallyJoinedUsers, + }, nil +} + type MembershipQuerier struct { Roomserver FederationRoomserverAPI } diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index ee433f0d2..35b7383a9 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -94,6 +94,7 @@ func NewRoomserverAPI( Cache: caches, IsLocalServerName: dendriteCfg.Global.IsLocalServerName, ServerACLs: serverACLs, + Cfg: dendriteCfg, }, enableMetrics: enableMetrics, // perform-er structs get initialised when we have a federation sender to use diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 34bea5b6d..181a93490 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -372,22 +372,14 @@ func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin( ctx context.Context, joinReq *rsAPI.PerformJoinRequest, ) (string, error) { - req := &api.QueryRestrictedJoinAllowedRequest{ - UserID: joinReq.UserID, - RoomID: joinReq.RoomIDOrAlias, + roomID, err := spec.NewRoomID(joinReq.RoomIDOrAlias) + if err != nil { + return "", err } - res := &api.QueryRestrictedJoinAllowedResponse{} - if err := r.Queryer.QueryRestrictedJoinAllowed(ctx, req, res); err != nil { - return "", fmt.Errorf("r.Queryer.QueryRestrictedJoinAllowed: %w", err) + userID, err := spec.NewUserID(joinReq.UserID, true) + if err != nil { + return "", err } - if !res.Restricted { - return "", nil - } - if !res.Resident { - return "", nil - } - if !res.Allowed { - return "", rsAPI.ErrNotAllowed{Err: fmt.Errorf("the join to room %s was not allowed", joinReq.RoomIDOrAlias)} - } - return res.AuthorisedVia, nil + + return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, *userID) } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index effcc90d7..6d898e8ad 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -17,10 +17,11 @@ package query import ( "context" "database/sql" - "encoding/json" "errors" "fmt" + //"github.com/matrix-org/dendrite/roomserver/internal" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" @@ -44,6 +45,42 @@ type Queryer struct { Cache caching.RoomServerCaches IsLocalServerName func(spec.ServerName) bool ServerACLs *acls.ServerACLs + Cfg *config.Dendrite +} + +func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) { + roomInfo, err := r.QueryRoomInfo(ctx, roomID) + if err != nil || roomInfo == nil || roomInfo.IsStub() { + return nil, err + } + + req := api.QueryServerJoinedToRoomRequest{ + ServerName: localServerName, + RoomID: roomID.String(), + } + res := api.QueryServerJoinedToRoomResponse{} + if err = r.QueryServerJoinedToRoom(ctx, &req, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") + return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err) + } + + userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed") + return nil, fmt.Errorf("InternalServerError: %w", err) + } + + locallyJoinedUsers, err := r.LocallyJoinedUsers(ctx, roomInfo.RoomVersion, types.RoomNID(roomInfo.RoomNID)) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("rsAPI.GetLocallyJoinedUsers failed") + return nil, fmt.Errorf("InternalServerError: %w", err) + } + + return &gomatrixserverlib.RestrictedRoomJoinInfo{ + LocalServerInRoom: res.RoomExists && res.IsInRoom, + UserJoinedToRoom: userJoinedToRoom, + JoinedUsers: locallyJoinedUsers, + }, nil } // QueryLatestEventsAndState implements api.RoomserverInternalAPI @@ -906,131 +943,20 @@ func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixse } // nolint:gocyclo -func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.QueryRestrictedJoinAllowedRequest, res *api.QueryRestrictedJoinAllowedResponse) error { +func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) { // Look up if we know anything about the room. If it doesn't exist // or is a stub entry then we can't do anything. - roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) + roomInfo, err := r.DB.RoomInfo(ctx, roomID.String()) if err != nil { - return fmt.Errorf("r.DB.RoomInfo: %w", err) + return "", fmt.Errorf("r.DB.RoomInfo: %w", err) } if roomInfo == nil || roomInfo.IsStub() { - return nil // fmt.Errorf("room %q doesn't exist or is stub room", req.RoomID) + return "", nil // fmt.Errorf("room %q doesn't exist or is stub room", req.RoomID) } verImpl, err := gomatrixserverlib.GetRoomVersion(roomInfo.RoomVersion) if err != nil { - return err + return "", err } - // If the room version doesn't allow restricted joins then don't - // try to process any further. - allowRestrictedJoins := verImpl.MayAllowRestrictedJoinsInEventAuth() - if !allowRestrictedJoins { - return nil - } - // Start off by populating the "resident" flag in the response. If we - // come across any rooms in the request that are missing, we will unset - // the flag. - res.Resident = true - // Get the join rules to work out if the join rule is "restricted". - joinRulesEvent, err := r.DB.GetStateEvent(ctx, req.RoomID, spec.MRoomJoinRules, "") - if err != nil { - return fmt.Errorf("r.DB.GetStateEvent: %w", err) - } - if joinRulesEvent == nil { - return nil - } - var joinRules gomatrixserverlib.JoinRuleContent - if err = json.Unmarshal(joinRulesEvent.Content(), &joinRules); err != nil { - return fmt.Errorf("json.Unmarshal: %w", err) - } - // If the join rule isn't "restricted" or "knock_restricted" then there's nothing more to do. - res.Restricted = joinRules.JoinRule == spec.Restricted || joinRules.JoinRule == spec.KnockRestricted - if !res.Restricted { - return nil - } - // If the user is already invited to the room then the join is allowed - // but we don't specify an authorised via user, since the event auth - // will allow the join anyway. - var pending bool - if pending, _, _, _, err = helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID); err != nil { - return fmt.Errorf("helpers.IsInvitePending: %w", err) - } else if pending { - res.Allowed = true - return nil - } - // We need to get the power levels content so that we can determine which - // users in the room are entitled to issue invites. We need to use one of - // these users as the authorising user. - powerLevelsEvent, err := r.DB.GetStateEvent(ctx, req.RoomID, spec.MRoomPowerLevels, "") - if err != nil { - return fmt.Errorf("r.DB.GetStateEvent: %w", err) - } - powerLevels, err := powerLevelsEvent.PowerLevels() - if err != nil { - return fmt.Errorf("unable to get powerlevels: %w", err) - } - // Step through the join rules and see if the user matches any of them. - for _, rule := range joinRules.Allow { - // We only understand "m.room_membership" rules at this point in - // time, so skip any rule that doesn't match those. - if rule.Type != spec.MRoomMembership { - continue - } - // See if the room exists. If it doesn't exist or if it's a stub - // room entry then we can't check memberships. - targetRoomInfo, err := r.DB.RoomInfo(ctx, rule.RoomID) - if err != nil || targetRoomInfo == nil || targetRoomInfo.IsStub() { - res.Resident = false - continue - } - // First of all work out if *we* are still in the room, otherwise - // it's possible that the memberships will be out of date. - isIn, err := r.DB.GetLocalServerInRoom(ctx, targetRoomInfo.RoomNID) - if err != nil || !isIn { - // If we aren't in the room, we can no longer tell if the room - // memberships are up-to-date. - res.Resident = false - continue - } - // At this point we're happy that we are in the room, so now let's - // see if the target user is in the room. - _, isIn, _, err = r.DB.GetMembership(ctx, targetRoomInfo.RoomNID, req.UserID) - if err != nil { - continue - } - // If the user is not in the room then we will skip them. - if !isIn { - continue - } - // The user is in the room, so now we will need to authorise the - // join using the user ID of one of our own users in the room. Pick - // one. - joinNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, targetRoomInfo.RoomNID, true, true) - if err != nil || len(joinNIDs) == 0 { - // There should always be more than one join NID at this point - // because we are gated behind GetLocalServerInRoom, but y'know, - // sometimes strange things happen. - continue - } - // For each of the joined users, let's see if we can get a valid - // membership event. - for _, joinNID := range joinNIDs { - events, err := r.DB.Events(ctx, roomInfo.RoomVersion, []types.EventNID{joinNID}) - if err != nil || len(events) != 1 { - continue - } - event := events[0] - if event.Type() != spec.MRoomMember || event.StateKey() == nil { - continue // shouldn't happen - } - // Only users that have the power to invite should be chosen. - if powerLevels.UserLevel(*event.StateKey()) < powerLevels.Invite { - continue - } - res.Resident = true - res.Allowed = true - res.AuthorisedVia = *event.StateKey() - return nil - } - } - return nil + + return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID) } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index d19ebebe4..11a0f5817 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -598,16 +598,15 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) { testCases := []struct { name string prepareRoomFunc func(t *testing.T) *test.Room - wantResponse api.QueryRestrictedJoinAllowedResponse + wantResponse string + wantError bool }{ { name: "public room unrestricted", prepareRoomFunc: func(t *testing.T) *test.Room { return test.NewRoom(t, alice) }, - wantResponse: api.QueryRestrictedJoinAllowedResponse{ - Resident: true, - }, + wantResponse: "", }, { name: "room version without restrictions", @@ -624,10 +623,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) { }, test.WithStateKey("")) return r }, - wantResponse: api.QueryRestrictedJoinAllowedResponse{ - Resident: true, - Restricted: true, - }, + wantError: true, }, { name: "knock_restricted", @@ -638,10 +634,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) { }, test.WithStateKey("")) return r }, - wantResponse: api.QueryRestrictedJoinAllowedResponse{ - Resident: true, - Restricted: true, - }, + wantError: true, }, { name: "restricted with pending invite", // bob should be allowed to join @@ -655,11 +648,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) { }, test.WithStateKey(bob.ID)) return r }, - wantResponse: api.QueryRestrictedJoinAllowedResponse{ - Resident: true, - Restricted: true, - Allowed: true, - }, + wantResponse: "", }, { name: "restricted with allowed room_id, but missing room", // bob should not be allowed to join, as we don't know about the room @@ -680,9 +669,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) { }, test.WithStateKey(bob.ID)) return r }, - wantResponse: api.QueryRestrictedJoinAllowedResponse{ - Restricted: true, - }, + wantError: true, }, { name: "restricted with allowed room_id", // bob should be allowed to join, as we know about the room @@ -703,12 +690,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) { }, test.WithStateKey(bob.ID)) return r }, - wantResponse: api.QueryRestrictedJoinAllowedResponse{ - Resident: true, - Restricted: true, - Allowed: true, - AuthorisedVia: alice.ID, - }, + wantResponse: alice.ID, }, } @@ -738,16 +720,17 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) { t.Errorf("failed to send events: %v", err) } - req := api.QueryRestrictedJoinAllowedRequest{ - UserID: bob.ID, - RoomID: testRoom.ID, + roomID, _ := spec.NewRoomID(testRoom.ID) + userID, _ := spec.NewUserID(bob.ID, true) + got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, *userID) + if tc.wantError && err == nil { + t.Fatal("expected error, got none") } - res := api.QueryRestrictedJoinAllowedResponse{} - if err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), &req, &res); err != nil { + if !tc.wantError && err != nil { t.Fatal(err) } - if !reflect.DeepEqual(tc.wantResponse, res) { - t.Fatalf("unexpected response, want %#v - got %#v", tc.wantResponse, res) + if !reflect.DeepEqual(tc.wantResponse, got) { + t.Fatalf("unexpected response, want %#v - got %#v", tc.wantResponse, got) } }) }