From 05cafbd197c99c0e116c9b61447e70ba5af992a3 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 11 Aug 2022 18:23:35 +0200 Subject: [PATCH] Implement history visibility on `/messages`, `/context`, `/sync` (#2511) * Add possibility to set history_visibility and user AccountType * Add new DB queries * Add actual history_visibility changes for /messages * Add passing tests * Extract check function * Cleanup * Cleanup * Fix build on 386 * Move ApplyHistoryVisibilityFilter to internal * Move queries to topology table * Add filtering to /sync and /context Some cleanup * Add passing tests; Remove failing tests :( * Re-add passing tests * Move filtering to own function to avoid duplication * Re-add passing test * Use newly added GMSL HistoryVisibility * Update gomatrixserverlib * Set the visibility when creating events * Default to shared history visibility * Remove unused query * Update history visibility checks to use gmsl Update tests * Remove unused statement * Update migrations to set "correct" history visibility * Add method to fetch the membership at a given event * Tweaks and logging * Use actual internal rsAPI, default to shared visibility in tests * Revert "Move queries to topology table" This reverts commit 4f0d41be9c194a46379796435ce73e79203edbd6. * Remove noise/unneeded code * More cleanup * Try to optimize database requests * Fix imports * PR peview fixes/changes * Move setting history visibility to own migration, be more restrictive * Fix unit tests * Lint * Fix missing entries * Tweaks for incremental syncs * Adapt generic changes Co-authored-by: Neil Alexander Co-authored-by: kegsay --- go.mod | 8 +- go.sum | 16 +- roomserver/api/api.go | 8 + roomserver/api/api_trace.go | 10 + roomserver/api/query.go | 14 ++ roomserver/internal/helpers/helpers.go | 6 + roomserver/internal/input/input_events.go | 4 +- roomserver/internal/query/query.go | 48 ++++ roomserver/inthttp/client.go | 12 +- roomserver/inthttp/server.go | 5 + roomserver/state/state.go | 60 ++++- syncapi/internal/history_visibility.go | 217 ++++++++++++++++++ syncapi/routing/context.go | 99 ++++++-- syncapi/routing/messages.go | 100 ++------ syncapi/storage/interface.go | 4 + ...2061412000000_history_visibility_column.go | 55 +++++ syncapi/storage/postgres/memberships_table.go | 28 ++- .../postgres/output_room_events_table.go | 10 +- syncapi/storage/postgres/syncserver.go | 15 ++ syncapi/storage/shared/syncserver.go | 49 ++-- ...2061412000000_history_visibility_column.go | 55 +++++ syncapi/storage/sqlite3/memberships_table.go | 22 ++ .../sqlite3/output_room_events_table.go | 10 +- syncapi/storage/sqlite3/syncserver.go | 19 +- syncapi/storage/storage_test.go | 20 +- syncapi/storage/tables/interface.go | 1 + syncapi/streams/stream_pdu.go | 123 ++++++---- syncapi/syncapi_test.go | 186 ++++++++++++++- sytest-whitelist | 25 +- test/room.go | 28 ++- test/user.go | 10 +- 31 files changed, 1043 insertions(+), 224 deletions(-) create mode 100644 syncapi/internal/history_visibility.go diff --git a/go.mod b/go.mod index 3559c5bb..79ef5c3e 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839 github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.13 @@ -34,7 +34,7 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.12.2 - github.com/sirupsen/logrus v1.8.1 + github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.7.1 github.com/tidwall/gjson v1.14.1 github.com/tidwall/sjson v1.2.4 @@ -42,7 +42,7 @@ require ( github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.3 go.uber.org/atomic v1.9.0 - golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e + golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa golang.org/x/image v0.0.0-20220413100746-70e8d0d3baa9 golang.org/x/mobile v0.0.0-20220518205345-8578da9835fd golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e @@ -99,7 +99,7 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect - golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect + golang.org/x/sys v0.0.0-20220731174439-a90be440212d // indirect golang.org/x/text v0.3.8-0.20211004125949-5bd84dd9b33b // indirect golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect golang.org/x/tools v0.1.10 // indirect diff --git a/go.sum b/go.sum index 2c8bb4f1..9d5c50d2 100644 --- a/go.sum +++ b/go.sum @@ -343,8 +343,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771 h1:ZIPHFIPNDS9dmEbPEiJbNmyCGJtn9exfpLC7JOcn/bE= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839 h1:QEFxKWH8PlEt3ZQKl31yJNAm8lvpNUwT51IMNTl9v1k= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk= github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 h1:ed8yvWhTLk7+sNeK/eOZRTvESFTOHDRevoRoyeqPtvY= github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9/go.mod h1:P4MqPf+u83OPulPJ+XTbSDbbWrdFYNY4LZ/B1PIduFE= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= @@ -493,8 +493,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= -github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= +github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s= github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= @@ -569,8 +569,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM= -golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -748,8 +748,10 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211102192858-4dd72447c267/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 h1:CBpWXWQpIRjzmkkA+M7q9Fqnwd2mZr3AFqexg8YTfoM= diff --git a/roomserver/api/api.go b/roomserver/api/api.go index ee0212ec..baf63aa3 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -97,6 +97,14 @@ type SyncRoomserverAPI interface { req *PerformBackfillRequest, res *PerformBackfillResponse, ) error + + // QueryMembershipAtEvent queries the memberships at the given events. + // Returns a map from eventID to a slice of gomatrixserverlib.HeaderedEvent. + QueryMembershipAtEvent( + ctx context.Context, + request *QueryMembershipAtEventRequest, + response *QueryMembershipAtEventResponse, + ) error } type AppserviceRoomserverAPI interface { diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 18a61733..8bef3537 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -382,6 +382,16 @@ func (t *RoomserverInternalAPITrace) QueryRestrictedJoinAllowed( return err } +func (t *RoomserverInternalAPITrace) QueryMembershipAtEvent( + ctx context.Context, + request *QueryMembershipAtEventRequest, + response *QueryMembershipAtEventResponse, +) error { + err := t.Impl.QueryMembershipAtEvent(ctx, request, response) + util.GetLogger(ctx).WithError(err).Infof("QueryMembershipAtEvent req=%+v res=%+v", js(request), js(response)) + return err +} + func js(thing interface{}) string { b, err := json.Marshal(thing) if err != nil { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index f157a902..c8e6f9dc 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -427,3 +427,17 @@ func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error { } return nil } + +// QueryMembershipAtEventRequest requests the membership events for a user +// for a list of eventIDs. +type QueryMembershipAtEventRequest struct { + RoomID string + EventIDs []string + UserID string +} + +// QueryMembershipAtEventResponse is the response to QueryMembershipAtEventRequest. +type QueryMembershipAtEventResponse struct { + // Memberships is a map from eventID to a list of events (if any). + Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"` +} diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index d61aa08c..6091f8ec 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -208,6 +208,12 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room return roomState.LoadCombinedStateAfterEvents(ctx, prevState) } +func MembershipAtEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) { + roomState := state.NewStateResolution(db, info) + // Fetch the state as it was when this event was fired + return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID) +} + func LoadEvents( ctx context.Context, db storage.Database, eventNIDs []types.EventNID, ) ([]*gomatrixserverlib.Event, error) { diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 866670d7..81541260 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -299,7 +299,7 @@ func (r *Inputer) processRoomEvent( // allowed at the time, and also to get the history visibility. We won't // bother doing this if the event was already rejected as it just ends up // burning CPU time. - historyVisibility := gomatrixserverlib.HistoryVisibilityJoined // Default to restrictive. + historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared. if rejectionErr == nil && !isRejected && !softfail { var err error historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev) @@ -429,7 +429,7 @@ func (r *Inputer) processStateBefore( input *api.InputRoomEvent, missingPrev bool, ) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) { - historyVisibility = gomatrixserverlib.HistoryVisibilityJoined // Default to restrictive. + historyVisibility = gomatrixserverlib.HistoryVisibilityShared // Default to shared. event := input.Event.Unwrap() isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") var stateBeforeEvent []*gomatrixserverlib.Event diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 5ba59b8f..c41e1ea6 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -204,6 +204,54 @@ func (r *Queryer) QueryMembershipForUser( return err } +func (r *Queryer) QueryMembershipAtEvent( + ctx context.Context, + request *api.QueryMembershipAtEventRequest, + response *api.QueryMembershipAtEventResponse, +) error { + response.Memberships = make(map[string][]*gomatrixserverlib.HeaderedEvent) + info, err := r.DB.RoomInfo(ctx, request.RoomID) + if err != nil { + return fmt.Errorf("unable to get roomInfo: %w", err) + } + if info == nil { + return fmt.Errorf("no roomInfo found") + } + + // get the users stateKeyNID + stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.UserID}) + if err != nil { + return fmt.Errorf("unable to get stateKeyNIDs for %s: %w", request.UserID, err) + } + if _, ok := stateKeyNIDs[request.UserID]; !ok { + return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID) + } + + stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, info, request.EventIDs, stateKeyNIDs[request.UserID]) + if err != nil { + return fmt.Errorf("unable to get state before event: %w", err) + } + + for _, eventID := range request.EventIDs { + stateEntry := stateEntries[eventID] + memberships, err := helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + if err != nil { + return fmt.Errorf("unable to get memberships at state: %w", err) + } + res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships)) + + for i := range memberships { + ev := memberships[i] + if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) { + res = append(res, ev.Headered(info.RoomVersion)) + } + } + response.Memberships[eventID] = res + } + + return nil +} + // QueryMembershipsForRoom implements api.RoomserverInternalAPI func (r *Queryer) QueryMembershipsForRoom( ctx context.Context, diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index d16f67c6..e9387bd9 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -5,14 +5,14 @@ import ( "errors" "net/http" + "github.com/matrix-org/gomatrixserverlib" + asAPI "github.com/matrix-org/dendrite/appservice/api" fsInputAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" - - "github.com/matrix-org/gomatrixserverlib" ) const ( @@ -61,6 +61,7 @@ const ( RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom" RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain" RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed" + RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent" ) type httpRoomserverInternalAPI struct { @@ -529,3 +530,10 @@ func (h *httpRoomserverInternalAPI) PerformForget( ) } + +func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context, request *api.QueryMembershipAtEventRequest, response *api.QueryMembershipAtEventResponse) error { + return httputil.CallInternalRPCAPI( + "QueryMembershiptAtEvent", h.roomserverURL+RoomserverQueryMembershipAtEventPath, + h.httpClient, ctx, request, response, + ) +} diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index e325d76a..3b688174 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -2,6 +2,7 @@ package inthttp import ( "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" ) @@ -193,4 +194,8 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { RoomserverQueryRestrictedJoinAllowed, httputil.MakeInternalRPCAPI("RoomserverQueryRestrictedJoinAllowed", r.QueryRestrictedJoinAllowed), ) + internalAPIMux.Handle( + RoomserverQueryMembershipAtEventPath, + httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", r.QueryMembershipAtEvent), + ) } diff --git a/roomserver/state/state.go b/roomserver/state/state.go index ca0c69f2..a40a2e9b 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -23,12 +23,11 @@ import ( "sync" "time" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" - - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" ) type StateResolutionStorage interface { @@ -124,6 +123,61 @@ func (v *StateResolution) LoadStateAtEvent( return stateEntries, nil } +func (v *StateResolution) LoadMembershipAtEvent( + ctx context.Context, eventIDs []string, stateKeyNID types.EventStateKeyNID, +) (map[string][]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadMembershipAtEvent") + defer span.Finish() + + // De-dupe snapshotNIDs + snapshotNIDMap := make(map[types.StateSnapshotNID][]string) // map from snapshot NID to eventIDs + for i := range eventIDs { + eventID := eventIDs[i] + snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) + if err != nil { + return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err) + } + if snapshotNID == 0 { + return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID) + } + snapshotNIDMap[snapshotNID] = append(snapshotNIDMap[snapshotNID], eventID) + } + + snapshotNIDs := make([]types.StateSnapshotNID, 0, len(snapshotNIDMap)) + for nid := range snapshotNIDMap { + snapshotNIDs = append(snapshotNIDs, nid) + } + + stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, snapshotNIDs) + if err != nil { + return nil, err + } + + result := make(map[string][]types.StateEntry) + for _, stateBlockNIDList := range stateBlockNIDLists { + // Query the membership event for the user at the given stateblocks + stateEntryLists, err := v.db.StateEntriesForTuples(ctx, stateBlockNIDList.StateBlockNIDs, []types.StateKeyTuple{ + { + EventTypeNID: types.MRoomMemberNID, + EventStateKeyNID: stateKeyNID, + }, + }) + if err != nil { + return nil, err + } + + evIDs := snapshotNIDMap[stateBlockNIDList.StateSnapshotNID] + + for _, evID := range evIDs { + for _, x := range stateEntryLists { + result[evID] = append(result[evID], x.StateEntries...) + } + } + } + + return result, nil +} + // LoadStateAtEvent loads the full state of a room before a particular event. func (v *StateResolution) LoadStateAtEventForHistoryVisibility( ctx context.Context, eventID string, diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go new file mode 100644 index 00000000..e73c004e --- /dev/null +++ b/syncapi/internal/history_visibility.go @@ -0,0 +1,217 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "math" + "time" + + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/gomatrixserverlib" + "github.com/prometheus/client_golang/prometheus" + "github.com/tidwall/gjson" +) + +func init() { + prometheus.MustRegister(calculateHistoryVisibilityDuration) +} + +// calculateHistoryVisibilityDuration stores the time it takes to +// calculate the history visibility. In polylith mode the roundtrip +// to the roomserver is included in this time. +var calculateHistoryVisibilityDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "dendrite", + Subsystem: "syncapi", + Name: "calculateHistoryVisibility_duration_millis", + Help: "How long it takes to calculate the history visibility", + Buckets: []float64{ // milliseconds + 5, 10, 25, 50, 75, 100, 250, 500, + 1000, 2000, 3000, 4000, 5000, 6000, + 7000, 8000, 9000, 10000, 15000, 20000, + }, + }, + []string{"api"}, +) + +var historyVisibilityPriority = map[gomatrixserverlib.HistoryVisibility]uint8{ + gomatrixserverlib.WorldReadable: 0, + gomatrixserverlib.HistoryVisibilityShared: 1, + gomatrixserverlib.HistoryVisibilityInvited: 2, + gomatrixserverlib.HistoryVisibilityJoined: 3, +} + +// eventVisibility contains the history visibility and membership state at a given event +type eventVisibility struct { + visibility gomatrixserverlib.HistoryVisibility + membershipAtEvent string + membershipCurrent string +} + +// allowed checks the eventVisibility if the user is allowed to see the event. +// Rules as defined by https://spec.matrix.org/v1.3/client-server-api/#server-behaviour-5 +func (ev eventVisibility) allowed() (allowed bool) { + switch ev.visibility { + case gomatrixserverlib.HistoryVisibilityWorldReadable: + // If the history_visibility was set to world_readable, allow. + return true + case gomatrixserverlib.HistoryVisibilityJoined: + // If the user’s membership was join, allow. + if ev.membershipAtEvent == gomatrixserverlib.Join { + return true + } + return false + case gomatrixserverlib.HistoryVisibilityShared: + // If the user’s membership was join, allow. + // If history_visibility was set to shared, and the user joined the room at any point after the event was sent, allow. + if ev.membershipAtEvent == gomatrixserverlib.Join || ev.membershipCurrent == gomatrixserverlib.Join { + return true + } + return false + case gomatrixserverlib.HistoryVisibilityInvited: + // If the user’s membership was join, allow. + if ev.membershipAtEvent == gomatrixserverlib.Join { + return true + } + if ev.membershipAtEvent == gomatrixserverlib.Invite { + return true + } + return false + default: + return false + } +} + +// ApplyHistoryVisibilityFilter applies the room history visibility filter on gomatrixserverlib.HeaderedEvents. +// Returns the filtered events and an error, if any. +func ApplyHistoryVisibilityFilter( + ctx context.Context, + syncDB storage.Database, + rsAPI api.SyncRoomserverAPI, + events []*gomatrixserverlib.HeaderedEvent, + alwaysIncludeEventIDs map[string]struct{}, + userID, endpoint string, +) ([]*gomatrixserverlib.HeaderedEvent, error) { + if len(events) == 0 { + return events, nil + } + start := time.Now() + + // try to get the current membership of the user + membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID(), userID, math.MaxInt64) + if err != nil { + return nil, err + } + + // Get the mapping from eventID -> eventVisibility + eventsFiltered := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events)) + visibilities, err := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID()) + if err != nil { + return eventsFiltered, err + } + for _, ev := range events { + evVis := visibilities[ev.EventID()] + evVis.membershipCurrent = membershipCurrent + // Always include specific state events for /sync responses + if alwaysIncludeEventIDs != nil { + if _, ok := alwaysIncludeEventIDs[ev.EventID()]; ok { + eventsFiltered = append(eventsFiltered, ev) + continue + } + } + // NOTSPEC: Always allow user to see their own membership events (spec contains more "rules") + if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(userID) { + eventsFiltered = append(eventsFiltered, ev) + continue + } + // Always allow history evVis events on boundaries. This is done + // by setting the effective evVis to the least restrictive + // of the old vs new. + // https://spec.matrix.org/v1.3/client-server-api/#server-behaviour-5 + if hisVis, err := ev.HistoryVisibility(); err == nil { + prevHisVis := gjson.GetBytes(ev.Unsigned(), "prev_content.history_visibility").String() + oldPrio, ok := historyVisibilityPriority[gomatrixserverlib.HistoryVisibility(prevHisVis)] + // if we can't get the previous history visibility, default to shared. + if !ok { + oldPrio = historyVisibilityPriority[gomatrixserverlib.HistoryVisibilityShared] + } + // no OK check, since this should have been validated when setting the value + newPrio := historyVisibilityPriority[hisVis] + if oldPrio < newPrio { + evVis.visibility = gomatrixserverlib.HistoryVisibility(prevHisVis) + } + } + // do the actual check + allowed := evVis.allowed() + if allowed { + eventsFiltered = append(eventsFiltered, ev) + } + } + calculateHistoryVisibilityDuration.With(prometheus.Labels{"api": endpoint}).Observe(float64(time.Since(start).Milliseconds())) + return eventsFiltered, nil +} + +// visibilityForEvents returns a map from eventID to eventVisibility containing the visibility and the membership +// of `userID` at the given event. +// Returns an error if the roomserver can't calculate the memberships. +func visibilityForEvents( + ctx context.Context, + rsAPI api.SyncRoomserverAPI, + events []*gomatrixserverlib.HeaderedEvent, + userID, roomID string, +) (map[string]eventVisibility, error) { + eventIDs := make([]string, len(events)) + for i := range events { + eventIDs[i] = events[i].EventID() + } + + result := make(map[string]eventVisibility, len(eventIDs)) + + // get the membership events for all eventIDs + membershipResp := &api.QueryMembershipAtEventResponse{} + err := rsAPI.QueryMembershipAtEvent(ctx, &api.QueryMembershipAtEventRequest{ + RoomID: roomID, + EventIDs: eventIDs, + UserID: userID, + }, membershipResp) + if err != nil { + return result, err + } + + // Create a map from eventID -> eventVisibility + for _, event := range events { + eventID := event.EventID() + vis := eventVisibility{ + membershipAtEvent: gomatrixserverlib.Leave, // default to leave, to not expose events by accident + visibility: event.Visibility, + } + membershipEvs, ok := membershipResp.Memberships[eventID] + if !ok { + result[eventID] = vis + continue + } + for _, ev := range membershipEvs { + membership, err := ev.Membership() + if err != nil { + return result, err + } + vis.membershipAtEvent = membership + } + result[eventID] = vis + } + return result, nil +} diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index f6b4d15e..13c4e9d8 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -21,10 +21,12 @@ import ( "fmt" "net/http" "strconv" + "time" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/caching" roomserver "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -95,24 +97,6 @@ func Context( ContainsURL: filter.ContainsURL, } - // TODO: Get the actual state at the last event returned by SelectContextAfterEvent - state, _ := syncDB.CurrentState(ctx, roomID, &stateFilter, nil) - // verify the user is allowed to see the context for this room/event - for _, x := range state { - var hisVis gomatrixserverlib.HistoryVisibility - hisVis, err = x.HistoryVisibility() - if err != nil { - continue - } - allowed := hisVis == gomatrixserverlib.WorldReadable || membershipRes.Membership == gomatrixserverlib.Join - if !allowed { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("User is not allowed to query context"), - } - } - } - id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID) if err != nil { if err == sql.ErrNoRows { @@ -125,6 +109,24 @@ func Context( return jsonerror.InternalServerError() } + // verify the user is allowed to see the context for this room/event + startTime := time.Now() + filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context") + if err != nil { + logrus.WithError(err).Error("unable to apply history visibility filter") + return jsonerror.InternalServerError() + } + logrus.WithFields(logrus.Fields{ + "duration": time.Since(startTime), + "room_id": roomID, + }).Debug("applied history visibility (context)") + if len(filteredEvents) == 0 { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("User is not allowed to query context"), + } + } + eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, roomID, filter) if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Error("unable to fetch before events") @@ -137,8 +139,27 @@ func Context( return jsonerror.InternalServerError() } - eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBefore, gomatrixserverlib.FormatAll) - eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfter, gomatrixserverlib.FormatAll) + startTime = time.Now() + eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, syncDB, rsAPI, eventsBefore, eventsAfter, device.UserID) + if err != nil { + logrus.WithError(err).Error("unable to apply history visibility filter") + return jsonerror.InternalServerError() + } + + logrus.WithFields(logrus.Fields{ + "duration": time.Since(startTime), + "room_id": roomID, + }).Debug("applied history visibility (context eventsBefore/eventsAfter)") + + // TODO: Get the actual state at the last event returned by SelectContextAfterEvent + state, err := syncDB.CurrentState(ctx, roomID, &stateFilter, nil) + if err != nil { + logrus.WithError(err).Error("unable to fetch current room state") + return jsonerror.InternalServerError() + } + + eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBeforeFiltered, gomatrixserverlib.FormatAll) + eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll) newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache) response := ContextRespsonse{ @@ -162,6 +183,44 @@ func Context( } } +// applyHistoryVisibilityOnContextEvents is a helper function to avoid roundtrips to the roomserver +// by combining the events before and after the context event. Returns the filtered events, +// and an error, if any. +func applyHistoryVisibilityOnContextEvents( + ctx context.Context, syncDB storage.Database, rsAPI roomserver.SyncRoomserverAPI, + eventsBefore, eventsAfter []*gomatrixserverlib.HeaderedEvent, + userID string, +) (filteredBefore, filteredAfter []*gomatrixserverlib.HeaderedEvent, err error) { + eventIDsBefore := make(map[string]struct{}, len(eventsBefore)) + eventIDsAfter := make(map[string]struct{}, len(eventsAfter)) + + // Remember before/after eventIDs, so we can restore them + // after applying history visibility checks + for _, ev := range eventsBefore { + eventIDsBefore[ev.EventID()] = struct{}{} + } + for _, ev := range eventsAfter { + eventIDsAfter[ev.EventID()] = struct{}{} + } + + allEvents := append(eventsBefore, eventsAfter...) + filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, allEvents, nil, userID, "context") + if err != nil { + return nil, nil, err + } + + // "Restore" events in the correct context + for _, ev := range filteredEvents { + if _, ok := eventIDsBefore[ev.EventID()]; ok { + filteredBefore = append(filteredBefore, ev) + } + if _, ok := eventIDsAfter[ev.EventID()]; ok { + filteredAfter = append(filteredAfter, ev) + } + } + return filteredBefore, filteredAfter, nil +} + func getStartEnd(ctx context.Context, syncDB storage.Database, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { if len(startEvents) > 0 { start, err = syncDB.EventPositionInTopology(ctx, startEvents[0].EventID()) diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index b4c9a542..9db3d8e1 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -19,6 +19,7 @@ import ( "fmt" "net/http" "sort" + "time" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -28,6 +29,7 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/types" @@ -324,6 +326,9 @@ func (r *messagesReq) retrieveEvents() ( // reliable way to define it), it would be easier and less troublesome to // only have to change it in one place, i.e. the database. start, end, err = r.getStartEnd(events) + if err != nil { + return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, err + } // Sort the events to ensure we send them in the right order. if r.backwardOrdering { @@ -337,97 +342,18 @@ func (r *messagesReq) retrieveEvents() ( } events = reversed(events) } - events = r.filterHistoryVisible(events) if len(events) == 0 { return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil } - // Convert all of the events into client events. - clientEvents = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll) - return clientEvents, start, end, err -} - -func (r *messagesReq) filterHistoryVisible(events []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { - // TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the - // user shouldn't see, we check the recent events and remove any prior to the join event of the user - // which is equiv to history_visibility: joined - joinEventIndex := -1 - for i, ev := range events { - if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(r.device.UserID) { - membership, _ := ev.Membership() - if membership == "join" { - joinEventIndex = i - break - } - } - } - - var result []*gomatrixserverlib.HeaderedEvent - var eventsToCheck []*gomatrixserverlib.HeaderedEvent - if joinEventIndex != -1 { - if r.backwardOrdering { - result = events[:joinEventIndex+1] - eventsToCheck = append(eventsToCheck, result[0]) - } else { - result = events[joinEventIndex:] - eventsToCheck = append(eventsToCheck, result[len(result)-1]) - } - } else { - eventsToCheck = []*gomatrixserverlib.HeaderedEvent{events[0], events[len(events)-1]} - result = events - } - // make sure the user was in the room for both the earliest and latest events, we need this because - // some backpagination results will not have the join event (e.g if they hit /messages at the join event itself) - wasJoined := true - for _, ev := range eventsToCheck { - var queryRes api.QueryStateAfterEventsResponse - err := r.rsAPI.QueryStateAfterEvents(r.ctx, &api.QueryStateAfterEventsRequest{ - RoomID: ev.RoomID(), - PrevEventIDs: ev.PrevEventIDs(), - StateToFetch: []gomatrixserverlib.StateKeyTuple{ - {EventType: gomatrixserverlib.MRoomMember, StateKey: r.device.UserID}, - {EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""}, - }, - }, &queryRes) - if err != nil { - wasJoined = false - break - } - var hisVisEvent, membershipEvent *gomatrixserverlib.HeaderedEvent - for i := range queryRes.StateEvents { - switch queryRes.StateEvents[i].Type() { - case gomatrixserverlib.MRoomMember: - membershipEvent = queryRes.StateEvents[i] - case gomatrixserverlib.MRoomHistoryVisibility: - hisVisEvent = queryRes.StateEvents[i] - } - } - if hisVisEvent == nil { - return events // apply no filtering as it defaults to Shared. - } - hisVis, _ := hisVisEvent.HistoryVisibility() - if hisVis == "shared" || hisVis == "world_readable" { - return events // apply no filtering - } - if membershipEvent == nil { - wasJoined = false - break - } - membership, err := membershipEvent.Membership() - if err != nil { - wasJoined = false - break - } - if membership != "join" { - wasJoined = false - break - } - } - if !wasJoined { - util.GetLogger(r.ctx).WithField("num_events", len(events)).Warnf("%s was not joined to room during these events, omitting them", r.device.UserID) - return []*gomatrixserverlib.HeaderedEvent{} - } - return result + // Apply room history visibility filter + startTime := time.Now() + filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.db, r.rsAPI, events, nil, r.device.UserID, "messages") + logrus.WithFields(logrus.Fields{ + "duration": time.Since(startTime), + "room_id": r.roomID, + }).Debug("applied history visibility (messages)") + return gomatrixserverlib.HeaderedToClientEvents(filteredEvents, gomatrixserverlib.FormatAll), start, end, err } func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 9751670b..43a75da9 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -161,6 +161,10 @@ type Database interface { IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error + // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found + // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty + // string as the membership. + SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) } type Presence interface { diff --git a/syncapi/storage/postgres/deltas/2022061412000000_history_visibility_column.go b/syncapi/storage/postgres/deltas/2022061412000000_history_visibility_column.go index 29008ade..d68ed8d5 100644 --- a/syncapi/storage/postgres/deltas/2022061412000000_history_visibility_column.go +++ b/syncapi/storage/postgres/deltas/2022061412000000_history_visibility_column.go @@ -17,7 +17,10 @@ package deltas import ( "context" "database/sql" + "encoding/json" "fmt" + + "github.com/matrix-org/gomatrixserverlib" ) func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error { @@ -31,6 +34,27 @@ func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.T return nil } +// UpSetHistoryVisibility sets the history visibility for already stored events. +// Requires current_room_state and output_room_events to be created. +func UpSetHistoryVisibility(ctx context.Context, tx *sql.Tx) error { + // get the current room history visibilities + historyVisibilities, err := currentHistoryVisibilities(ctx, tx) + if err != nil { + return err + } + + // update the history visibility + for roomID, hisVis := range historyVisibilities { + _, err = tx.ExecContext(ctx, `UPDATE syncapi_output_room_events SET history_visibility = $1 + WHERE type IN ('m.room.message', 'm.room.encrypted') AND room_id = $2 AND history_visibility <> $1`, hisVis, roomID) + if err != nil { + return fmt.Errorf("failed to update history visibility: %w", err) + } + } + + return nil +} + func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` ALTER TABLE syncapi_current_room_state ADD COLUMN IF NOT EXISTS history_visibility SMALLINT NOT NULL DEFAULT 2; @@ -39,9 +63,40 @@ func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.T if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } + return nil } +// currentHistoryVisibilities returns a map from roomID to current history visibility. +// If the history visibility was changed after room creation, defaults to joined. +func currentHistoryVisibilities(ctx context.Context, tx *sql.Tx) (map[string]gomatrixserverlib.HistoryVisibility, error) { + rows, err := tx.QueryContext(ctx, `SELECT DISTINCT room_id, headered_event_json FROM syncapi_current_room_state + WHERE type = 'm.room.history_visibility' AND state_key = ''; +`) + if err != nil { + return nil, fmt.Errorf("failed to query current room state: %w", err) + } + defer rows.Close() // nolint: errcheck + var eventBytes []byte + var roomID string + var event gomatrixserverlib.HeaderedEvent + var hisVis gomatrixserverlib.HistoryVisibility + historyVisibilities := make(map[string]gomatrixserverlib.HistoryVisibility) + for rows.Next() { + if err = rows.Scan(&roomID, &eventBytes); err != nil { + return nil, fmt.Errorf("failed to scan row: %w", err) + } + if err = json.Unmarshal(eventBytes, &event); err != nil { + return nil, fmt.Errorf("failed to unmarshal event: %w", err) + } + historyVisibilities[roomID] = gomatrixserverlib.HistoryVisibilityJoined + if hisVis, err = event.HistoryVisibility(); err == nil && event.Depth() < 10 { + historyVisibilities[roomID] = hisVis + } + } + return historyVisibilities, nil +} + func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` ALTER TABLE syncapi_output_room_events DROP COLUMN IF EXISTS history_visibility; diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index 00223c57..939d6b3f 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -66,10 +66,14 @@ const selectMembershipCountSQL = "" + const selectHeroesSQL = "" + "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5" +const selectMembershipBeforeSQL = "" + + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" + type membershipsStatements struct { - upsertMembershipStmt *sql.Stmt - selectMembershipCountStmt *sql.Stmt - selectHeroesStmt *sql.Stmt + upsertMembershipStmt *sql.Stmt + selectMembershipCountStmt *sql.Stmt + selectHeroesStmt *sql.Stmt + selectMembershipForUserStmt *sql.Stmt } func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -82,6 +86,7 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { {&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectHeroesStmt, selectHeroesSQL}, + {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, }.Prepare(db) } @@ -132,3 +137,20 @@ func (s *membershipsStatements) SelectHeroes( } return heroes, rows.Err() } + +// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found +// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty +// string as the membership. +func (s *membershipsStatements) SelectMembershipForUser( + ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64, +) (membership string, topologyPos int, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipForUserStmt) + err = stmt.QueryRowContext(ctx, roomID, userID, pos).Scan(&membership, &topologyPos) + if err != nil { + if err == sql.ErrNoRows { + return "leave", 0, nil + } + return "", 0, err + } + return membership, topologyPos, nil +} diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 34ff6700..8f633640 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -191,10 +191,12 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { } m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "syncapi: add history visibility column (output_room_events)", - Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents, - }) + m.AddMigrations( + sqlutil.Migration{ + Version: "syncapi: add history visibility column (output_room_events)", + Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents, + }, + ) err = m.Up(context.Background()) if err != nil { return nil, err diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index a044716c..979ff664 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/syncapi/storage/shared" ) @@ -97,6 +98,20 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) if err != nil { return nil, err } + + // apply migrations which need multiple tables + m := sqlutil.NewMigrator(d.db) + m.AddMigrations( + sqlutil.Migration{ + Version: "syncapi: set history visibility for existing events", + Up: deltas.UpSetHistoryVisibility, // Requires current_room_state and output_room_events to be created. + }, + ) + err = m.Up(base.Context()) + if err != nil { + return nil, err + } + d.Database = shared.Database{ DB: d.db, Writer: d.writer, diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 932bda16..a46e5525 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -231,7 +231,7 @@ func (d *Database) AddPeek( return } -// DeletePeeks tracks the fact that a user has stopped peeking from the specified +// DeletePeek tracks the fact that a user has stopped peeking from the specified // device. If the peeks was successfully deleted this returns the stream ID it was // stored at. Returns an error if there was a problem communicating with the database. func (d *Database) DeletePeek( @@ -372,6 +372,7 @@ func (d *Database) WriteEvent( ) (pduPosition types.StreamPosition, returnErr error) { returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error + ev.Visibility = historyVisibility pos, err := d.OutputEvents.InsertEvent( ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, historyVisibility, ) @@ -563,7 +564,7 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda return err } -// Retrieve the backward topology position, i.e. the position of the +// GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the // oldest event in the room's topology. func (d *Database) GetBackwardTopologyPos( ctx context.Context, @@ -674,7 +675,7 @@ func (d *Database) fetchMissingStateEvents( return events, nil } -// getStateDeltas returns the state deltas between fromPos and toPos, +// GetStateDeltas returns the state deltas between fromPos and toPos, // exclusive of oldPos, inclusive of newPos, for the rooms in which // the user has new membership events. // A list of joined room IDs is also returned in case the caller needs it. @@ -812,7 +813,7 @@ func (d *Database) GetStateDeltas( return deltas, joinedRoomIDs, nil } -// getStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync +// GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync // requests with full_state=true. // Fetches full state for all joined rooms and uses selectStateInRange to get // updates for other rooms. @@ -1039,37 +1040,41 @@ func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID s return d.NotificationData.SelectUserUnreadCounts(ctx, userID, from, to) } -func (s *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { - return s.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID) +func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { + return d.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID) } -func (s *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) { - return s.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter) +func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) { + return d.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter) } -func (s *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) { - return s.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter) +func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) { + return d.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter) } -func (s *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) { - return s.Ignores.SelectIgnores(ctx, userID) +func (d *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) { + return d.Ignores.SelectIgnores(ctx, userID) } -func (s *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error { - return s.Ignores.UpsertIgnores(ctx, userID, ignores) +func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error { + return d.Ignores.UpsertIgnores(ctx, userID, ignores) } -func (s *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { - return s.Presence.UpsertPresence(ctx, nil, userID, statusMsg, presence, lastActiveTS, fromSync) +func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { + return d.Presence.UpsertPresence(ctx, nil, userID, statusMsg, presence, lastActiveTS, fromSync) } -func (s *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return s.Presence.GetPresenceForUser(ctx, nil, userID) +func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUser(ctx, nil, userID) } -func (s *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { - return s.Presence.GetPresenceAfter(ctx, nil, after, filter) +func (d *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { + return d.Presence.GetPresenceAfter(ctx, nil, after, filter) } -func (s *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { - return s.Presence.GetMaxPresenceID(ctx, nil) +func (d *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { + return d.Presence.GetMaxPresenceID(ctx, nil) +} + +func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { + return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos) } diff --git a/syncapi/storage/sqlite3/deltas/2022061412000000_history_visibility_column.go b/syncapi/storage/sqlite3/deltas/2022061412000000_history_visibility_column.go index 07917721..d23f0756 100644 --- a/syncapi/storage/sqlite3/deltas/2022061412000000_history_visibility_column.go +++ b/syncapi/storage/sqlite3/deltas/2022061412000000_history_visibility_column.go @@ -17,7 +17,10 @@ package deltas import ( "context" "database/sql" + "encoding/json" "fmt" + + "github.com/matrix-org/gomatrixserverlib" ) func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error { @@ -37,6 +40,27 @@ func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.T return nil } +// UpSetHistoryVisibility sets the history visibility for already stored events. +// Requires current_room_state and output_room_events to be created. +func UpSetHistoryVisibility(ctx context.Context, tx *sql.Tx) error { + // get the current room history visibilities + historyVisibilities, err := currentHistoryVisibilities(ctx, tx) + if err != nil { + return err + } + + // update the history visibility + for roomID, hisVis := range historyVisibilities { + _, err = tx.ExecContext(ctx, `UPDATE syncapi_output_room_events SET history_visibility = $1 + WHERE type IN ('m.room.message', 'm.room.encrypted') AND room_id = $2 AND history_visibility <> $1`, hisVis, roomID) + if err != nil { + return fmt.Errorf("failed to update history visibility: %w", err) + } + } + + return nil +} + func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error { // SQLite doesn't have "if exists", so check if the column exists. If the query doesn't return an error, it already exists. // Required for unit tests, as otherwise a duplicate column error will show up. @@ -51,9 +75,40 @@ func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.T if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } + return nil } +// currentHistoryVisibilities returns a map from roomID to current history visibility. +// If the history visibility was changed after room creation, defaults to joined. +func currentHistoryVisibilities(ctx context.Context, tx *sql.Tx) (map[string]gomatrixserverlib.HistoryVisibility, error) { + rows, err := tx.QueryContext(ctx, `SELECT DISTINCT room_id, headered_event_json FROM syncapi_current_room_state + WHERE type = 'm.room.history_visibility' AND state_key = ''; +`) + if err != nil { + return nil, fmt.Errorf("failed to query current room state: %w", err) + } + defer rows.Close() // nolint: errcheck + var eventBytes []byte + var roomID string + var event gomatrixserverlib.HeaderedEvent + var hisVis gomatrixserverlib.HistoryVisibility + historyVisibilities := make(map[string]gomatrixserverlib.HistoryVisibility) + for rows.Next() { + if err = rows.Scan(&roomID, &eventBytes); err != nil { + return nil, fmt.Errorf("failed to scan row: %w", err) + } + if err = json.Unmarshal(eventBytes, &event); err != nil { + return nil, fmt.Errorf("failed to unmarshal event: %w", err) + } + historyVisibilities[roomID] = gomatrixserverlib.HistoryVisibilityJoined + if hisVis, err = event.HistoryVisibility(); err == nil && event.Depth() < 10 { + historyVisibilities[roomID] = hisVis + } + } + return historyVisibilities, nil +} + func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error { // SQLite doesn't have "if exists", so check if the column exists. _, err := tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_output_room_events LIMIT 1") diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index e4daa99c..0c966fca 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -66,11 +66,15 @@ const selectMembershipCountSQL = "" + const selectHeroesSQL = "" + "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5" +const selectMembershipBeforeSQL = "" + + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" + type membershipsStatements struct { db *sql.DB upsertMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt //selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic + selectMembershipForUserStmt *sql.Stmt } func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -84,6 +88,7 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { return s, sqlutil.StatementList{ {&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL}, + {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, // {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic }.Prepare(db) } @@ -148,3 +153,20 @@ func (s *membershipsStatements) SelectHeroes( } return heroes, rows.Err() } + +// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found +// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty +// string as the membership. +func (s *membershipsStatements) SelectMembershipForUser( + ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64, +) (membership string, topologyPos int, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipForUserStmt) + err = stmt.QueryRowContext(ctx, roomID, userID, pos).Scan(&membership, &topologyPos) + if err != nil { + if err == sql.ErrNoRows { + return "leave", 0, nil + } + return "", 0, err + } + return membership, topologyPos, nil +} diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index de389fa9..91fd35b5 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -139,10 +139,12 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even } m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "syncapi: add history visibility column (output_room_events)", - Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents, - }) + m.AddMigrations( + sqlutil.Migration{ + Version: "syncapi: add history visibility column (output_room_events)", + Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents, + }, + ) err = m.Up(context.Background()) if err != nil { return nil, err diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 5c5eb0f5..a84e2bd1 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -16,12 +16,14 @@ package sqlite3 import ( + "context" "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage/shared" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas" ) // SyncServerDatasource represents a sync server datasource which manages @@ -41,13 +43,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { return nil, err } - if err = d.prepare(); err != nil { + if err = d.prepare(base.Context()); err != nil { return nil, err } return &d, nil } -func (d *SyncServerDatasource) prepare() (err error) { +func (d *SyncServerDatasource) prepare(ctx context.Context) (err error) { if err = d.streamID.Prepare(d.db); err != nil { return err } @@ -107,6 +109,19 @@ func (d *SyncServerDatasource) prepare() (err error) { if err != nil { return err } + + // apply migrations which need multiple tables + m := sqlutil.NewMigrator(d.db) + m.AddMigrations( + sqlutil.Migration{ + Version: "syncapi: set history visibility for existing events", + Up: deltas.UpSetHistoryVisibility, // Requires current_room_state and output_room_events to be created. + }, + ) + err = m.Up(ctx) + if err != nil { + return err + } d.Database = shared.Database{ DB: d.db, Writer: d.writer, diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index eda5ef3e..a62818e9 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -12,20 +12,22 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" ) var ctx = context.Background() -func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func(), func()) { connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewSyncServerDatasource(nil, &config.DatabaseOptions{ + base, closeBase := testrig.CreateBaseDendrite(t, dbType) + db, err := storage.NewSyncServerDatasource(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }) if err != nil { t.Fatalf("NewSyncServerDatasource returned %s", err) } - return db, close + return db, close, closeBase } func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) { @@ -51,8 +53,9 @@ func TestWriteEvents(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { alice := test.NewUser(t) r := test.NewRoom(t, alice) - db, close := MustCreateDatabase(t, dbType) + db, close, closeBase := MustCreateDatabase(t, dbType) defer close() + defer closeBase() MustWriteEvents(t, db, r.Events()) }) } @@ -60,8 +63,9 @@ func TestWriteEvents(t *testing.T) { // These tests assert basic functionality of RecentEvents for PDUs func TestRecentEventsPDU(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := MustCreateDatabase(t, dbType) + db, close, closeBase := MustCreateDatabase(t, dbType) defer close() + defer closeBase() alice := test.NewUser(t) // dummy room to make sure SQL queries are filtering on room ID MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) @@ -163,8 +167,9 @@ func TestRecentEventsPDU(t *testing.T) { // The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token func TestGetEventsInRangeWithTopologyToken(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := MustCreateDatabase(t, dbType) + db, close, closeBase := MustCreateDatabase(t, dbType) defer close() + defer closeBase() alice := test.NewUser(t) r := test.NewRoom(t, alice) for i := 0; i < 10; i++ { @@ -404,8 +409,9 @@ func TestSendToDeviceBehaviour(t *testing.T) { bob := test.NewUser(t) deviceID := "one" test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := MustCreateDatabase(t, dbType) + db, close, closeBase := MustCreateDatabase(t, dbType) defer close() + defer closeBase() // At this point there should be no messages. We haven't sent anything // yet. _, events, err := db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100) diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index d68351d4..468d26ac 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -185,6 +185,7 @@ type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error) + SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) } type NotificationData interface { diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 1ad3adc4..136cbea5 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -10,10 +10,13 @@ import ( "github.com/matrix-org/dendrite/internal/caching" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/internal" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "go.uber.org/atomic" @@ -123,7 +126,7 @@ func (p *PDUStreamProvider) CompleteSync( defer reqWaitGroup.Done() jr, jerr := p.getJoinResponseForCompleteSync( - ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, + ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false, ) if jerr != nil { req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed") @@ -149,7 +152,7 @@ func (p *PDUStreamProvider) CompleteSync( if !peek.Deleted { var jr *types.JoinResponse jr, err = p.getJoinResponseForCompleteSync( - ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, + ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true, ) if err != nil { req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed") @@ -281,12 +284,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } } - if len(recentEvents) > 0 { - updateLatestPosition(recentEvents[len(recentEvents)-1].EventID()) - } - if len(delta.StateEvents) > 0 { - updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID()) - } if stateFilter.LazyLoadMembers { delta.StateEvents, err = p.lazyLoadMembers( @@ -306,6 +303,19 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } + // Applies the history visibility rules + events, err := applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents) + if err != nil { + logrus.WithError(err).Error("unable to apply history visibility filter") + } + + if len(events) > 0 { + updateLatestPosition(events[len(events)-1].EventID()) + } + if len(delta.StateEvents) > 0 { + updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID()) + } + switch delta.Membership { case gomatrixserverlib.Join: jr := types.NewJoinResponse() @@ -313,14 +323,17 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( p.addRoomSummary(ctx, jr, delta.RoomID, device.UserID, latestPosition) } jr.Timeline.PrevBatch = &prevBatch - jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) - jr.Timeline.Limited = limited + jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) + // If we are limited by the filter AND the history visibility filter + // didn't "remove" events, return that the response is limited. + jr.Timeline.Limited = limited && len(events) == len(recentEvents) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) res.Rooms.Join[delta.RoomID] = *jr case gomatrixserverlib.Peek: jr := types.NewJoinResponse() jr.Timeline.PrevBatch = &prevBatch + // TODO: Apply history visibility on peeked rooms jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = limited jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) @@ -330,12 +343,12 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( fallthrough // transitions to leave are the same as ban case gomatrixserverlib.Ban: - // TODO: recentEvents may contain events that this user is not allowed to see because they are - // no longer in the room. lr := types.NewLeaveResponse() lr.Timeline.PrevBatch = &prevBatch - lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) - lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true + lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) + // If we are limited by the filter AND the history visibility filter + // didn't "remove" events, return that the response is limited. + lr.Timeline.Limited = limited && len(events) == len(recentEvents) lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) res.Rooms.Leave[delta.RoomID] = *lr } @@ -343,6 +356,41 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( return latestPosition, nil } +// applyHistoryVisibilityFilter gets the current room state and supplies it to ApplyHistoryVisibilityFilter, to make +// sure we always return the required events in the timeline. +func applyHistoryVisibilityFilter( + ctx context.Context, + db storage.Database, + rsAPI roomserverAPI.SyncRoomserverAPI, + roomID, userID string, + limit int, + recentEvents []*gomatrixserverlib.HeaderedEvent, +) ([]*gomatrixserverlib.HeaderedEvent, error) { + // We need to make sure we always include the latest states events, if they are in the timeline. + // We grep at least limit * 2 events, to ensure we really get the needed events. + stateEvents, err := db.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil) + if err != nil { + // Not a fatal error, we can continue without the stateEvents, + // they are only needed if there are state events in the timeline. + logrus.WithError(err).Warnf("failed to get current room state") + } + alwaysIncludeIDs := make(map[string]struct{}, len(stateEvents)) + for _, ev := range stateEvents { + alwaysIncludeIDs[ev.EventID()] = struct{}{} + } + startTime := time.Now() + events, err := internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync") + if err != nil { + + return nil, err + } + logrus.WithFields(logrus.Fields{ + "duration": time.Since(startTime), + "room_id": roomID, + }).Debug("applied history visibility (sync)") + return events, nil +} + func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) { // Work out how many members are in the room. joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition) @@ -390,6 +438,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( eventFilter *gomatrixserverlib.RoomEventFilter, wantFullState bool, device *userapi.Device, + isPeek bool, ) (jr *types.JoinResponse, err error) { jr = types.NewJoinResponse() // TODO: When filters are added, we may need to call this multiple times to get enough events. @@ -404,33 +453,6 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( return } - // TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the - // user shouldn't see, we check the recent events and remove any prior to the join event of the user - // which is equiv to history_visibility: joined - joinEventIndex := -1 - for i := len(recentStreamEvents) - 1; i >= 0; i-- { - ev := recentStreamEvents[i] - if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(device.UserID) { - membership, _ := ev.Membership() - if membership == "join" { - joinEventIndex = i - if i > 0 { - // the create event happens before the first join, so we should cut it at that point instead - if recentStreamEvents[i-1].Type() == gomatrixserverlib.MRoomCreate && recentStreamEvents[i-1].StateKeyEquals("") { - joinEventIndex = i - 1 - break - } - } - break - } - } - } - if joinEventIndex != -1 { - // cut all events earlier than the join (but not the join itself) - recentStreamEvents = recentStreamEvents[joinEventIndex:] - limited = false // so clients know not to try to backpaginate - } - // Work our way through the timeline events and pick out the event IDs // of any state events that appear in the timeline. We'll specifically // exclude them at the next step, so that we don't get duplicate state @@ -474,6 +496,19 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) stateEvents = removeDuplicates(stateEvents, recentEvents) + events := recentEvents + // Only apply history visibility checks if the response is for joined rooms + if !isPeek { + events, err = applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents) + if err != nil { + logrus.WithError(err).Error("unable to apply history visibility filter") + } + } + + // If we are limited by the filter AND the history visibility filter + // didn't "remove" events, return that the response is limited. + limited = limited && len(events) == len(recentEvents) + if stateFilter.LazyLoadMembers { if err != nil { return nil, err @@ -488,8 +523,10 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } jr.Timeline.PrevBatch = prevBatch - jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) - jr.Timeline.Limited = limited + jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) + // If we are limited by the filter AND the history visibility filter + // didn't "remove" events, return that the response is limited. + jr.Timeline.Limited = limited && len(events) == len(recentEvents) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync) return jr, nil } diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 931fef88..dc073a16 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -12,6 +12,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/producers" keyapi "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" @@ -54,6 +55,16 @@ func (s *syncRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *rsap return nil } +func (s *syncRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *rsapi.QueryMembershipForUserRequest, res *rsapi.QueryMembershipForUserResponse) error { + res.IsRoomForgotten = false + res.RoomExists = true + return nil +} + +func (s *syncRoomserverAPI) QueryMembershipAtEvent(ctx context.Context, req *rsapi.QueryMembershipAtEventRequest, res *rsapi.QueryMembershipAtEventResponse) error { + return nil +} + type syncUserAPI struct { userapi.SyncUserAPI accounts []userapi.Device @@ -107,7 +118,7 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) - msgs := toNATSMsgs(t, base, room.Events()) + msgs := toNATSMsgs(t, base, room.Events()...) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) testrig.MustPublishMsgs(t, jsctx, msgs...) @@ -200,7 +211,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { // m.room.power_levels // m.room.join_rules // m.room.history_visibility - msgs := toNATSMsgs(t, base, room.Events()) + msgs := toNATSMsgs(t, base, room.Events()...) sinceTokens := make([]string, len(msgs)) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) for i, msg := range msgs { @@ -315,6 +326,174 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) { } +// This is mainly what Sytest is doing in "test_history_visibility" +func TestMessageHistoryVisibility(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + testHistoryVisibility(t, dbType) + }) +} + +func testHistoryVisibility(t *testing.T, dbType test.DBType) { + type result struct { + seeWithoutJoin bool + seeBeforeJoin bool + seeAfterInvite bool + } + + // create the users + alice := test.NewUser(t) + bob := test.NewUser(t) + + bobDev := userapi.Device{ + ID: "BOBID", + UserID: bob.ID, + AccessToken: "BOD_BEARER_TOKEN", + DisplayName: "BOB", + } + + ctx := context.Background() + // check guest and normal user accounts + for _, accType := range []userapi.AccountType{userapi.AccountTypeGuest, userapi.AccountTypeUser} { + testCases := []struct { + historyVisibility gomatrixserverlib.HistoryVisibility + wantResult result + }{ + { + historyVisibility: gomatrixserverlib.HistoryVisibilityWorldReadable, + wantResult: result{ + seeWithoutJoin: true, + seeBeforeJoin: true, + seeAfterInvite: true, + }, + }, + { + historyVisibility: gomatrixserverlib.HistoryVisibilityShared, + wantResult: result{ + seeWithoutJoin: false, + seeBeforeJoin: true, + seeAfterInvite: true, + }, + }, + { + historyVisibility: gomatrixserverlib.HistoryVisibilityInvited, + wantResult: result{ + seeWithoutJoin: false, + seeBeforeJoin: false, + seeAfterInvite: true, + }, + }, + { + historyVisibility: gomatrixserverlib.HistoryVisibilityJoined, + wantResult: result{ + seeWithoutJoin: false, + seeBeforeJoin: false, + seeAfterInvite: false, + }, + }, + } + + bobDev.AccountType = accType + userType := "guest" + if accType == userapi.AccountTypeUser { + userType = "real user" + } + + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + + // Use the actual internal roomserver API + rsAPI := roomserver.NewInternalAPI(base) + rsAPI.SetFederationAPI(nil, nil) + + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{bobDev}}, rsAPI, &syncKeyAPI{}) + + for _, tc := range testCases { + testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType) + t.Run(testname, func(t *testing.T) { + // create a room with the given visibility + room := test.NewRoom(t, alice, test.RoomHistoryVisibility(tc.historyVisibility)) + + // send the events/messages to NATS to create the rooms + beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("Before invite in a %s room", tc.historyVisibility)}) + eventsToSend := append(room.Events(), beforeJoinEv) + if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // There is only one event, we expect only to be able to see this, if the room is world_readable + w := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{ + "access_token": bobDev.AccessToken, + "dir": "b", + }))) + if w.Code != 200 { + t.Logf("%s", w.Body.String()) + t.Fatalf("got HTTP %d want %d", w.Code, 200) + } + // We only care about the returned events at this point + var res struct { + Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` + } + if err := json.NewDecoder(w.Body).Decode(&res); err != nil { + t.Errorf("failed to decode response body: %s", err) + } + + verifyEventVisible(t, tc.wantResult.seeWithoutJoin, beforeJoinEv, res.Chunk) + + // Create invite, a message, join the room and create another message. + inviteEv := room.CreateAndInsert(t, alice, "m.room.member", map[string]interface{}{"membership": "invite"}, test.WithStateKey(bob.ID)) + afterInviteEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After invite in a %s room", tc.historyVisibility)}) + joinEv := room.CreateAndInsert(t, bob, "m.room.member", map[string]interface{}{"membership": "join"}, test.WithStateKey(bob.ID)) + msgEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After join in a %s room", tc.historyVisibility)}) + + eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv) + + if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // Verify the messages after/before invite are visible or not + w = httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{ + "access_token": bobDev.AccessToken, + "dir": "b", + }))) + if w.Code != 200 { + t.Logf("%s", w.Body.String()) + t.Fatalf("got HTTP %d want %d", w.Code, 200) + } + if err := json.NewDecoder(w.Body).Decode(&res); err != nil { + t.Errorf("failed to decode response body: %s", err) + } + // verify results + verifyEventVisible(t, tc.wantResult.seeBeforeJoin, beforeJoinEv, res.Chunk) + verifyEventVisible(t, tc.wantResult.seeAfterInvite, afterInviteEv, res.Chunk) + }) + } + } +} + +func verifyEventVisible(t *testing.T, wantVisible bool, wantVisibleEvent *gomatrixserverlib.HeaderedEvent, chunk []gomatrixserverlib.ClientEvent) { + t.Helper() + if wantVisible { + for _, ev := range chunk { + if ev.EventID == wantVisibleEvent.EventID() { + return + } + } + t.Fatalf("expected to see event %s but didn't: %+v", wantVisibleEvent.EventID(), chunk) + } else { + for _, ev := range chunk { + if ev.EventID == wantVisibleEvent.EventID() { + t.Fatalf("expected not to see event %s: %+v", wantVisibleEvent.EventID(), string(ev.Content)) + } + } + } +} + func TestSendToDevice(t *testing.T) { test.WithAllDatabases(t, testSendToDevice) } @@ -448,7 +627,7 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { } } -func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverlib.HeaderedEvent) []*nats.Msg { +func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input ...*gomatrixserverlib.HeaderedEvent) []*nats.Msg { result := make([]*nats.Msg, len(input)) for i, ev := range input { var addsStateIDs []string @@ -460,6 +639,7 @@ func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverli NewRoomEvent: &rsapi.OutputNewRoomEvent{ Event: ev, AddsStateEventIDs: addsStateIDs, + HistoryVisibility: ev.Visibility, }, }) } diff --git a/sytest-whitelist b/sytest-whitelist index 2a145291..88dfe920 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -110,8 +110,6 @@ Newly joined room is included in an incremental sync User is offline if they set_presence=offline in their sync Changes to state are included in an incremental sync A change to displayname should appear in incremental /sync -Current state appears in timeline in private history -Current state appears in timeline in private history with many messages before Rooms a user is invited to appear in an initial sync Rooms a user is invited to appear in an incremental sync Sync can be polled for updates @@ -458,7 +456,6 @@ After changing password, a different session no longer works by default Read markers appear in incremental v2 /sync Read markers appear in initial v2 /sync Read markers can be updated -Local users can peek into world_readable rooms by room ID We can't peek into rooms with shared history_visibility We can't peek into rooms with invited history_visibility We can't peek into rooms with joined history_visibility @@ -720,4 +717,26 @@ Setting state twice is idempotent Joining room twice is idempotent Inbound federation can return missing events for shared visibility Inbound federation ignores redactions from invalid servers room > v3 +Joining room twice is idempotent +Getting messages going forward is limited for a departed room (SPEC-216) +m.room.history_visibility == "shared" allows/forbids appropriately for Guest users +m.room.history_visibility == "invited" allows/forbids appropriately for Guest users +m.room.history_visibility == "default" allows/forbids appropriately for Guest users +m.room.history_visibility == "shared" allows/forbids appropriately for Real users +m.room.history_visibility == "invited" allows/forbids appropriately for Real users +m.room.history_visibility == "default" allows/forbids appropriately for Real users +Guest users can sync from world_readable guest_access rooms if joined +Guest users can sync from shared guest_access rooms if joined +Guest users can sync from invited guest_access rooms if joined +Guest users can sync from joined guest_access rooms if joined +Guest users can sync from default guest_access rooms if joined +Real users can sync from world_readable guest_access rooms if joined +Real users can sync from shared guest_access rooms if joined +Real users can sync from invited guest_access rooms if joined +Real users can sync from joined guest_access rooms if joined +Real users can sync from default guest_access rooms if joined +Only see history_visibility changes on boundaries +Current state appears in timeline in private history +Current state appears in timeline in private history with many messages before +Local users can peek into world_readable rooms by room ID Newly joined room includes presence in incremental sync \ No newline at end of file diff --git a/test/room.go b/test/room.go index 6ae403b3..94eb51bb 100644 --- a/test/room.go +++ b/test/room.go @@ -37,10 +37,11 @@ var ( ) type Room struct { - ID string - Version gomatrixserverlib.RoomVersion - preset Preset - creator *User + ID string + Version gomatrixserverlib.RoomVersion + preset Preset + visibility gomatrixserverlib.HistoryVisibility + creator *User authEvents gomatrixserverlib.AuthEvents currentState map[string]*gomatrixserverlib.HeaderedEvent @@ -61,6 +62,7 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room { preset: PresetPublicChat, Version: gomatrixserverlib.RoomVersionV9, currentState: make(map[string]*gomatrixserverlib.HeaderedEvent), + visibility: gomatrixserverlib.HistoryVisibilityShared, } for _, m := range modifiers { m(t, r) @@ -97,10 +99,14 @@ func (r *Room) insertCreateEvents(t *testing.T) { fallthrough case PresetPrivateChat: joinRule.JoinRule = "invite" - hisVis.HistoryVisibility = "shared" + hisVis.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared case PresetPublicChat: joinRule.JoinRule = "public" - hisVis.HistoryVisibility = "shared" + hisVis.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared + } + + if r.visibility != "" { + hisVis.HistoryVisibility = r.visibility } r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{ @@ -183,7 +189,9 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil { t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err) } - return ev.Headered(r.Version) + headeredEvent := ev.Headered(r.Version) + headeredEvent.Visibility = r.visibility + return headeredEvent } // Add a new event to this room DAG. Not thread-safe. @@ -242,6 +250,12 @@ func RoomPreset(p Preset) roomModifier { } } +func RoomHistoryVisibility(vis gomatrixserverlib.HistoryVisibility) roomModifier { + return func(t *testing.T, r *Room) { + r.visibility = vis + } +} + func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier { return func(t *testing.T, r *Room) { r.Version = ver diff --git a/test/user.go b/test/user.go index 0020098a..692eae35 100644 --- a/test/user.go +++ b/test/user.go @@ -20,6 +20,7 @@ import ( "sync/atomic" "testing" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" ) @@ -45,7 +46,8 @@ var ( ) type User struct { - ID string + ID string + accountType api.AccountType // key ID and private key of the server who has this user, if known. keyID gomatrixserverlib.KeyID privKey ed25519.PrivateKey @@ -62,6 +64,12 @@ func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserve } } +func WithAccountType(accountType api.AccountType) UserOpt { + return func(u *User) { + u.accountType = accountType + } +} + func NewUser(t *testing.T, opts ...UserOpt) *User { counter := atomic.AddInt64(&userIDCounter, 1) var u User