diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index 890b18183..700a72f5d 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -74,7 +74,7 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, JSON: jsonerror.BadJSON("A password must be supplied."), } } - localpart, err := userutil.ParseUsernameParam(username, &t.Config.Matrix.ServerName) + localpart, _, err := userutil.ParseUsernameParam(username, t.Config.Matrix) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 89c269f1a..69bca13be 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -70,7 +70,7 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi if err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - if domain != cfg.Matrix.ServerName { + if !cfg.Matrix.IsLocalServerName(domain) { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.MissingArgument("User ID must belong to this server."), @@ -169,7 +169,7 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien if err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - if domain == cfg.Matrix.ServerName { + if cfg.Matrix.IsLocalServerName(domain) { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.InvalidParam("Can not mark local device list as stale"), diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index 3e837c864..eefe8e24b 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -169,9 +169,21 @@ func createRoom( asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time, ) util.JSONResponse { + _, userDomain, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") + return jsonerror.InternalServerError() + } + if !cfg.Matrix.IsLocalServerName(userDomain) { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden(fmt.Sprintf("User domain %q not configured locally", userDomain)), + } + } + // TODO (#267): Check room ID doesn't clash with an existing one, and we // probably shouldn't be using pseudo-random strings, maybe GUIDs? - roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), cfg.Matrix.ServerName) + roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userDomain) logger := util.GetLogger(ctx) userID := device.UserID @@ -314,7 +326,7 @@ func createRoom( var roomAlias string if r.RoomAliasName != "" { - roomAlias = fmt.Sprintf("#%s:%s", r.RoomAliasName, cfg.Matrix.ServerName) + roomAlias = fmt.Sprintf("#%s:%s", r.RoomAliasName, userDomain) // check it's free TODO: This races but is better than nothing hasAliasReq := roomserverAPI.GetRoomIDForAliasRequest{ Alias: roomAlias, @@ -436,7 +448,7 @@ func createRoom( builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} } var ev *gomatrixserverlib.Event - ev, err = buildEvent(&builder, &authEvents, cfg, evTime, roomVersion) + ev, err = buildEvent(&builder, userDomain, &authEvents, cfg, evTime, roomVersion) if err != nil { util.GetLogger(ctx).WithError(err).Error("buildEvent failed") return jsonerror.InternalServerError() @@ -461,7 +473,7 @@ func createRoom( inputs = append(inputs, roomserverAPI.InputRoomEvent{ Kind: roomserverAPI.KindNew, Event: event, - Origin: cfg.Matrix.ServerName, + Origin: userDomain, SendAsServer: roomserverAPI.DoNotSendToOtherServers, }) } @@ -548,7 +560,7 @@ func createRoom( Event: event, InviteRoomState: inviteStrippedState, RoomVersion: event.RoomVersion, - SendAsServer: string(cfg.Matrix.ServerName), + SendAsServer: string(userDomain), }, &inviteRes); err != nil { util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") return util.JSONResponse{ @@ -591,6 +603,7 @@ func createRoom( // buildEvent fills out auth_events for the builder then builds the event func buildEvent( builder *gomatrixserverlib.EventBuilder, + serverName gomatrixserverlib.ServerName, provider gomatrixserverlib.AuthEventProvider, cfg *config.ClientAPI, evTime time.Time, @@ -606,7 +619,7 @@ func buildEvent( } builder.AuthEvents = refs event, err := builder.Build( - evTime, cfg.Matrix.ServerName, cfg.Matrix.KeyID, + evTime, serverName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, roomVersion, ) if err != nil { diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 836d9e152..33bc63d18 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -75,7 +75,7 @@ func DirectoryRoom( if res.RoomID == "" { // If we don't know it locally, do a federation query. // But don't send the query to ourselves. - if domain != cfg.Matrix.ServerName { + if !cfg.Matrix.IsLocalServerName(domain) { fedRes, fedErr := federation.LookupRoomAlias(req.Context(), domain, roomAlias) if fedErr != nil { // TODO: Return 502 if the remote server errored. @@ -127,7 +127,7 @@ func SetLocalAlias( } } - if domain != cfg.Matrix.ServerName { + if !cfg.Matrix.IsLocalServerName(domain) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("Alias must be on local homeserver"), diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go index 8ddb3267a..4ebf2295a 100644 --- a/clientapi/routing/directory_public.go +++ b/clientapi/routing/directory_public.go @@ -62,8 +62,7 @@ func GetPostPublicRooms( } serverName := gomatrixserverlib.ServerName(request.Server) - - if serverName != "" && serverName != cfg.Matrix.ServerName { + if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) { res, err := federation.GetPublicRoomsFiltered( req.Context(), serverName, int(request.Limit), request.Since, diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 6017b5840..7f5a8c4f8 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -68,7 +68,7 @@ func Login( return *authErr } // make a device/access token - authErr2 := completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) + authErr2 := completeAuth(req.Context(), cfg.Matrix, userAPI, login, req.RemoteAddr, req.UserAgent()) cleanup(req.Context(), &authErr2) return authErr2 } @@ -79,7 +79,7 @@ func Login( } func completeAuth( - ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.ClientUserAPI, login *auth.Login, + ctx context.Context, cfg *config.Global, userAPI userapi.ClientUserAPI, login *auth.Login, ipAddr, userAgent string, ) util.JSONResponse { token, err := auth.GenerateAccessToken() @@ -88,7 +88,7 @@ func completeAuth( return jsonerror.InternalServerError() } - localpart, err := userutil.ParseUsernameParam(login.Username(), &serverName) + localpart, serverName, err := userutil.ParseUsernameParam(login.Username(), cfg) if err != nil { util.GetLogger(ctx).WithError(err).Error("auth.ParseUsernameParam failed") return jsonerror.InternalServerError() diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 77f627eb2..94ba17a02 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -105,12 +105,13 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic return jsonerror.InternalServerError() } + serverName := device.UserDomain() if err = roomserverAPI.SendEvents( ctx, rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, - cfg.Matrix.ServerName, - cfg.Matrix.ServerName, + serverName, + serverName, nil, false, ); err != nil { @@ -271,7 +272,7 @@ func sendInvite( Event: event, InviteRoomState: nil, // ask the roomserver to draw up invite room state for us RoomVersion: event.RoomVersion, - SendAsServer: string(cfg.Matrix.ServerName), + SendAsServer: string(device.UserDomain()), }, &inviteRes); err != nil { util.GetLogger(ctx).WithError(err).Error("PerformInvite failed") return util.JSONResponse{ @@ -341,7 +342,7 @@ func loadProfile( } var profile *authtypes.Profile - if serverName == cfg.Matrix.ServerName { + if cfg.Matrix.IsLocalServerName(serverName) { profile, err = appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, profileAPI) } else { profile = &authtypes.Profile{} diff --git a/clientapi/routing/openid.go b/clientapi/routing/openid.go index cfb440bea..8e9be7889 100644 --- a/clientapi/routing/openid.go +++ b/clientapi/routing/openid.go @@ -63,7 +63,7 @@ func CreateOpenIDToken( JSON: openIDTokenResponse{ AccessToken: response.Token.Token, TokenType: "Bearer", - MatrixServerName: string(cfg.Matrix.ServerName), + MatrixServerName: string(device.UserDomain()), ExpiresIn: response.Token.ExpiresAtMS / 1000, // convert ms to s }, } diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index c9647eb1b..4d9e1f8a5 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -113,12 +113,19 @@ func SetAvatarURL( } } - localpart, _, err := gomatrixserverlib.SplitID('@', userID) + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() } + if !cfg.Matrix.IsLocalServerName(domain) { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("userID does not belong to a locally configured domain"), + } + } + evTime, err := httputil.ParseTSParam(req) if err != nil { return util.JSONResponse{ @@ -129,8 +136,9 @@ func SetAvatarURL( setRes := &userapi.PerformSetAvatarURLResponse{} if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{ - Localpart: localpart, - AvatarURL: r.AvatarURL, + Localpart: localpart, + ServerName: domain, + AvatarURL: r.AvatarURL, }, setRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed") return jsonerror.InternalServerError() @@ -204,12 +212,19 @@ func SetDisplayName( } } - localpart, _, err := gomatrixserverlib.SplitID('@', userID) + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() } + if !cfg.Matrix.IsLocalServerName(domain) { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("userID does not belong to a locally configured domain"), + } + } + evTime, err := httputil.ParseTSParam(req) if err != nil { return util.JSONResponse{ @@ -221,6 +236,7 @@ func SetDisplayName( profileRes := &userapi.PerformUpdateDisplayNameResponse{} err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{ Localpart: localpart, + ServerName: domain, DisplayName: r.DisplayName, }, profileRes) if err != nil { @@ -261,6 +277,12 @@ func updateProfile( return jsonerror.InternalServerError(), err } + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") + return jsonerror.InternalServerError(), err + } + events, err := buildMembershipEvents( ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, ) @@ -276,7 +298,7 @@ func updateProfile( return jsonerror.InternalServerError(), e } - if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, domain, domain, nil, true); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError(), err } @@ -298,7 +320,7 @@ func getProfile( return nil, err } - if domain != cfg.Matrix.ServerName { + if !cfg.Matrix.IsLocalServerName(domain) { profile, fedErr := federation.LookupProfile(ctx, domain, userID, "") if fedErr != nil { if x, ok := fedErr.(gomatrix.HTTPError); ok { diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index a0f3b1152..778a02fd4 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -131,7 +131,8 @@ func SendRedaction( JSON: jsonerror.NotFound("Room does not exist"), } } - if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, false); err != nil { + domain := device.UserDomain() + if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, domain, domain, nil, false); err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 0bda1e488..698d185b4 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -412,7 +412,7 @@ func UserIDIsWithinApplicationServiceNamespace( return false } - if domain != cfg.Matrix.ServerName { + if !cfg.Matrix.IsLocalServerName(domain) { return false } diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 114e9088d..bb66cf6fc 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -94,6 +94,7 @@ func SendEvent( // create a mutex for the specific user in the specific room // this avoids a situation where events that are received in quick succession are sent to the roomserver in a jumbled order userID := device.UserID + domain := device.UserDomain() mutex, _ := userRoomSendMutexes.LoadOrStore(roomID+userID, &sync.Mutex{}) mutex.(*sync.Mutex).Lock() defer mutex.(*sync.Mutex).Unlock() @@ -185,8 +186,8 @@ func SendEvent( []*gomatrixserverlib.HeaderedEvent{ e.Headered(verRes.RoomVersion), }, - cfg.Matrix.ServerName, - cfg.Matrix.ServerName, + domain, + domain, txnAndSessionID, false, ); err != nil { diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 9670fecad..99fb8171d 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -215,7 +215,7 @@ func queryIDServerStoreInvite( } var profile *authtypes.Profile - if serverName == cfg.Matrix.ServerName { + if cfg.Matrix.IsLocalServerName(serverName) { res := &userapi.QueryProfileResponse{} err = userAPI.QueryProfile(ctx, &userapi.QueryProfileRequest{UserID: device.UserID}, res) if err != nil { diff --git a/clientapi/userutil/userutil.go b/clientapi/userutil/userutil.go index 7e909ffad..9be1e9b31 100644 --- a/clientapi/userutil/userutil.go +++ b/clientapi/userutil/userutil.go @@ -17,6 +17,7 @@ import ( "fmt" "strings" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -24,23 +25,23 @@ import ( // usernameParam can either be a user ID or just the localpart/username. // If serverName is passed, it is verified against the domain obtained from usernameParam (if present) // Returns error in case of invalid usernameParam. -func ParseUsernameParam(usernameParam string, expectedServerName *gomatrixserverlib.ServerName) (string, error) { +func ParseUsernameParam(usernameParam string, cfg *config.Global) (string, gomatrixserverlib.ServerName, error) { localpart := usernameParam if strings.HasPrefix(usernameParam, "@") { lp, domain, err := gomatrixserverlib.SplitID('@', usernameParam) if err != nil { - return "", errors.New("invalid username") + return "", "", errors.New("invalid username") } - if expectedServerName != nil && domain != *expectedServerName { - return "", errors.New("user ID does not belong to this server") + if !cfg.IsLocalServerName(domain) { + return "", "", errors.New("user ID does not belong to this server") } - localpart = lp + return lp, domain, nil } - return localpart, nil + return localpart, cfg.ServerName, nil } // MakeUserID generates user ID from localpart & server name diff --git a/clientapi/userutil/userutil_test.go b/clientapi/userutil/userutil_test.go index 2628642fb..ccd6647b2 100644 --- a/clientapi/userutil/userutil_test.go +++ b/clientapi/userutil/userutil_test.go @@ -15,6 +15,7 @@ package userutil import ( "testing" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -28,7 +29,11 @@ var ( // TestGoodUserID checks that correct localpart is returned for a valid user ID. func TestGoodUserID(t *testing.T) { - lp, err := ParseUsernameParam(goodUserID, &serverName) + cfg := &config.Global{ + ServerName: serverName, + } + + lp, _, err := ParseUsernameParam(goodUserID, cfg) if err != nil { t.Error("User ID Parsing failed for ", goodUserID, " with error: ", err.Error()) @@ -41,7 +46,11 @@ func TestGoodUserID(t *testing.T) { // TestWithLocalpartOnly checks that localpart is returned when usernameParam contains only localpart. func TestWithLocalpartOnly(t *testing.T) { - lp, err := ParseUsernameParam(localpart, &serverName) + cfg := &config.Global{ + ServerName: serverName, + } + + lp, _, err := ParseUsernameParam(localpart, cfg) if err != nil { t.Error("User ID Parsing failed for ", localpart, " with error: ", err.Error()) @@ -54,7 +63,11 @@ func TestWithLocalpartOnly(t *testing.T) { // TestIncorrectDomain checks for error when there's server name mismatch. func TestIncorrectDomain(t *testing.T) { - _, err := ParseUsernameParam(goodUserID, &invalidServerName) + cfg := &config.Global{ + ServerName: invalidServerName, + } + + _, _, err := ParseUsernameParam(goodUserID, cfg) if err == nil { t.Error("Invalid Domain should return an error") @@ -63,7 +76,11 @@ func TestIncorrectDomain(t *testing.T) { // TestBadUserID checks that ParseUsernameParam fails for invalid user ID func TestBadUserID(t *testing.T) { - _, err := ParseUsernameParam(badUserID, &serverName) + cfg := &config.Global{ + ServerName: serverName, + } + + _, _, err := ParseUsernameParam(badUserID, cfg) if err == nil { t.Error("Illegal User ID should return an error") diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index f6dace702..a58cba1b1 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -69,7 +69,7 @@ func AddPublicRoutes( TopicPresenceEvent: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent), TopicDeviceListUpdate: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), TopicSigningKeyUpdate: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), - ServerName: cfg.Matrix.ServerName, + Config: cfg, UserAPI: userAPI, } @@ -107,7 +107,7 @@ func NewInternalAPI( ) api.FederationInternalAPI { cfg := &base.Cfg.FederationAPI - federationDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.ServerName) + federationDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName) if err != nil { logrus.WithError(err).Panic("failed to connect to federation sender db") } diff --git a/federationapi/federationapi_keys_test.go b/federationapi/federationapi_keys_test.go index 85cc43aa5..7ccc02f76 100644 --- a/federationapi/federationapi_keys_test.go +++ b/federationapi/federationapi_keys_test.go @@ -87,6 +87,7 @@ func TestMain(m *testing.M) { cfg.Global.JetStream.StoragePath = config.Path(d) cfg.Global.KeyID = serverKeyID cfg.Global.KeyValidityPeriod = s.validity + cfg.FederationAPI.KeyPerspectives = nil f, err := os.CreateTemp(d, "federation_keys_test*.db") if err != nil { return -1 @@ -207,7 +208,6 @@ func TestRenewalBehaviour(t *testing.T) { // happy at this point that the key that we already have is from the past // then repeating a key fetch should cause us to try and renew the key. // If so, then the new key will end up in our cache. - serverC.renew() res, err = serverA.api.FetchKeys( diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index e923143a7..c37bc87c2 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -164,6 +164,7 @@ func TestFederationAPIJoinThenKeyUpdate(t *testing.T) { func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { base, close := testrig.CreateBaseDendrite(t, dbType) base.Cfg.FederationAPI.PreferDirectFetch = true + base.Cfg.FederationAPI.KeyPerspectives = nil defer close() jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) diff --git a/federationapi/internal/keys.go b/federationapi/internal/keys.go index 2b7a8219a..258bd88bf 100644 --- a/federationapi/internal/keys.go +++ b/federationapi/internal/keys.go @@ -99,7 +99,7 @@ func (s *FederationInternalAPI) handleLocalKeys( results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, ) { for req := range requests { - if req.ServerName != s.cfg.Matrix.ServerName { + if !s.cfg.Matrix.IsLocalServerName(req.ServerName) { continue } if req.KeyID == s.cfg.Matrix.KeyID { diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 28ec48d7b..1b61ec711 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -77,7 +77,7 @@ func (r *FederationInternalAPI) PerformJoin( seenSet := make(map[gomatrixserverlib.ServerName]bool) var uniqueList []gomatrixserverlib.ServerName for _, srv := range request.ServerNames { - if seenSet[srv] || srv == r.cfg.Matrix.ServerName { + if seenSet[srv] || r.cfg.Matrix.IsLocalServerName(srv) { continue } seenSet[srv] = true diff --git a/federationapi/producers/syncapi.go b/federationapi/producers/syncapi.go index 659ff1bcf..7cce13a7d 100644 --- a/federationapi/producers/syncapi.go +++ b/federationapi/producers/syncapi.go @@ -25,6 +25,7 @@ import ( "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -39,7 +40,7 @@ type SyncAPIProducer struct { TopicDeviceListUpdate string TopicSigningKeyUpdate string JetStream nats.JetStreamContext - ServerName gomatrixserverlib.ServerName + Config *config.FederationAPI UserAPI userapi.UserInternalAPI } @@ -77,7 +78,7 @@ func (p *SyncAPIProducer) SendToDevice( // device. If the event isn't targeted locally then we can't expand the // wildcard as we don't know about the remote devices, so instead we leave it // as-is, so that the federation sender can send it on with the wildcard intact. - if domain == p.ServerName && deviceID == "*" { + if p.Config.Matrix.IsLocalServerName(domain) && deviceID == "*" { var res userapi.QueryDevicesResponse err = p.UserAPI.QueryDevices(context.TODO(), &userapi.QueryDevicesRequest{ UserID: userID, diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index a1b280103..7ef4646f7 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -47,7 +47,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase connStr, dbClose := test.PrepareDBConnectionString(t, dbType) db, err := storage.NewDatabase(b, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, b.Caches, b.Cfg.Global.ServerName) + }, b.Caches, b.Cfg.Global.IsLocalServerName) if err != nil { t.Fatalf("NewDatabase returned %s", err) } diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index e25f9866e..9f16e5093 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -124,7 +124,7 @@ func Setup( mu := internal.NewMutexByRoom() v1fedmux.Handle("/send/{txnID}", MakeFedAPI( - "federation_send", cfg.Matrix.ServerName, keys, wakeup, + "federation_send", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return Send( httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]), @@ -134,7 +134,7 @@ func Setup( )).Methods(http.MethodPut, http.MethodOptions) v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( - "federation_invite", cfg.Matrix.ServerName, keys, wakeup, + "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -150,7 +150,7 @@ func Setup( )).Methods(http.MethodPut, http.MethodOptions) v2fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( - "federation_invite", cfg.Matrix.ServerName, keys, wakeup, + "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -172,7 +172,7 @@ func Setup( )).Methods(http.MethodPost, http.MethodOptions) v1fedmux.Handle("/exchange_third_party_invite/{roomID}", MakeFedAPI( - "exchange_third_party_invite", cfg.Matrix.ServerName, keys, wakeup, + "exchange_third_party_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return ExchangeThirdPartyInvite( httpReq, request, vars["roomID"], rsAPI, cfg, federation, @@ -181,7 +181,7 @@ func Setup( )).Methods(http.MethodPut, http.MethodOptions) v1fedmux.Handle("/event/{eventID}", MakeFedAPI( - "federation_get_event", cfg.Matrix.ServerName, keys, wakeup, + "federation_get_event", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetEvent( httpReq.Context(), request, rsAPI, vars["eventID"], cfg.Matrix.ServerName, @@ -190,7 +190,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/state/{roomID}", MakeFedAPI( - "federation_get_state", cfg.Matrix.ServerName, keys, wakeup, + "federation_get_state", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -205,7 +205,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/state_ids/{roomID}", MakeFedAPI( - "federation_get_state_ids", cfg.Matrix.ServerName, keys, wakeup, + "federation_get_state_ids", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -220,7 +220,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/event_auth/{roomID}/{eventID}", MakeFedAPI( - "federation_get_event_auth", cfg.Matrix.ServerName, keys, wakeup, + "federation_get_event_auth", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -235,7 +235,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/query/directory", MakeFedAPI( - "federation_query_room_alias", cfg.Matrix.ServerName, keys, wakeup, + "federation_query_room_alias", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return RoomAliasToID( httpReq, federation, cfg, rsAPI, fsAPI, @@ -244,7 +244,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/query/profile", MakeFedAPI( - "federation_query_profile", cfg.Matrix.ServerName, keys, wakeup, + "federation_query_profile", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetProfile( httpReq, userAPI, cfg, @@ -253,7 +253,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/user/devices/{userID}", MakeFedAPI( - "federation_user_devices", cfg.Matrix.ServerName, keys, wakeup, + "federation_user_devices", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetUserDevices( httpReq, keyAPI, vars["userID"], @@ -263,7 +263,7 @@ func Setup( if mscCfg.Enabled("msc2444") { v1fedmux.Handle("/peek/{roomID}/{peekID}", MakeFedAPI( - "federation_peek", cfg.Matrix.ServerName, keys, wakeup, + "federation_peek", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -294,7 +294,7 @@ func Setup( } v1fedmux.Handle("/make_join/{roomID}/{userID}", MakeFedAPI( - "federation_make_join", cfg.Matrix.ServerName, keys, wakeup, + "federation_make_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -325,7 +325,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/send_join/{roomID}/{eventID}", MakeFedAPI( - "federation_send_join", cfg.Matrix.ServerName, keys, wakeup, + "federation_send_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -357,7 +357,7 @@ func Setup( )).Methods(http.MethodPut) v2fedmux.Handle("/send_join/{roomID}/{eventID}", MakeFedAPI( - "federation_send_join", cfg.Matrix.ServerName, keys, wakeup, + "federation_send_join", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -374,7 +374,7 @@ func Setup( )).Methods(http.MethodPut) v1fedmux.Handle("/make_leave/{roomID}/{eventID}", MakeFedAPI( - "federation_make_leave", cfg.Matrix.ServerName, keys, wakeup, + "federation_make_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -391,7 +391,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/send_leave/{roomID}/{eventID}", MakeFedAPI( - "federation_send_leave", cfg.Matrix.ServerName, keys, wakeup, + "federation_send_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -423,7 +423,7 @@ func Setup( )).Methods(http.MethodPut) v2fedmux.Handle("/send_leave/{roomID}/{eventID}", MakeFedAPI( - "federation_send_leave", cfg.Matrix.ServerName, keys, wakeup, + "federation_send_leave", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -447,7 +447,7 @@ func Setup( )).Methods(http.MethodGet) v1fedmux.Handle("/get_missing_events/{roomID}", MakeFedAPI( - "federation_get_missing_events", cfg.Matrix.ServerName, keys, wakeup, + "federation_get_missing_events", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -460,7 +460,7 @@ func Setup( )).Methods(http.MethodPost) v1fedmux.Handle("/backfill/{roomID}", MakeFedAPI( - "federation_backfill", cfg.Matrix.ServerName, keys, wakeup, + "federation_backfill", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ @@ -479,14 +479,14 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost) v1fedmux.Handle("/user/keys/claim", MakeFedAPI( - "federation_keys_claim", cfg.Matrix.ServerName, keys, wakeup, + "federation_keys_claim", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return ClaimOneTimeKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) v1fedmux.Handle("/user/keys/query", MakeFedAPI( - "federation_keys_query", cfg.Matrix.ServerName, keys, wakeup, + "federation_keys_query", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) }, @@ -525,15 +525,15 @@ func ErrorIfLocalServerNotInRoom( // MakeFedAPI makes an http.Handler that checks matrix federation authentication. func MakeFedAPI( - metricsName string, - serverName gomatrixserverlib.ServerName, + metricsName string, serverName gomatrixserverlib.ServerName, + isLocalServerName func(gomatrixserverlib.ServerName) bool, keyRing gomatrixserverlib.JSONVerifier, wakeup *FederationWakeups, f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse, ) http.Handler { h := func(req *http.Request) util.JSONResponse { fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( - req, time.Now(), serverName, keyRing, + req, time.Now(), serverName, isLocalServerName, keyRing, ) if fedReq == nil { return errResp diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index 6e208d096..a33fa4a43 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -36,7 +36,7 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { var d Database var err error if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { @@ -96,7 +96,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, } d.Database = shared.Database{ DB: d.db, - ServerName: serverName, + IsLocalServerName: isLocalServerName, Cache: cache, Writer: d.writer, FederationJoinedHosts: joinedHosts, diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 6afb313a8..4fabff7d4 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -29,7 +29,7 @@ import ( type Database struct { DB *sql.DB - ServerName gomatrixserverlib.ServerName + IsLocalServerName func(gomatrixserverlib.ServerName) bool Cache caching.FederationCache Writer sqlutil.Writer FederationQueuePDUs tables.FederationQueuePDUs @@ -124,7 +124,7 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, } if excludeSelf { for i, server := range servers { - if server == d.ServerName { + if d.IsLocalServerName(server) { servers = append(servers[:i], servers[i+1:]...) } } diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index c89cb6bea..e86ac817b 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -35,7 +35,7 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { var d Database var err error if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { @@ -95,7 +95,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, } d.Database = shared.Database{ DB: d.db, - ServerName: serverName, + IsLocalServerName: isLocalServerName, Cache: cache, Writer: d.writer, FederationJoinedHosts: joinedHosts, diff --git a/federationapi/storage/storage.go b/federationapi/storage/storage.go index f246b9bc9..142e281ea 100644 --- a/federationapi/storage/storage.go +++ b/federationapi/storage/storage.go @@ -29,12 +29,12 @@ import ( ) // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (Database, error) { +func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, cache, serverName) + return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties, cache, serverName) + return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 6272fd2b1..f7408fa9f 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -19,7 +19,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Dat connStr, dbClose := test.PrepareDBConnectionString(t, dbType) db, err := storage.NewDatabase(b, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, b.Caches, b.Cfg.Global.ServerName) + }, b.Caches, func(server gomatrixserverlib.ServerName) bool { return server == "localhost" }) if err != nil { t.Fatalf("NewDatabase returned %s", err) } diff --git a/go.mod b/go.mod index 7f9bb3897..39dfb0fe1 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a + github.com/matrix-org/gomatrixserverlib v0.0.0-20221025142407-17b0be811afa github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index 5cce7e0d8..5e6253860 100644 --- a/go.sum +++ b/go.sum @@ -387,8 +387,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a h1:6rJFN5NBuzZ7h5meYkLtXKa6VFZfDc8oVXHd4SDXr5o= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221021091412-7c772f1b388a/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221025142407-17b0be811afa h1:S98DShDv3sn7O4n4HjtJOejypseYVpv1R/XPg+cDnfI= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221025142407-17b0be811afa/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3 h1:lzkSQvBv8TuqKJCPoVwOVvEnARTlua5rrNy/Qw2Vxeo= github.com/matrix-org/pinecone v0.0.0-20221007145426-3adc85477dd3/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index cb6b22d32..6a6d51b0a 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -117,6 +117,11 @@ func (r *Admin) PerformAdminEvacuateRoom( PrevEvents: prevEvents, } + _, senderDomain, err := gomatrixserverlib.SplitID('@', fledglingEvent.Sender) + if err != nil { + continue + } + if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, @@ -146,8 +151,8 @@ func (r *Admin) PerformAdminEvacuateRoom( inputEvents = append(inputEvents, api.InputRoomEvent{ Kind: api.KindNew, Event: event, - Origin: r.Cfg.Matrix.ServerName, - SendAsServer: string(r.Cfg.Matrix.ServerName), + Origin: senderDomain, + SendAsServer: string(senderDomain), }) res.Affected = append(res.Affected, stateKey) prevEvents = []gomatrixserverlib.EventReference{ @@ -176,7 +181,7 @@ func (r *Admin) PerformAdminEvacuateUser( } return nil } - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: "Can only evacuate local users using this endpoint", diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 3fbdf332e..f60247cd7 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -70,8 +70,8 @@ func (r *Inviter) PerformInvite( } return nil, nil } - isTargetLocal := domain == r.Cfg.Matrix.ServerName - isOriginLocal := senderDomain == r.Cfg.Matrix.ServerName + isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain) + isOriginLocal := r.Cfg.Matrix.IsLocalServerName(senderDomain) if !isOriginLocal && !isTargetLocal { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 262273ff5..9d596ab30 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -92,7 +92,7 @@ func (r *Joiner) performJoin( Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), } } - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { return "", "", &rsAPI.PerformError{ Code: rsAPI.PerformErrorBadRequest, Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), @@ -124,7 +124,7 @@ func (r *Joiner) performJoinRoomByAlias( // Check if this alias matches our own server configuration. If it // doesn't then we'll need to try a federated join. var roomID string - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { // The alias isn't owned by us, so we will need to try joining using // a remote server. dirReq := fsAPI.PerformDirectoryLookupRequest{ @@ -172,7 +172,7 @@ func (r *Joiner) performJoinRoomByID( // The original client request ?server_name=... may include this HS so filter that out so we // don't attempt to make_join with ourselves for i := 0; i < len(req.ServerNames); i++ { - if req.ServerNames[i] == r.Cfg.Matrix.ServerName { + if r.Cfg.Matrix.IsLocalServerName(req.ServerNames[i]) { // delete this entry req.ServerNames = append(req.ServerNames[:i], req.ServerNames[i+1:]...) i-- @@ -191,12 +191,19 @@ func (r *Joiner) performJoinRoomByID( // If the server name in the room ID isn't ours then it's a // possible candidate for finding the room via federation. Add // it to the list of servers to try. - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { req.ServerNames = append(req.ServerNames, domain) } // Prepare the template for the join event. userID := req.UserID + _, userDomain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return "", "", &rsAPI.PerformError{ + Code: rsAPI.PerformErrorBadRequest, + Msg: fmt.Sprintf("User ID %q is invalid: %s", userID, err), + } + } eb := gomatrixserverlib.EventBuilder{ Type: gomatrixserverlib.MRoomMember, Sender: userID, @@ -247,7 +254,7 @@ func (r *Joiner) performJoinRoomByID( // If we were invited by someone from another server then we can // assume they are in the room so we can join via them. - if inviterDomain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(inviterDomain) { req.ServerNames = append(req.ServerNames, inviterDomain) forceFederatedJoin = true memberEvent := gjson.Parse(string(inviteEvent.JSON())) @@ -300,7 +307,7 @@ func (r *Joiner) performJoinRoomByID( { Kind: rsAPI.KindNew, Event: event.Headered(buildRes.RoomVersion), - SendAsServer: string(r.Cfg.Matrix.ServerName), + SendAsServer: string(userDomain), }, }, } @@ -323,7 +330,7 @@ func (r *Joiner) performJoinRoomByID( // The room doesn't exist locally. If the room ID looks like it should // be ours then this probably means that we've nuked our database at // some point. - if domain == r.Cfg.Matrix.ServerName { + if r.Cfg.Matrix.IsLocalServerName(domain) { // If there are no more server names to try then give up here. // Otherwise we'll try a federated join as normal, since it's quite // possible that the room still exists on other servers. @@ -348,7 +355,7 @@ func (r *Joiner) performJoinRoomByID( // it will have been overwritten with a room ID by performJoinRoomByAlias. // We should now include this in the response so that the CS API can // return the right room ID. - return req.RoomIDOrAlias, r.Cfg.Matrix.ServerName, nil + return req.RoomIDOrAlias, userDomain, nil } func (r *Joiner) performFederatedJoinRoomByID( diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 85b659814..49e4b479a 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -52,7 +52,7 @@ func (r *Leaver) PerformLeave( if err != nil { return nil, fmt.Errorf("supplied user ID %q in incorrect format", req.UserID) } - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID) } logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ @@ -85,7 +85,7 @@ func (r *Leaver) performLeaveRoomByID( if serr != nil { return nil, fmt.Errorf("sender %q is invalid", senderUser) } - if senderDomain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(senderDomain) { return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID) } // check that this is not a "server notice room" @@ -186,7 +186,7 @@ func (r *Leaver) performLeaveRoomByID( Kind: api.KindNew, Event: event.Headered(buildRes.RoomVersion), Origin: senderDomain, - SendAsServer: string(r.Cfg.Matrix.ServerName), + SendAsServer: string(senderDomain), }, }, } diff --git a/roomserver/internal/perform/perform_peek.go b/roomserver/internal/perform/perform_peek.go index 74d87a5b4..436d137ff 100644 --- a/roomserver/internal/perform/perform_peek.go +++ b/roomserver/internal/perform/perform_peek.go @@ -72,7 +72,7 @@ func (r *Peeker) performPeek( Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), } } - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { return "", &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), @@ -104,7 +104,7 @@ func (r *Peeker) performPeekRoomByAlias( // Check if this alias matches our own server configuration. If it // doesn't then we'll need to try a federated peek. var roomID string - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { // The alias isn't owned by us, so we will need to try peeking using // a remote server. dirReq := fsAPI.PerformDirectoryLookupRequest{ @@ -154,7 +154,7 @@ func (r *Peeker) performPeekRoomByID( // handle federated peeks // FIXME: don't create an outbound peek if we already have one going. - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { // If the server name in the room ID isn't ours then it's a // possible candidate for finding the room via federation. Add // it to the list of servers to try. diff --git a/roomserver/internal/perform/perform_unpeek.go b/roomserver/internal/perform/perform_unpeek.go index 49e9067c9..0d97da4d6 100644 --- a/roomserver/internal/perform/perform_unpeek.go +++ b/roomserver/internal/perform/perform_unpeek.go @@ -67,7 +67,7 @@ func (r *Unpeeker) performUnpeek( Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), } } - if domain != r.Cfg.Matrix.ServerName { + if !r.Cfg.Matrix.IsLocalServerName(domain) { return &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index d6dc9708c..38abe323c 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -60,6 +60,13 @@ func (r *Upgrader) performRoomUpgrade( ) (string, *api.PerformError) { roomID := req.RoomID userID := req.UserID + _, userDomain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return "", &api.PerformError{ + Code: api.PerformErrorNotAllowed, + Msg: "Error validating the user ID", + } + } evTime := time.Now() // Return an immediate error if the room does not exist @@ -80,7 +87,7 @@ func (r *Upgrader) performRoomUpgrade( // TODO (#267): Check room ID doesn't clash with an existing one, and we // probably shouldn't be using pseudo-random strings, maybe GUIDs? - newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), r.Cfg.Matrix.ServerName) + newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userDomain) // Get the existing room state for the old room. oldRoomReq := &api.QueryLatestEventsAndStateRequest{ @@ -107,12 +114,12 @@ func (r *Upgrader) performRoomUpgrade( } // Send the setup events to the new room - if pErr = r.sendInitialEvents(ctx, evTime, userID, newRoomID, string(req.RoomVersion), eventsToMake); pErr != nil { + if pErr = r.sendInitialEvents(ctx, evTime, userID, userDomain, newRoomID, string(req.RoomVersion), eventsToMake); pErr != nil { return "", pErr } // 5. Send the tombstone event to the old room - if pErr = r.sendHeaderedEvent(ctx, tombstoneEvent, string(r.Cfg.Matrix.ServerName)); pErr != nil { + if pErr = r.sendHeaderedEvent(ctx, userDomain, tombstoneEvent, string(userDomain)); pErr != nil { return "", pErr } @@ -122,7 +129,7 @@ func (r *Upgrader) performRoomUpgrade( } // If the old room had a canonical alias event, it should be deleted in the old room - if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, userID, roomID); pErr != nil { + if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, userID, userDomain, roomID); pErr != nil { return "", pErr } @@ -132,7 +139,7 @@ func (r *Upgrader) performRoomUpgrade( } // 6. Restrict power levels in the old room - if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, userID, roomID); pErr != nil { + if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, userID, userDomain, roomID); pErr != nil { return "", pErr } @@ -154,7 +161,7 @@ func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*goma return powerLevelContent, nil } -func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID, roomID string) *api.PerformError { +func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, roomID string) *api.PerformError { restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID) if pErr != nil { return pErr @@ -183,7 +190,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T return resErr } } else { - if resErr = r.sendHeaderedEvent(ctx, restrictedPowerLevelsHeadered, api.DoNotSendToOtherServers); resErr != nil { + if resErr = r.sendHeaderedEvent(ctx, userDomain, restrictedPowerLevelsHeadered, api.DoNotSendToOtherServers); resErr != nil { return resErr } } @@ -223,7 +230,7 @@ func moveLocalAliases(ctx context.Context, return nil } -func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID, roomID string) *api.PerformError { +func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, roomID string) *api.PerformError { for _, event := range oldRoom.StateEvents { if event.Type() != gomatrixserverlib.MRoomCanonicalAlias || !event.StateKeyEquals("") { continue @@ -254,7 +261,7 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api return resErr } } else { - if resErr = r.sendHeaderedEvent(ctx, emptyCanonicalAliasEvent, api.DoNotSendToOtherServers); resErr != nil { + if resErr = r.sendHeaderedEvent(ctx, userDomain, emptyCanonicalAliasEvent, api.DoNotSendToOtherServers); resErr != nil { return resErr } } @@ -495,7 +502,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query return eventsToMake, nil } -func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID, newRoomID, newVersion string, eventsToMake []fledglingEvent) *api.PerformError { +func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain gomatrixserverlib.ServerName, newRoomID, newVersion string, eventsToMake []fledglingEvent) *api.PerformError { var err error var builtEvents []*gomatrixserverlib.HeaderedEvent authEvents := gomatrixserverlib.NewAuthEvents(nil) @@ -519,7 +526,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user builder.PrevEvents = []gomatrixserverlib.EventReference{builtEvents[i-1].EventReference()} } var event *gomatrixserverlib.Event - event, err = r.buildEvent(&builder, &authEvents, evTime, gomatrixserverlib.RoomVersion(newVersion)) + event, err = r.buildEvent(&builder, userDomain, &authEvents, evTime, gomatrixserverlib.RoomVersion(newVersion)) if err != nil { return &api.PerformError{ Msg: fmt.Sprintf("Failed to build new %q event: %s", builder.Type, err), @@ -547,7 +554,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user inputs = append(inputs, api.InputRoomEvent{ Kind: api.KindNew, Event: event, - Origin: r.Cfg.Matrix.ServerName, + Origin: userDomain, SendAsServer: api.DoNotSendToOtherServers, }) } @@ -668,6 +675,7 @@ func createTemporaryPowerLevels(powerLevelContent *gomatrixserverlib.PowerLevelC func (r *Upgrader) sendHeaderedEvent( ctx context.Context, + serverName gomatrixserverlib.ServerName, headeredEvent *gomatrixserverlib.HeaderedEvent, sendAsServer string, ) *api.PerformError { @@ -675,7 +683,7 @@ func (r *Upgrader) sendHeaderedEvent( inputs = append(inputs, api.InputRoomEvent{ Kind: api.KindNew, Event: headeredEvent, - Origin: r.Cfg.Matrix.ServerName, + Origin: serverName, SendAsServer: sendAsServer, }) if err := api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { @@ -689,6 +697,7 @@ func (r *Upgrader) sendHeaderedEvent( func (r *Upgrader) buildEvent( builder *gomatrixserverlib.EventBuilder, + serverName gomatrixserverlib.ServerName, provider gomatrixserverlib.AuthEventProvider, evTime time.Time, roomVersion gomatrixserverlib.RoomVersion, @@ -703,7 +712,7 @@ func (r *Upgrader) buildEvent( } builder.AuthEvents = refs event, err := builder.Build( - evTime, r.Cfg.Matrix.ServerName, r.Cfg.Matrix.KeyID, + evTime, serverName, r.Cfg.Matrix.KeyID, r.Cfg.Matrix.PrivateKey, roomVersion, ) if err != nil { diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 784893d24..825772827 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -14,6 +14,9 @@ type Global struct { // The name of the server. This is usually the domain name, e.g 'matrix.org', 'localhost'. ServerName gomatrixserverlib.ServerName `yaml:"server_name"` + // The secondary server names, used for virtual hosting. + SecondaryServerNames []gomatrixserverlib.ServerName `yaml:"-"` + // Path to the private key which will be used to sign requests and events. PrivateKeyPath Path `yaml:"private_key"` @@ -120,6 +123,18 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { c.Cache.Verify(configErrs, isMonolith) } +func (c *Global) IsLocalServerName(serverName gomatrixserverlib.ServerName) bool { + if c.ServerName == serverName { + return true + } + for _, secondaryName := range c.SecondaryServerNames { + if secondaryName == serverName { + return true + } + } + return false +} + type OldVerifyKeys struct { // Path to the private key. PrivateKeyPath Path `yaml:"private_key"` diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 452b14580..98502f5cb 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -132,7 +132,7 @@ func Enable( base.PublicFederationAPIMux.Handle("/unstable/event_relationships", httputil.MakeExternalAPI( "msc2836_event_relationships", func(req *http.Request) util.JSONResponse { fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( - req, time.Now(), base.Cfg.Global.ServerName, keyRing, + req, time.Now(), base.Cfg.Global.ServerName, base.Cfg.Global.IsLocalServerName, keyRing, ) if fedReq == nil { return errResp diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index a92a16a27..bc9df0f96 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -64,7 +64,7 @@ func Enable( fedAPI := httputil.MakeExternalAPI( "msc2946_fed_spaces", func(req *http.Request) util.JSONResponse { fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( - req, time.Now(), base.Cfg.Global.ServerName, keyRing, + req, time.Now(), base.Cfg.Global.ServerName, base.Cfg.Global.IsLocalServerName, keyRing, ) if fedReq == nil { return errResp diff --git a/test/testrig/base.go b/test/testrig/base.go index 10cc2407b..15fb5c370 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -36,6 +36,7 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f Monolithic: true, }) cfg.Global.JetStream.InMemory = true + cfg.FederationAPI.KeyPerspectives = nil switch dbType { case test.DBTypePostgres: cfg.Global.Defaults(config.DefaultOpts{ // autogen a signing key @@ -106,6 +107,7 @@ func Base(cfg *config.Dendrite) (*base.BaseDendrite, nats.JetStreamContext, *nat } cfg.Global.JetStream.InMemory = true cfg.SyncAPI.Fulltext.InMemory = true + cfg.FederationAPI.KeyPerspectives = nil base := base.NewBaseDendrite(cfg, "Tests") js, jc := base.NATS.Prepare(base.ProcessContext, &cfg.Global.JetStream) return base, js, jc diff --git a/userapi/api/api.go b/userapi/api/api.go index eef29144a..8d7f783de 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -318,8 +318,9 @@ type QuerySearchProfilesResponse struct { // PerformAccountCreationRequest is the request for PerformAccountCreation type PerformAccountCreationRequest struct { - AccountType AccountType // Required: whether this is a guest or user account - Localpart string // Required: The localpart for this account. Ignored if account type is guest. + AccountType AccountType // Required: whether this is a guest or user account + Localpart string // Required: The localpart for this account. Ignored if account type is guest. + ServerName gomatrixserverlib.ServerName // optional: if not specified, default server name used instead AppServiceID string // optional: the application service ID (not user ID) creating this account, if any. Password string // optional: if missing then this account will be a passwordless account @@ -360,7 +361,8 @@ type PerformLastSeenUpdateResponse struct { // PerformDeviceCreationRequest is the request for PerformDeviceCreation type PerformDeviceCreationRequest struct { Localpart string - AccessToken string // optional: if blank one will be made on your behalf + ServerName gomatrixserverlib.ServerName // optional: if blank, default server name used + AccessToken string // optional: if blank one will be made on your behalf // optional: if nil an ID is generated for you. If set, replaces any existing device session, // which will generate a new access token and invalidate the old one. DeviceID *string @@ -384,7 +386,8 @@ type PerformDeviceCreationResponse struct { // PerformAccountDeactivationRequest is the request for PerformAccountDeactivation type PerformAccountDeactivationRequest struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName // optional: if blank, default server name used } // PerformAccountDeactivationResponse is the response for PerformAccountDeactivation @@ -434,6 +437,18 @@ type Device struct { AccountType AccountType } +func (d *Device) UserDomain() gomatrixserverlib.ServerName { + _, domain, err := gomatrixserverlib.SplitID('@', d.UserID) + if err != nil { + // This really is catastrophic because it means that someone + // managed to forge a malformed user ID for a device during + // login. + // TODO: Is there a better way to deal with this than panic? + panic(err) + } + return domain +} + // Account represents a Matrix account on this home server. type Account struct { UserID string @@ -577,7 +592,9 @@ type Notification struct { } type PerformSetAvatarURLRequest struct { - Localpart, AvatarURL string + Localpart string + ServerName gomatrixserverlib.ServerName + AvatarURL string } type PerformSetAvatarURLResponse struct { Profile *authtypes.Profile `json:"profile"` @@ -606,7 +623,9 @@ type QueryAccountByPasswordResponse struct { } type PerformUpdateDisplayNameRequest struct { - Localpart, DisplayName string + Localpart string + ServerName gomatrixserverlib.ServerName + DisplayName string } type PerformUpdateDisplayNameResponse struct { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 7b94b3da7..9ca76965d 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -46,9 +46,9 @@ import ( type UserInternalAPI struct { DB storage.Database SyncProducer *producers.SyncAPI + Config *config.UserAPI DisableTLSValidation bool - ServerName gomatrixserverlib.ServerName // AppServices is the list of all registered AS AppServices []config.ApplicationService KeyAPI keyapi.UserKeyAPI @@ -62,8 +62,8 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot update account data of remote users (server name %s)", domain) } if req.DataType == "" { return fmt.Errorf("data type must not be empty") @@ -104,7 +104,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun logrus.WithError(err).Error("UserInternalAPI.setFullyRead: SplitID failure") return nil } - if domain != a.ServerName { + if !a.Config.Matrix.IsLocalServerName(domain) { return nil } @@ -171,6 +171,11 @@ func addUserToRoom( } func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { + serverName := req.ServerName + if serverName == "" { + serverName = a.Config.Matrix.ServerName + } + // XXXX: Use the server name here acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) if err != nil { if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists @@ -188,8 +193,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P res.Account = &api.Account{ AppServiceID: req.AppServiceID, Localpart: req.Localpart, - ServerName: a.ServerName, - UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + ServerName: serverName, + UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), AccountType: req.AccountType, } return nil @@ -235,6 +240,12 @@ func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.Pe } func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error { + serverName := req.ServerName + if serverName == "" { + serverName = a.Config.Matrix.ServerName + } + _ = serverName + // XXXX: Use the server name here util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, "device_id": req.DeviceID, @@ -259,8 +270,8 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot PerformDeviceDeletion of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot PerformDeviceDeletion of remote users (server name %s)", domain) } deletedDeviceIDs := req.DeviceIDs if len(req.DeviceIDs) == 0 { @@ -392,8 +403,8 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot query profile of remote users (server name %s)", domain) } prof, err := a.DB.GetProfileByLocalpart(ctx, local) if err != nil { @@ -443,8 +454,8 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot query devices of remote users (server name %s)", domain) } devs, err := a.DB.GetDevicesByLocalpart(ctx, local) if err != nil { @@ -460,8 +471,8 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot query account data of remote users (server name %s)", domain) } if req.DataType != "" { var data json.RawMessage @@ -509,10 +520,13 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc } return err } - localPart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localPart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { return err } + if !a.Config.Matrix.IsLocalServerName(domain) { + return nil + } acc, err := a.DB.GetAccountByLocalpart(ctx, localPart) if err != nil { return err @@ -547,7 +561,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe AccountType: api.AccountTypeAppService, } - localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName) + localpart, _, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix) if err != nil { return nil, err } @@ -572,8 +586,16 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe // PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again. func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { + serverName := req.ServerName + if serverName == "" { + serverName = a.Config.Matrix.ServerName + } + if !a.Config.Matrix.IsLocalServerName(serverName) { + return fmt.Errorf("server name %q not locally configured", serverName) + } + evacuateReq := &rsapi.PerformAdminEvacuateUserRequest{ - UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), } evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{} if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil { @@ -584,7 +606,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a } deviceReq := &api.PerformDeviceDeletionRequest{ - UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), } deviceRes := &api.PerformDeviceDeletionResponse{} if err := a.PerformDeviceDeletion(ctx, deviceReq, deviceRes); err != nil { diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go index f1bf391e4..87f25e5e2 100644 --- a/userapi/internal/api_logintoken.go +++ b/userapi/internal/api_logintoken.go @@ -31,8 +31,8 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot create a login token for a remote user (server name %s)", domain) } tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data) if err != nil { @@ -63,8 +63,8 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain) } if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil { res.Data = nil diff --git a/userapi/userapi.go b/userapi/userapi.go index c077248e2..e46a8e76e 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -76,7 +76,7 @@ func NewInternalAPI( userAPI := &internal.UserInternalAPI{ DB: db, SyncProducer: syncProducer, - ServerName: cfg.Matrix.ServerName, + Config: cfg, AppServices: appServices, KeyAPI: keyAPI, RSAPI: rsAPI, diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index aaa93f45b..2a43c0bd4 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -66,8 +66,8 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap } return &internal.UserInternalAPI{ - DB: accountDB, - ServerName: cfg.Matrix.ServerName, + DB: accountDB, + Config: cfg, }, accountDB, func() { close() baseclose()