diff --git a/clientapi/clientapi_test.go b/clientapi/clientapi_test.go index 2ff4b650..1b2f1358 100644 --- a/clientapi/clientapi_test.go +++ b/clientapi/clientapi_test.go @@ -2151,3 +2151,130 @@ func TestKeyBackup(t *testing.T) { } }) } + +func TestGetMembership(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + + testCases := []struct { + name string + roomID string + user *test.User + additionalEvents func(t *testing.T, room *test.Room) + request func(t *testing.T, room *test.Room, accessToken string) *http.Request + wantOK bool + wantMemberCount int + }{ + + { + name: "/joined_members - Bob never joined", + user: bob, + request: func(t *testing.T, room *test.Room, accessToken string) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": accessToken, + })) + }, + wantOK: false, + }, + { + name: "/joined_members - Alice joined", + user: alice, + request: func(t *testing.T, room *test.Room, accessToken string) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": accessToken, + })) + }, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "/joined_members - Alice leaves, shouldn't be able to see members ", + user: alice, + request: func(t *testing.T, room *test.Room, accessToken string) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": accessToken, + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(alice.ID)) + }, + wantOK: false, + }, + { + name: "/joined_members - Bob joins, Alice sees two members", + user: alice, + request: func(t *testing.T, room *test.Room, accessToken string) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": accessToken, + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + wantOK: true, + wantMemberCount: 2, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + defer close() + natsInstance := jetstream.NATSInstance{} + jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream) + + // Use an actual roomserver for this + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) + + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + + accessTokens := map[*test.User]userDevice{ + alice: {}, + bob: {}, + } + createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + room := test.NewRoom(t, alice) + t.Cleanup(func() { + t.Logf("running cleanup for %s", tc.name) + }) + // inject additional events + if tc.additionalEvents != nil { + tc.additionalEvents(t, room) + } + if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + w := httptest.NewRecorder() + routers.Client.ServeHTTP(w, tc.request(t, room, accessTokens[tc.user].accessToken)) + if w.Code != 200 && tc.wantOK { + t.Logf("%s", w.Body.String()) + t.Fatalf("got HTTP %d want %d", w.Code, 200) + } + t.Logf("[%s] Resp: %s", tc.name, w.Body.String()) + + // check we got the expected events + if tc.wantOK { + memberCount := len(gjson.GetBytes(w.Body.Bytes(), "joined").Map()) + if memberCount != tc.wantMemberCount { + t.Fatalf("expected %d members, got %d", tc.wantMemberCount, memberCount) + } + } + }) + } + }) +} diff --git a/clientapi/routing/memberships.go b/clientapi/routing/memberships.go new file mode 100644 index 00000000..84be498d --- /dev/null +++ b/clientapi/routing/memberships.go @@ -0,0 +1,139 @@ +// Copyright 2024 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "net/http" + + "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" +) + +// https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-rooms-roomid-joined-members +type getJoinedMembersResponse struct { + Joined map[string]joinedMember `json:"joined"` +} + +type joinedMember struct { + DisplayName string `json:"display_name"` + AvatarURL string `json:"avatar_url"` +} + +// The database stores 'displayname' without an underscore. +// Deserialize into this and then change to the actual API response +type databaseJoinedMember struct { + DisplayName string `json:"displayname"` + AvatarURL string `json:"avatar_url"` +} + +// GetJoinedMembers implements +// +// GET /rooms/{roomId}/joined_members +func GetJoinedMembers( + req *http.Request, device *userapi.Device, roomID string, + rsAPI api.ClientRoomserverAPI, +) util.JSONResponse { + // Validate the userID + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Device UserID is invalid"), + } + } + + // Validate the roomID + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("RoomID is invalid"), + } + } + + // Get the current memberships for the requesting user to determine + // if they are allowed to query this endpoint. + queryReq := api.QueryMembershipForUserRequest{ + RoomID: validRoomID.String(), + UserID: *userID, + } + + var queryRes api.QueryMembershipForUserResponse + if queryErr := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); queryErr != nil { + util.GetLogger(req.Context()).WithError(queryErr).Error("rsAPI.QueryMembershipsForRoom failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + if !queryRes.HasBeenInRoom { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You aren't a member of the room and weren't previously a member of the room."), + } + } + + if !queryRes.IsInRoom { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You aren't a member of the room and weren't previously a member of the room."), + } + } + + // Get the current membership events + var membershipsForRoomResp api.QueryMembershipsForRoomResponse + if err = rsAPI.QueryMembershipsForRoom(req.Context(), &api.QueryMembershipsForRoomRequest{ + JoinedOnly: true, + RoomID: validRoomID.String(), + }, &membershipsForRoomResp); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryEventsByID failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + var res getJoinedMembersResponse + res.Joined = make(map[string]joinedMember) + for _, ev := range membershipsForRoomResp.JoinEvents { + var content databaseJoinedMember + if err := json.Unmarshal(ev.Content, &content); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("failed to unmarshal event content") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(ev.Sender)) + if err != nil || userID == nil { + util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + res.Joined[userID.String()] = joinedMember(content) + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: res, + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index d4aa1d08..3e23ab40 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -1513,4 +1513,14 @@ func Setup( return GetPresence(req, device, natsClient, cfg.Matrix.JetStream.Prefixed(jetstream.RequestPresence), vars["userId"]) }), ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/rooms/{roomID}/joined_members", + httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetJoinedMembers(req, device, vars["roomID"], rsAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) } diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index e849adf6..9cc937d8 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -15,7 +15,6 @@ package routing import ( - "encoding/json" "math" "net/http" @@ -33,31 +32,13 @@ type getMembershipResponse struct { Chunk []synctypes.ClientEvent `json:"chunk"` } -// https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-rooms-roomid-joined-members -type getJoinedMembersResponse struct { - Joined map[string]joinedMember `json:"joined"` -} - -type joinedMember struct { - DisplayName string `json:"display_name"` - AvatarURL string `json:"avatar_url"` -} - -// The database stores 'displayname' without an underscore. -// Deserialize into this and then change to the actual API response -type databaseJoinedMember struct { - DisplayName string `json:"displayname"` - AvatarURL string `json:"avatar_url"` -} - // GetMemberships implements // // GET /rooms/{roomId}/members -// GET /rooms/{roomId}/joined_members func GetMemberships( req *http.Request, device *userapi.Device, roomID string, syncDB storage.Database, rsAPI api.SyncRoomserverAPI, - joinedOnly bool, membership, notMembership *string, at string, + membership, notMembership *string, at string, ) util.JSONResponse { userID, err := spec.NewUserID(device.UserID, true) if err != nil { @@ -87,13 +68,6 @@ func GetMemberships( } } - if joinedOnly && !queryRes.IsInRoom { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("You aren't a member of the room and weren't previously a member of the room."), - } - } - db, err := syncDB.NewDatabaseSnapshot(req.Context()) if err != nil { return util.JSONResponse{ @@ -139,40 +113,6 @@ func GetMemberships( result := qryRes.Events - if joinedOnly { - var res getJoinedMembersResponse - res.Joined = make(map[string]joinedMember) - for _, ev := range result { - var content databaseJoinedMember - if err := json.Unmarshal(ev.Content(), &content); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("failed to unmarshal event content") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - - userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) - if err != nil || userID == nil { - util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed") - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - if err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), - } - } - res.Joined[userID.String()] = joinedMember(content) - } - return util.JSONResponse{ - Code: http.StatusOK, - JSON: res, - } - } return util.JSONResponse{ Code: http.StatusOK, JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index a837e169..78188d1b 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -197,19 +197,7 @@ func Setup( } at := req.URL.Query().Get("at") - return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, false, membership, notMembership, at) + return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, membership, notMembership, at) }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) - - v3mux.Handle("/rooms/{roomID}/joined_members", - httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.ErrorResponse(err) - } - at := req.URL.Query().Get("at") - membership := spec.Join - return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, true, &membership, nil, at) - }), - ).Methods(http.MethodGet, http.MethodOptions) } diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index ac526851..0a2c38ab 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -753,24 +753,6 @@ func TestGetMembership(t *testing.T) { }, wantOK: false, }, - { - name: "/joined_members - Bob never joined", - request: func(t *testing.T, room *test.Room) *http.Request { - return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ - "access_token": bobDev.AccessToken, - })) - }, - wantOK: false, - }, - { - name: "/joined_members - Alice joined", - request: func(t *testing.T, room *test.Room) *http.Request { - return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ - "access_token": aliceDev.AccessToken, - })) - }, - wantOK: true, - }, { name: "Alice leaves before Bob joins, should not be able to see Bob", request: func(t *testing.T, room *test.Room) *http.Request { @@ -809,21 +791,6 @@ func TestGetMembership(t *testing.T) { wantOK: true, wantMemberCount: 2, }, - { - name: "/joined_members - Alice leaves, shouldn't be able to see members ", - request: func(t *testing.T, room *test.Room) *http.Request { - return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ - "access_token": aliceDev.AccessToken, - })) - }, - additionalEvents: func(t *testing.T, room *test.Room) { - room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{ - "membership": "leave", - }, test.WithStateKey(alice.ID)) - }, - useSleep: true, - wantOK: false, - }, { name: "'at' specified, returns memberships before Bob joins", request: func(t *testing.T, room *test.Room) *http.Request {