diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 62ffa615..2d17e092 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -15,6 +15,8 @@ package clientapi import ( + "github.com/matrix-org/gomatrixserverlib" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/producers" @@ -26,7 +28,6 @@ import ( "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) // AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component. diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index c50e552b..e371d921 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -37,6 +37,7 @@ func JoinRoomByIDOrAlias( joinReq := roomserverAPI.PerformJoinRequest{ RoomIDOrAlias: roomIDOrAlias, UserID: device.UserID, + IsGuest: device.AccountType == api.AccountTypeGuest, Content: map[string]interface{}{}, } joinRes := roomserverAPI.PerformJoinResponse{} @@ -84,7 +85,14 @@ func JoinRoomByIDOrAlias( if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil { done <- jsonerror.InternalAPIError(req.Context(), err) } else if joinRes.Error != nil { - done <- joinRes.Error.JSONResponse() + if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest { + done <- util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg), + } + } else { + done <- joinRes.Error.JSONResponse() + } } else { done <- util.JSONResponse{ Code: http.StatusOK, diff --git a/clientapi/routing/joinroom_test.go b/clientapi/routing/joinroom_test.go new file mode 100644 index 00000000..9e8208e6 --- /dev/null +++ b/clientapi/routing/joinroom_test.go @@ -0,0 +1,158 @@ +package routing + +import ( + "bytes" + "context" + "net/http" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestJoinRoomByIDOrAlias(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := test.NewUser(t, test.WithAccountType(uapi.AccountTypeGuest)) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) + rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc + + // Create the users in the userapi + for _, u := range []*test.User{alice, bob, charlie} { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: "someRandomPassword", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + + } + + aliceDev := &uapi.Device{UserID: alice.ID} + bobDev := &uapi.Device{UserID: bob.ID} + charlieDev := &uapi.Device{UserID: charlie.ID, AccountType: uapi.AccountTypeGuest} + + // create a room with disabled guest access and invite Bob + resp := createRoom(ctx, createRoomRequest{ + Name: "testing", + IsDirect: true, + Topic: "testing", + Visibility: "public", + Preset: presetPublicChat, + RoomAliasName: "alias", + Invite: []string{bob.ID}, + GuestCanJoin: false, + }, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) + crResp, ok := resp.JSON.(createRoomResponse) + if !ok { + t.Fatalf("response is not a createRoomResponse: %+v", resp) + } + + // create a room with guest access enabled and invite Charlie + resp = createRoom(ctx, createRoomRequest{ + Name: "testing", + IsDirect: true, + Topic: "testing", + Visibility: "public", + Preset: presetPublicChat, + Invite: []string{charlie.ID}, + GuestCanJoin: true, + }, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) + crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse) + if !ok { + t.Fatalf("response is not a createRoomResponse: %+v", resp) + } + + // Dummy request + body := &bytes.Buffer{} + req, err := http.NewRequest(http.MethodPost, "/?server_name=test", body) + if err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + device *uapi.Device + roomID string + wantHTTP200 bool + }{ + { + name: "User can join successfully by alias", + device: bobDev, + roomID: crResp.RoomAlias, + wantHTTP200: true, + }, + { + name: "User can join successfully by roomID", + device: bobDev, + roomID: crResp.RoomID, + wantHTTP200: true, + }, + { + name: "join is forbidden if user is guest", + device: charlieDev, + roomID: crResp.RoomID, + }, + { + name: "room does not exist", + device: aliceDev, + roomID: "!doesnotexist:test", + }, + { + name: "user from different server", + device: &uapi.Device{UserID: "@wrong:server"}, + roomID: crResp.RoomAlias, + }, + { + name: "user doesn't exist locally", + device: &uapi.Device{UserID: "@doesnotexist:test"}, + roomID: crResp.RoomAlias, + }, + { + name: "invalid room ID", + device: aliceDev, + roomID: "invalidRoomID", + }, + { + name: "roomAlias does not exist", + device: aliceDev, + roomID: "#doesnotexist:test", + }, + { + name: "room with guest_access event", + device: charlieDev, + roomID: crRespWithGuestAccess.RoomID, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + joinResp := JoinRoomByIDOrAlias(req, tc.device, rsAPI, userAPI, tc.roomID) + if tc.wantHTTP200 && !joinResp.Is2xx() { + t.Fatalf("expected join room to succeed, but didn't: %+v", joinResp) + } + }) + } + }) +} diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index e70e5ea9..e789b956 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -78,6 +78,7 @@ const ( type PerformJoinRequest struct { RoomIDOrAlias string `json:"room_id_or_alias"` UserID string `json:"user_id"` + IsGuest bool `json:"is_guest"` Content map[string]interface{} `json:"content"` ServerNames []gomatrixserverlib.ServerName `json:"server_names"` Unsigned map[string]interface{} `json:"unsigned"` diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 1a362660..451b3769 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -4,6 +4,10 @@ import ( "context" "github.com/getsentry/sentry-go" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" + asAPI "github.com/matrix-org/dendrite/appservice/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/caching" @@ -19,9 +23,6 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" ) // RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI @@ -104,6 +105,11 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio r.fsAPI = fsAPI r.KeyRing = keyRing + identity, err := r.Cfg.Matrix.SigningIdentityFor(r.ServerName) + if err != nil { + logrus.Panic(err) + } + r.Inputer = &input.Inputer{ Cfg: &r.Base.Cfg.RoomServer, Base: r.Base, @@ -114,7 +120,8 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio JetStream: r.JetStream, NATSClient: r.NATSClient, Durable: nats.Durable(r.Durable), - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, + SigningIdentity: identity, FSAPI: fsAPI, KeyRing: keyRing, ACLs: r.ServerACLs, @@ -135,7 +142,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Queryer: r.Queryer, } r.Peeker = &perform.Peeker{ - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, Cfg: r.Cfg, DB: r.DB, FSAPI: r.fsAPI, @@ -146,7 +153,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Inputer: r.Inputer, } r.Unpeeker = &perform.Unpeeker{ - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, Cfg: r.Cfg, DB: r.DB, FSAPI: r.fsAPI, @@ -193,6 +200,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) { r.Leaver.UserAPI = userAPI + r.Inputer.UserAPI = userAPI } func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index e965691c..94131103 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -23,6 +23,8 @@ import ( "sync" "time" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/Arceliar/phony" "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" @@ -79,6 +81,7 @@ type Inputer struct { JetStream nats.JetStreamContext Durable nats.SubOpt ServerName gomatrixserverlib.ServerName + SigningIdentity *gomatrixserverlib.SigningIdentity FSAPI fedapi.RoomserverFederationAPI KeyRing gomatrixserverlib.JSONVerifier ACLs *acls.ServerACLs @@ -87,6 +90,7 @@ type Inputer struct { workers sync.Map // room ID -> *worker Queryer *query.Queryer + UserAPI userapi.RoomserverUserAPI } // If a room consumer is inactive for a while then we will allow NATS diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 10b8ee27..4179fc1e 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -19,6 +19,7 @@ package input import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "time" @@ -31,6 +32,8 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + userAPI "github.com/matrix-org/dendrite/userapi/api" + fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" @@ -440,6 +443,13 @@ func (r *Inputer) processRoomEvent( } } + // If guest_access changed and is not can_join, kick all guest users. + if event.Type() == gomatrixserverlib.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" { + if err = r.kickGuests(ctx, event, roomInfo); err != nil { + logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation") + } + } + // Everything was OK — the latest events updater didn't error and // we've sent output events. Finally, generate a hook call. hooks.Run(hooks.KindNewEventPersisted, headered) @@ -729,3 +739,98 @@ func (r *Inputer) calculateAndSetState( succeeded = true return nil } + +// kickGuests kicks guests users from m.room.guest_access rooms, if guest access is now prohibited. +func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo) error { + membershipNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true) + if err != nil { + return err + } + + memberEvents, err := r.DB.Events(ctx, membershipNIDs) + if err != nil { + return err + } + + inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents)) + latestReq := &api.QueryLatestEventsAndStateRequest{ + RoomID: event.RoomID(), + } + latestRes := &api.QueryLatestEventsAndStateResponse{} + if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { + return err + } + + prevEvents := latestRes.LatestEvents + for _, memberEvent := range memberEvents { + if memberEvent.StateKey() == nil { + continue + } + + localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey()) + if err != nil { + continue + } + + accountRes := &userAPI.QueryAccountByLocalpartResponse{} + if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{ + Localpart: localpart, + ServerName: senderDomain, + }, accountRes); err != nil { + return err + } + if accountRes.Account == nil { + continue + } + + if accountRes.Account.AccountType != userAPI.AccountTypeGuest { + continue + } + + var memberContent gomatrixserverlib.MemberContent + if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil { + return err + } + memberContent.Membership = gomatrixserverlib.Leave + + stateKey := *memberEvent.StateKey() + fledglingEvent := &gomatrixserverlib.EventBuilder{ + RoomID: event.RoomID(), + Type: gomatrixserverlib.MRoomMember, + StateKey: &stateKey, + Sender: stateKey, + PrevEvents: prevEvents, + } + + if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { + return err + } + + eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent) + if err != nil { + return err + } + + event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes) + if err != nil { + return err + } + + inputEvents = append(inputEvents, api.InputRoomEvent{ + Kind: api.KindNew, + Event: event, + Origin: senderDomain, + SendAsServer: string(senderDomain), + }) + prevEvents = []gomatrixserverlib.EventReference{ + event.EventReference(), + } + } + + inputReq := &api.InputRoomEventsRequest{ + InputRoomEvents: inputEvents, + Asynchronous: true, // Needs to be async, as we otherwise create a deadlock + } + inputRes := &api.InputRoomEventsResponse{} + return r.InputRoomEvents(ctx, inputReq, inputRes) +} diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 4de008c6..fc7ba940 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -16,6 +16,7 @@ package perform import ( "context" + "database/sql" "errors" "fmt" "strings" @@ -270,6 +271,28 @@ func (r *Joiner) performJoinRoomByID( } } + // If a guest is trying to join a room, check that the room has a m.room.guest_access event + if req.IsGuest { + var guestAccessEvent *gomatrixserverlib.HeaderedEvent + guestAccess := "forbidden" + guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, gomatrixserverlib.MRoomGuestAccess, "") + if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil { + logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'") + } + if guestAccessEvent != nil { + guestAccess = gjson.GetBytes(guestAccessEvent.Content(), "guest_access").String() + } + + // Servers MUST only allow guest users to join rooms if the m.room.guest_access state event + // is present on the room and has the guest_access value can_join. + if guestAccess != "can_join" { + return "", "", &rsAPI.PerformError{ + Code: rsAPI.PerformErrorNotAllowed, + Msg: "Guest access is forbidden", + } + } + } + // If we should do a forced federated join then do that. var joinedVia gomatrixserverlib.ServerName if forceFederatedJoin { diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 518bb372..595ceb52 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -3,18 +3,23 @@ package roomserver_test import ( "context" "net/http" + "reflect" "testing" "time" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/userapi" + + userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/roomserver/storage" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" ) @@ -29,7 +34,28 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, s return base, db, close } -func Test_SharedUsers(t *testing.T) { +func TestUsers(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + rsAPI := roomserver.NewInternalAPI(base) + // SetFederationAPI starts the room event input consumer + rsAPI.SetFederationAPI(nil, nil) + + t.Run("shared users", func(t *testing.T) { + testSharedUsers(t, rsAPI) + }) + + t.Run("kick users", func(t *testing.T) { + usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil) + rsAPI.SetUserAPI(usrAPI) + testKickUsers(t, rsAPI, usrAPI) + }) + }) + +} + +func testSharedUsers(t *testing.T, rsAPI api.RoomserverInternalAPI) { alice := test.NewUser(t) bob := test.NewUser(t) room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) @@ -43,36 +69,93 @@ func Test_SharedUsers(t *testing.T) { }, test.WithStateKey(bob.ID)) ctx := context.Background() - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, _, close := mustCreateDatabase(t, dbType) - defer close() - rsAPI := roomserver.NewInternalAPI(base) - // SetFederationAPI starts the room event input consumer - rsAPI.SetFederationAPI(nil, nil) - // Create the room - if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { - t.Fatalf("failed to send events: %v", err) + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // Query the shared users for Alice, there should only be Bob. + // This is used by the SyncAPI keychange consumer. + res := &api.QuerySharedUsersResponse{} + if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil { + t.Errorf("unable to query known users: %v", err) + } + if _, ok := res.UserIDsToCount[bob.ID]; !ok { + t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + } + // Also verify that we get the expected result when specifying OtherUserIDs. + // This is used by the SyncAPI when getting device list changes. + if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil { + t.Errorf("unable to query known users: %v", err) + } + if _, ok := res.UserIDsToCount[bob.ID]; !ok { + t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + } +} + +func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI.UserInternalAPI) { + // Create users and room; Bob is going to be the guest and kicked on revocation of guest access + alice := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeGuest)) + + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat), test.GuestsCanJoin(true)) + + // Join with the guest user + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + // Create the users in the userapi, so the RSAPI can query the account type later + for _, u := range []*test.User{alice, bob} { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &userAPI.PerformAccountCreationResponse{} + if err := usrAPI.PerformAccountCreation(ctx, &userAPI.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: "someRandomPassword", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + + // Create the room in the database + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // Get the membership events BEFORE revoking guest access + membershipRes := &api.QueryMembershipsForRoomResponse{} + if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes); err != nil { + t.Errorf("failed to query membership for room: %s", err) + } + + // revoke guest access + revokeEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey("")) + if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // TODO: Even though we are sending the events sync, the "kickUsers" function is sending the events async, so we need + // to loop and wait for the events to be processed by the roomserver. + for i := 0; i <= 20; i++ { + // Get the membership events AFTER revoking guest access + membershipRes2 := &api.QueryMembershipsForRoomResponse{} + if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes2); err != nil { + t.Errorf("failed to query membership for room: %s", err) } - // Query the shared users for Alice, there should only be Bob. - // This is used by the SyncAPI keychange consumer. - res := &api.QuerySharedUsersResponse{} - if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil { - t.Fatalf("unable to query known users: %v", err) + // The membership events should NOT match, as Bob (guest user) should now be kicked from the room + if !reflect.DeepEqual(membershipRes, membershipRes2) { + return } - if _, ok := res.UserIDsToCount[bob.ID]; !ok { - t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) - } - // Also verify that we get the expected result when specifying OtherUserIDs. - // This is used by the SyncAPI when getting device list changes. - if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil { - t.Fatalf("unable to query known users: %v", err) - } - if _, ok := res.UserIDsToCount[bob.ID]; !ok { - t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) - } - }) + time.Sleep(time.Millisecond * 10) + } + + t.Errorf("memberships didn't change in time") } func Test_QueryLeftUsers(t *testing.T) { diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 511951fe..804eb1a2 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -174,7 +174,7 @@ func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*g return id, nil } } - return nil, fmt.Errorf("no signing identity %q", serverName) + return nil, fmt.Errorf("no signing identity for %q", serverName) } func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity { diff --git a/setup/config/config_test.go b/setup/config/config_test.go index ee7e7389..3408bf46 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -16,8 +16,10 @@ package config import ( "fmt" + "reflect" "testing" + "github.com/matrix-org/gomatrixserverlib" "gopkg.in/yaml.v2" ) @@ -290,3 +292,55 @@ func TestUnmarshalDataUnit(t *testing.T) { } } } + +func Test_SigningIdentityFor(t *testing.T) { + tests := []struct { + name string + virtualHosts []*VirtualHost + serverName gomatrixserverlib.ServerName + want *gomatrixserverlib.SigningIdentity + wantErr bool + }{ + { + name: "no virtual hosts defined", + wantErr: true, + }, + { + name: "no identity found", + serverName: gomatrixserverlib.ServerName("doesnotexist"), + wantErr: true, + }, + { + name: "found identity", + serverName: gomatrixserverlib.ServerName("main"), + want: &gomatrixserverlib.SigningIdentity{ServerName: "main"}, + }, + { + name: "identity found on virtual hosts", + serverName: gomatrixserverlib.ServerName("vh2"), + virtualHosts: []*VirtualHost{ + {SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}}, + {SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh2"}}, + }, + want: &gomatrixserverlib.SigningIdentity{ServerName: "vh2"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Global{ + VirtualHosts: tt.virtualHosts, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "main", + }, + } + got, err := c.SigningIdentityFor(tt.serverName) + if (err != nil) != tt.wantErr { + t.Errorf("SigningIdentityFor() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SigningIdentityFor() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/sytest-blacklist b/sytest-blacklist index c35b03bd..99cfbabc 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -48,4 +48,7 @@ If a device list update goes missing, the server resyncs on the next one Leaves are present in non-gapped incremental syncs # Below test was passing for the wrong reason, failing correctly since #2858 -New federated private chats get full presence information (SYN-115) \ No newline at end of file +New federated private chats get full presence information (SYN-115) + +# We don't have any state to calculate m.room.guest_access when accepting invites +Guest users can accept invites to private rooms over federation \ No newline at end of file diff --git a/sytest-whitelist b/sytest-whitelist index 49ffb8fe..215889a4 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -763,4 +763,7 @@ AS and main public room lists are separate local user has tags copied to the new room remote user has tags copied to the new room /upgrade moves remote aliases to the new room -Local and remote users' homeservers remove a room from their public directory on upgrade \ No newline at end of file +Local and remote users' homeservers remove a room from their public directory on upgrade +Guest users denied access over federation if guest access prohibited +Guest users are kicked from guest_access rooms on revocation of guest_access +Guest users are kicked from guest_access rooms on revocation of guest_access over federation \ No newline at end of file diff --git a/test/room.go b/test/room.go index 4328bf84..685876cb 100644 --- a/test/room.go +++ b/test/room.go @@ -38,11 +38,12 @@ var ( ) type Room struct { - ID string - Version gomatrixserverlib.RoomVersion - preset Preset - visibility gomatrixserverlib.HistoryVisibility - creator *User + ID string + Version gomatrixserverlib.RoomVersion + preset Preset + guestCanJoin bool + visibility gomatrixserverlib.HistoryVisibility + creator *User authEvents gomatrixserverlib.AuthEvents currentState map[string]*gomatrixserverlib.HeaderedEvent @@ -120,6 +121,11 @@ func (r *Room) insertCreateEvents(t *testing.T) { r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey("")) + if r.guestCanJoin { + r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomGuestAccess, map[string]string{ + "guest_access": "can_join", + }, WithStateKey("")) + } } // Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe. @@ -268,3 +274,9 @@ func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier { r.Version = ver } } + +func GuestsCanJoin(canJoin bool) roomModifier { + return func(t *testing.T, r *Room) { + r.guestCanJoin = canJoin + } +} diff --git a/userapi/api/api.go b/userapi/api/api.go index d3f5aefc..4ea2e91c 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -50,6 +50,7 @@ type KeyserverUserAPI interface { type RoomserverUserAPI interface { QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error + QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) } // api functions required by the media api @@ -671,3 +672,12 @@ type PerformSaveThreePIDAssociationRequest struct { ServerName gomatrixserverlib.ServerName Medium string } + +type QueryAccountByLocalpartRequest struct { + Localpart string + ServerName gomatrixserverlib.ServerName +} + +type QueryAccountByLocalpartResponse struct { + Account *Account +} diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index ce661770..d10b5767 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -204,6 +204,12 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex return err } +func (t *UserInternalAPITrace) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) error { + err := t.Impl.QueryAccountByLocalpart(ctx, req, res) + util.GetLogger(ctx).Infof("QueryAccountByLocalpart req=%+v res=%+v", js(req), js(res)) + return err +} + func js(thing interface{}) string { b, err := json.Marshal(thing) if err != nil { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 3f256457..0bb480da 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -548,6 +548,11 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc return nil } +func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) { + res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName) + return +} + // Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem // creating a 'device'. func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 87ae058c..51b0fe3e 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -60,6 +60,7 @@ const ( QueryAccountByPasswordPath = "/userapi/queryAccountByPassword" QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID" QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart" + QueryAccountByLocalpartPath = "/userapi/queryAccountType" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -440,3 +441,14 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation( h.httpClient, ctx, request, response, ) } + +func (h *httpUserInternalAPI) QueryAccountByLocalpart( + ctx context.Context, + req *api.QueryAccountByLocalpartRequest, + res *api.QueryAccountByLocalpartResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryAccountByLocalpart", h.apiURL+QueryAccountByLocalpartPath, + h.httpClient, ctx, req, res, + ) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index f0579079..b40b507c 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -189,4 +189,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics PerformSaveThreePIDAssociationPath, httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation), ) + + internalAPIMux.Handle( + QueryAccountByLocalpartPath, + httputil.MakeInternalRPCAPI("AccountByLocalpart", enableMetrics, s.QueryAccountByLocalpart), + ) } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 8a19af19..dada56de 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -307,3 +307,64 @@ func TestLoginToken(t *testing.T) { }) }) } + +func TestQueryAccountByLocalpart(t *testing.T) { + alice := test.NewUser(t) + + localpart, userServername, _ := gomatrixserverlib.SplitID('@', alice.ID) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + defer close() + + createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType) + if err != nil { + t.Error(err) + } + + testCases := func(t *testing.T, internalAPI api.UserInternalAPI) { + // Query existing account + queryAccResp := &api.QueryAccountByLocalpartResponse{} + if err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{ + Localpart: localpart, + ServerName: userServername, + }, queryAccResp); err != nil { + t.Error(err) + } + if !reflect.DeepEqual(createdAcc, queryAccResp.Account) { + t.Fatalf("created and queried accounts don't match:\n%+v vs.\n%+v", createdAcc, queryAccResp.Account) + } + + // Query non-existent account, this should result in an error + err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{ + Localpart: "doesnotexist", + ServerName: userServername, + }, queryAccResp) + + if err == nil { + t.Fatalf("expected an error, but got none: %+v", queryAccResp) + } + } + + t.Run("Monolith", func(t *testing.T) { + testCases(t, intAPI) + // also test tracing + testCases(t, &api.UserInternalAPITrace{Impl: intAPI}) + }) + + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + userapi.AddInternalRoutes(router, intAPI, false) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + + userHTTPApi, err := inthttp.NewUserAPIClient(apiURL, &http.Client{Timeout: time.Second * 5}) + if err != nil { + t.Fatalf("failed to create HTTP client: %s", err) + } + testCases(t, userHTTPApi) + + }) + }) +}