From 29841bed6b1a88787211368e6052a87a658c5714 Mon Sep 17 00:00:00 2001 From: Alex Chen Date: Fri, 12 Jul 2019 22:59:53 +0800 Subject: [PATCH] Add typing notifications to /sync responses - fixes #635 (#718) This PR adds a new consumer for typing notifications in syncapi. It also brings changes to syncserver.go and some related files so EDUs can better fit in /sync responses. Fixes #635. Fixes #574. --- syncapi/consumers/clientapi.go | 9 +- syncapi/consumers/roomserver.go | 12 +- syncapi/consumers/typingserver.go | 96 ++++++ syncapi/routing/routing.go | 2 +- syncapi/routing/state.go | 4 +- syncapi/storage/account_data_table.go | 4 +- syncapi/storage/output_room_events_table.go | 11 +- syncapi/storage/syncserver.go | 342 ++++++++++++++------ syncapi/sync/notifier.go | 68 ++-- syncapi/sync/notifier_test.go | 130 +++++--- syncapi/sync/request.go | 44 ++- syncapi/sync/requestpool.go | 20 +- syncapi/sync/userstream.go | 18 +- syncapi/syncapi.go | 16 +- syncapi/types/types.go | 43 ++- testfile | 3 + typingserver/api/output.go | 7 +- typingserver/cache/cache.go | 123 +++++-- typingserver/input/input.go | 12 +- 19 files changed, 712 insertions(+), 252 deletions(-) create mode 100644 syncapi/consumers/typingserver.go diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index d05a7692..f0db5642 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" + "github.com/matrix-org/dendrite/syncapi/types" log "github.com/sirupsen/logrus" sarama "gopkg.in/Shopify/sarama.v1" ) @@ -29,7 +30,7 @@ import ( // OutputClientDataConsumer consumes events that originated in the client API server. type OutputClientDataConsumer struct { clientAPIConsumer *common.ContinualConsumer - db *storage.SyncServerDatabase + db *storage.SyncServerDatasource notifier *sync.Notifier } @@ -38,7 +39,7 @@ func NewOutputClientDataConsumer( cfg *config.Dendrite, kafkaConsumer sarama.Consumer, n *sync.Notifier, - store *storage.SyncServerDatabase, + store *storage.SyncServerDatasource, ) *OutputClientDataConsumer { consumer := common.ContinualConsumer{ @@ -78,7 +79,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error "room_id": output.RoomID, }).Info("received data from client API server") - syncStreamPos, err := s.db.UpsertAccountData( + pduPos, err := s.db.UpsertAccountData( context.TODO(), string(msg.Key), output.RoomID, output.Type, ) if err != nil { @@ -89,7 +90,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error }).Panicf("could not save account data") } - s.notifier.OnNewEvent(nil, string(msg.Key), syncStreamPos) + s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.SyncPosition{PDUPosition: pduPos}) return nil } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 1866a966..e4f1ab46 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -33,7 +33,7 @@ import ( // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { roomServerConsumer *common.ContinualConsumer - db *storage.SyncServerDatabase + db *storage.SyncServerDatasource notifier *sync.Notifier query api.RoomserverQueryAPI } @@ -43,7 +43,7 @@ func NewOutputRoomEventConsumer( cfg *config.Dendrite, kafkaConsumer sarama.Consumer, n *sync.Notifier, - store *storage.SyncServerDatabase, + store *storage.SyncServerDatasource, queryAPI api.RoomserverQueryAPI, ) *OutputRoomEventConsumer { @@ -126,7 +126,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( } } - syncStreamPos, err := s.db.WriteEvent( + pduPos, err := s.db.WriteEvent( ctx, &ev, addsStateEvents, @@ -144,7 +144,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( }).Panicf("roomserver output log: write event failure") return nil } - s.notifier.OnNewEvent(&ev, "", types.StreamPosition(syncStreamPos)) + s.notifier.OnNewEvent(&ev, "", nil, types.SyncPosition{PDUPosition: pduPos}) return nil } @@ -152,7 +152,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( func (s *OutputRoomEventConsumer) onNewInviteEvent( ctx context.Context, msg api.OutputNewInviteEvent, ) error { - syncStreamPos, err := s.db.AddInviteEvent(ctx, msg.Event) + pduPos, err := s.db.AddInviteEvent(ctx, msg.Event) if err != nil { // panic rather than continue with an inconsistent database log.WithFields(log.Fields{ @@ -161,7 +161,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( }).Panicf("roomserver output log: write invite failure") return nil } - s.notifier.OnNewEvent(&msg.Event, "", syncStreamPos) + s.notifier.OnNewEvent(&msg.Event, "", nil, types.SyncPosition{PDUPosition: pduPos}) return nil } diff --git a/syncapi/consumers/typingserver.go b/syncapi/consumers/typingserver.go new file mode 100644 index 00000000..5d998a18 --- /dev/null +++ b/syncapi/consumers/typingserver.go @@ -0,0 +1,96 @@ +// Copyright 2019 Alex Chen +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "encoding/json" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/sync" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/typingserver/api" + log "github.com/sirupsen/logrus" + sarama "gopkg.in/Shopify/sarama.v1" +) + +// OutputTypingEventConsumer consumes events that originated in the typing server. +type OutputTypingEventConsumer struct { + typingConsumer *common.ContinualConsumer + db *storage.SyncServerDatasource + notifier *sync.Notifier +} + +// NewOutputTypingEventConsumer creates a new OutputTypingEventConsumer. +// Call Start() to begin consuming from the typing server. +func NewOutputTypingEventConsumer( + cfg *config.Dendrite, + kafkaConsumer sarama.Consumer, + n *sync.Notifier, + store *storage.SyncServerDatasource, +) *OutputTypingEventConsumer { + + consumer := common.ContinualConsumer{ + Topic: string(cfg.Kafka.Topics.OutputTypingEvent), + Consumer: kafkaConsumer, + PartitionStore: store, + } + + s := &OutputTypingEventConsumer{ + typingConsumer: &consumer, + db: store, + notifier: n, + } + + consumer.ProcessMessage = s.onMessage + + return s +} + +// Start consuming from typing api +func (s *OutputTypingEventConsumer) Start() error { + s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) { + s.notifier.OnNewEvent(nil, roomID, nil, types.SyncPosition{TypingPosition: latestSyncPosition}) + }) + + return s.typingConsumer.Start() +} + +func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { + var output api.OutputTypingEvent + if err := json.Unmarshal(msg.Value, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("typing server output log: message parse failure") + return nil + } + + log.WithFields(log.Fields{ + "room_id": output.Event.RoomID, + "user_id": output.Event.UserID, + "typing": output.Event.Typing, + }).Debug("received data from typing server") + + var typingPos int64 + typingEvent := output.Event + if typingEvent.Typing { + typingPos = s.db.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime) + } else { + typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID) + } + + s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.SyncPosition{TypingPosition: typingPos}) + return nil +} diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index cbdcfb6b..0f5019fc 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -34,7 +34,7 @@ const pathPrefixR0 = "/_matrix/client/r0" // Due to Setup being used to call many other functions, a gocyclo nolint is // applied: // nolint: gocyclo -func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServerDatabase, deviceDB *devices.Database) { +func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServerDatasource, deviceDB *devices.Database) { r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() authData := auth.Data{ diff --git a/syncapi/routing/state.go b/syncapi/routing/state.go index 6b98a0b7..5571a052 100644 --- a/syncapi/routing/state.go +++ b/syncapi/routing/state.go @@ -40,7 +40,7 @@ type stateEventInStateResp struct { // TODO: Check if the user is in the room. If not, check if the room's history // is publicly visible. Current behaviour is returning an empty array if the // user cannot see the room's history. -func OnIncomingStateRequest(req *http.Request, db *storage.SyncServerDatabase, roomID string) util.JSONResponse { +func OnIncomingStateRequest(req *http.Request, db *storage.SyncServerDatasource, roomID string) util.JSONResponse { // TODO(#287): Auth request and handle the case where the user has left (where // we should return the state at the poin they left) @@ -84,7 +84,7 @@ func OnIncomingStateRequest(req *http.Request, db *storage.SyncServerDatabase, r // /rooms/{roomID}/state/{type}/{statekey} request. It will look in current // state to see if there is an event with that type and state key, if there // is then (by default) we return the content, otherwise a 404. -func OnIncomingStateTypeRequest(req *http.Request, db *storage.SyncServerDatabase, roomID string, evType, stateKey string) util.JSONResponse { +func OnIncomingStateTypeRequest(req *http.Request, db *storage.SyncServerDatasource, roomID string, evType, stateKey string) util.JSONResponse { // TODO(#287): Auth request and handle the case where the user has left (where // we should return the state at the poin they left) diff --git a/syncapi/storage/account_data_table.go b/syncapi/storage/account_data_table.go index d4d74d15..9b73ce7d 100644 --- a/syncapi/storage/account_data_table.go +++ b/syncapi/storage/account_data_table.go @@ -19,8 +19,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/common" - - "github.com/matrix-org/dendrite/syncapi/types" ) const accountDataSchema = ` @@ -94,7 +92,7 @@ func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) selectAccountDataInRange( ctx context.Context, userID string, - oldPos, newPos types.StreamPosition, + oldPos, newPos int64, ) (data map[string][]string, err error) { data = make(map[string][]string) diff --git a/syncapi/storage/output_room_events_table.go b/syncapi/storage/output_room_events_table.go index 035db988..06df017c 100644 --- a/syncapi/storage/output_room_events_table.go +++ b/syncapi/storage/output_room_events_table.go @@ -23,7 +23,6 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -109,11 +108,11 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { return } -// selectStateInRange returns the state events between the two given stream positions, exclusive of oldPos, inclusive of newPos. +// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // two positions, only the most recent state is returned. func (s *outputRoomEventsStatements) selectStateInRange( - ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition, + ctx context.Context, txn *sql.Tx, oldPos, newPos int64, ) (map[string]map[string]bool, map[string]streamEvent, error) { stmt := common.TxStmt(txn, s.selectStateInRangeStmt) @@ -171,7 +170,7 @@ func (s *outputRoomEventsStatements) selectStateInRange( eventIDToEvent[ev.EventID()] = streamEvent{ Event: ev, - streamPosition: types.StreamPosition(streamPos), + streamPosition: streamPos, } } @@ -223,7 +222,7 @@ func (s *outputRoomEventsStatements) insertEvent( // RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'. func (s *outputRoomEventsStatements) selectRecentEvents( ctx context.Context, txn *sql.Tx, - roomID string, fromPos, toPos types.StreamPosition, limit int, + roomID string, fromPos, toPos int64, limit int, ) ([]streamEvent, error) { stmt := common.TxStmt(txn, s.selectRecentEventsStmt) rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) @@ -286,7 +285,7 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) { result = append(result, streamEvent{ Event: ev, - streamPosition: types.StreamPosition(streamPos), + streamPosition: streamPos, transactionID: transactionID, }) } diff --git a/syncapi/storage/syncserver.go b/syncapi/storage/syncserver.go index b0655a0a..b4d7ccbd 100644 --- a/syncapi/storage/syncserver.go +++ b/syncapi/storage/syncserver.go @@ -17,7 +17,10 @@ package storage import ( "context" "database/sql" + "encoding/json" "fmt" + "strconv" + "time" "github.com/sirupsen/logrus" @@ -28,6 +31,7 @@ import ( _ "github.com/lib/pq" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/typingserver/cache" "github.com/matrix-org/gomatrixserverlib" ) @@ -35,33 +39,35 @@ type stateDelta struct { roomID string stateEvents []gomatrixserverlib.Event membership string - // The stream position of the latest membership event for this user, if applicable. + // The PDU stream position of the latest membership event for this user, if applicable. // Can be 0 if there is no membership event in this delta. - membershipPos types.StreamPosition + membershipPos int64 } -// Same as gomatrixserverlib.Event but also has the stream position for this event. +// Same as gomatrixserverlib.Event but also has the PDU stream position for this event. type streamEvent struct { gomatrixserverlib.Event - streamPosition types.StreamPosition + streamPosition int64 transactionID *api.TransactionID } -// SyncServerDatabase represents a sync server database -type SyncServerDatabase struct { +// SyncServerDatabase represents a sync server datasource which manages +// both the database for PDUs and caches for EDUs. +type SyncServerDatasource struct { db *sql.DB common.PartitionOffsetStatements accountData accountDataStatements events outputRoomEventsStatements roomstate currentRoomStateStatements invites inviteEventsStatements + typingCache *cache.TypingCache } // NewSyncServerDatabase creates a new sync server database -func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) { - var d SyncServerDatabase +func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, error) { + var d SyncServerDatasource var err error - if d.db, err = sql.Open("postgres", dataSourceName); err != nil { + if d.db, err = sql.Open("postgres", dbDataSourceName); err != nil { return nil, err } if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil { @@ -79,11 +85,12 @@ func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) { if err := d.invites.prepare(d.db); err != nil { return nil, err } + d.typingCache = cache.NewTypingCache() return &d, nil } // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. -func (d *SyncServerDatabase) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { +func (d *SyncServerDatasource) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { return d.roomstate.selectJoinedUsers(ctx) } @@ -92,7 +99,7 @@ func (d *SyncServerDatabase) AllJoinedUsersInRooms(ctx context.Context) (map[str // If an event is not found in the database then it will be omitted from the list. // Returns an error if there was a problem talking with the database. // Does not include any transaction IDs in the returned events. -func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { +func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs) if err != nil { return nil, err @@ -104,38 +111,38 @@ func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]g } // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races -// when generating the stream position for this event. Returns the sync stream position for the inserted event. +// when generating the sync stream position for this event. Returns the sync stream position for the inserted event. // Returns an error if there was a problem inserting this event. -func (d *SyncServerDatabase) WriteEvent( +func (d *SyncServerDatasource) WriteEvent( ctx context.Context, ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string, transactionID *api.TransactionID, -) (streamPos types.StreamPosition, returnErr error) { +) (pduPosition int64, returnErr error) { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { var err error pos, err := d.events.insertEvent(ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID) if err != nil { return err } - streamPos = types.StreamPosition(pos) + pduPosition = pos if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { // Nothing to do, the event may have just been a message event. return nil } - return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, streamPos) + return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition) }) return } -func (d *SyncServerDatabase) updateRoomState( +func (d *SyncServerDatasource) updateRoomState( ctx context.Context, txn *sql.Tx, removedEventIDs []string, addedEvents []gomatrixserverlib.Event, - streamPos types.StreamPosition, + pduPosition int64, ) error { // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. for _, eventID := range removedEventIDs { @@ -157,7 +164,7 @@ func (d *SyncServerDatabase) updateRoomState( } membership = &value } - if err := d.roomstate.upsertRoomState(ctx, txn, event, membership, int64(streamPos)); err != nil { + if err := d.roomstate.upsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { return err } } @@ -168,7 +175,7 @@ func (d *SyncServerDatabase) updateRoomState( // GetStateEvent returns the Matrix state event of a given type for a given room with a given state key // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error -func (d *SyncServerDatabase) GetStateEvent( +func (d *SyncServerDatasource) GetStateEvent( ctx context.Context, roomID, evType, stateKey string, ) (*gomatrixserverlib.Event, error) { return d.roomstate.selectStateEvent(ctx, roomID, evType, stateKey) @@ -177,7 +184,7 @@ func (d *SyncServerDatabase) GetStateEvent( // GetStateEventsForRoom fetches the state events for a given room. // Returns an empty slice if no state events could be found for this room. // Returns an error if there was an issue with the retrieval. -func (d *SyncServerDatabase) GetStateEventsForRoom( +func (d *SyncServerDatasource) GetStateEventsForRoom( ctx context.Context, roomID string, ) (stateEvents []gomatrixserverlib.Event, err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) error { @@ -187,46 +194,49 @@ func (d *SyncServerDatabase) GetStateEventsForRoom( return } -// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. -func (d *SyncServerDatabase) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { - return d.syncStreamPositionTx(ctx, nil) +// SyncPosition returns the latest positions for syncing. +func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.SyncPosition, error) { + return d.syncPositionTx(ctx, nil) } -func (d *SyncServerDatabase) syncStreamPositionTx( +func (d *SyncServerDatasource) syncPositionTx( ctx context.Context, txn *sql.Tx, -) (types.StreamPosition, error) { - maxID, err := d.events.selectMaxEventID(ctx, txn) +) (sp types.SyncPosition, err error) { + + maxEventID, err := d.events.selectMaxEventID(ctx, txn) if err != nil { - return 0, err + return sp, err } maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) if err != nil { - return 0, err + return sp, err } - if maxAccountDataID > maxID { - maxID = maxAccountDataID + if maxAccountDataID > maxEventID { + maxEventID = maxAccountDataID } maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) if err != nil { - return 0, err + return sp, err } - if maxInviteID > maxID { - maxID = maxInviteID + if maxInviteID > maxEventID { + maxEventID = maxInviteID } - return types.StreamPosition(maxID), nil + sp.PDUPosition = maxEventID + + sp.TypingPosition = d.typingCache.GetLatestSyncPosition() + + return } -// IncrementalSync returns all the data needed in order to create an incremental -// sync response for the given user. Events returned will include any client -// transaction IDs associated with the given device. These transaction IDs come -// from when the device sent the event via an API that included a transaction -// ID. -func (d *SyncServerDatabase) IncrementalSync( +// addPDUDeltaToResponse adds all PDU deltas to a sync response. +// IDs of all rooms the user joined are returned so EDU deltas can be added for them. +func (d *SyncServerDatasource) addPDUDeltaToResponse( ctx context.Context, device authtypes.Device, - fromPos, toPos types.StreamPosition, + fromPos, toPos int64, numRecentEventsPerRoom int, -) (*types.Response, error) { + res *types.Response, +) ([]string, error) { txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot) if err != nil { return nil, err @@ -235,7 +245,7 @@ func (d *SyncServerDatabase) IncrementalSync( defer common.EndTransaction(txn, &succeeded) // Work out which rooms to return in the response. This is done by getting not only the currently - // joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions. + // joined rooms, but also which rooms have membership transitions for this user between the 2 PDU stream positions. // This works out what the 'state' key should be for each room as well as which membership block // to put the room into. deltas, err := d.getStateDeltas(ctx, &device, txn, fromPos, toPos, device.UserID) @@ -243,8 +253,9 @@ func (d *SyncServerDatabase) IncrementalSync( return nil, err } - res := types.NewResponse(toPos) + joinedRoomIDs := make([]string, 0, len(deltas)) for _, delta := range deltas { + joinedRoomIDs = append(joinedRoomIDs, delta.roomID) err = d.addRoomDeltaToResponse(ctx, &device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res) if err != nil { return nil, err @@ -257,52 +268,151 @@ func (d *SyncServerDatabase) IncrementalSync( } succeeded = true + return joinedRoomIDs, nil +} + +// addTypingDeltaToResponse adds all typing notifications to a sync response +// since the specified position. +func (d *SyncServerDatasource) addTypingDeltaToResponse( + since int64, + joinedRoomIDs []string, + res *types.Response, +) error { + var jr types.JoinResponse + var ok bool + var err error + for _, roomID := range joinedRoomIDs { + if typingUsers, updated := d.typingCache.GetTypingUsersIfUpdatedAfter( + roomID, since, + ); updated { + ev := gomatrixserverlib.ClientEvent{ + Type: gomatrixserverlib.MTyping, + } + ev.Content, err = json.Marshal(map[string]interface{}{ + "user_ids": typingUsers, + }) + if err != nil { + return err + } + + if jr, ok = res.Rooms.Join[roomID]; !ok { + jr = *types.NewJoinResponse() + } + jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev) + res.Rooms.Join[roomID] = jr + } + } + return nil +} + +// addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if +// the positions of that type are not equal in fromPos and toPos. +func (d *SyncServerDatasource) addEDUDeltaToResponse( + fromPos, toPos types.SyncPosition, + joinedRoomIDs []string, + res *types.Response, +) (err error) { + + if fromPos.TypingPosition != toPos.TypingPosition { + err = d.addTypingDeltaToResponse( + fromPos.TypingPosition, joinedRoomIDs, res, + ) + } + + return +} + +// IncrementalSync returns all the data needed in order to create an incremental +// sync response for the given user. Events returned will include any client +// transaction IDs associated with the given device. These transaction IDs come +// from when the device sent the event via an API that included a transaction +// ID. +func (d *SyncServerDatasource) IncrementalSync( + ctx context.Context, + device authtypes.Device, + fromPos, toPos types.SyncPosition, + numRecentEventsPerRoom int, +) (*types.Response, error) { + nextBatchPos := fromPos.WithUpdates(toPos) + res := types.NewResponse(nextBatchPos) + + var joinedRoomIDs []string + var err error + if fromPos.PDUPosition != toPos.PDUPosition { + joinedRoomIDs, err = d.addPDUDeltaToResponse( + ctx, device, fromPos.PDUPosition, toPos.PDUPosition, numRecentEventsPerRoom, res, + ) + } else { + joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership( + ctx, nil, device.UserID, "join", + ) + } + if err != nil { + return nil, err + } + + err = d.addEDUDeltaToResponse( + fromPos, toPos, joinedRoomIDs, res, + ) + if err != nil { + return nil, err + } + return res, nil } -// CompleteSync a complete /sync API response for the given user. -func (d *SyncServerDatabase) CompleteSync( - ctx context.Context, userID string, numRecentEventsPerRoom int, -) (*types.Response, error) { +// getResponseWithPDUsForCompleteSync creates a response and adds all PDUs needed +// to it. It returns toPos and joinedRoomIDs for use of adding EDUs. +func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( + ctx context.Context, + userID string, + numRecentEventsPerRoom int, +) ( + res *types.Response, + toPos types.SyncPosition, + joinedRoomIDs []string, + err error, +) { // This needs to be all done in a transaction as we need to do multiple SELECTs, and we need to have - // a consistent view of the database throughout. This includes extracting the sync stream position. + // a consistent view of the database throughout. This includes extracting the sync position. // This does have the unfortunate side-effect that all the matrixy logic resides in this function, // but it's better to not hide the fact that this is being done in a transaction. txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot) if err != nil { - return nil, err + return } var succeeded bool defer common.EndTransaction(txn, &succeeded) - // Get the current stream position which we will base the sync response on. - pos, err := d.syncStreamPositionTx(ctx, txn) + // Get the current sync position which we will base the sync response on. + toPos, err = d.syncPositionTx(ctx, txn) if err != nil { - return nil, err + return } + res = types.NewResponse(toPos) + // Extract room state and recent events for all rooms the user is joined to. - roomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "join") + joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "join") if err != nil { - return nil, err + return } // Build up a /sync response. Add joined rooms. - res := types.NewResponse(pos) - for _, roomID := range roomIDs { + for _, roomID := range joinedRoomIDs { var stateEvents []gomatrixserverlib.Event stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID) if err != nil { - return nil, err + return } // TODO: When filters are added, we may need to call this multiple times to get enough events. // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 var recentStreamEvents []streamEvent recentStreamEvents, err = d.events.selectRecentEvents( - ctx, txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom, + ctx, txn, roomID, 0, toPos.PDUPosition, numRecentEventsPerRoom, ) if err != nil { - return nil, err + return } // We don't include a device here as we don't need to send down @@ -311,10 +421,12 @@ func (d *SyncServerDatabase) CompleteSync( stateEvents = removeDuplicates(stateEvents, recentEvents) jr := types.NewJoinResponse() - if prevBatch := recentStreamEvents[0].streamPosition - 1; prevBatch > 0 { - jr.Timeline.PrevBatch = types.StreamPosition(prevBatch).String() + if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 { + // Use the short form of batch token for prev_batch + jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) } else { - jr.Timeline.PrevBatch = types.StreamPosition(1).String() + // Use the short form of batch token for prev_batch + jr.Timeline.PrevBatch = "1" } jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = true @@ -322,12 +434,34 @@ func (d *SyncServerDatabase) CompleteSync( res.Rooms.Join[roomID] = *jr } - if err = d.addInvitesToResponse(ctx, txn, userID, 0, pos, res); err != nil { - return nil, err + if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil { + return } succeeded = true - return res, err + return res, toPos, joinedRoomIDs, err +} + +// CompleteSync returns a complete /sync API response for the given user. +func (d *SyncServerDatasource) CompleteSync( + ctx context.Context, userID string, numRecentEventsPerRoom int, +) (*types.Response, error) { + res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync( + ctx, userID, numRecentEventsPerRoom, + ) + if err != nil { + return nil, err + } + + // Use a zero value SyncPosition for fromPos so all EDU states are added. + err = d.addEDUDeltaToResponse( + types.SyncPosition{}, toPos, joinedRoomIDs, res, + ) + if err != nil { + return nil, err + } + + return res, nil } var txReadOnlySnapshot = sql.TxOptions{ @@ -345,8 +479,8 @@ var txReadOnlySnapshot = sql.TxOptions{ // Returns a map following the format data[roomID] = []dataTypes // If no data is retrieved, returns an empty map // If there was an issue with the retrieval, returns an error -func (d *SyncServerDatabase) GetAccountDataInRange( - ctx context.Context, userID string, oldPos, newPos types.StreamPosition, +func (d *SyncServerDatasource) GetAccountDataInRange( + ctx context.Context, userID string, oldPos, newPos int64, ) (map[string][]string, error) { return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos) } @@ -357,26 +491,24 @@ func (d *SyncServerDatabase) GetAccountDataInRange( // If no data with the given type, user ID and room ID exists in the database, // creates a new row, else update the existing one // Returns an error if there was an issue with the upsert -func (d *SyncServerDatabase) UpsertAccountData( +func (d *SyncServerDatasource) UpsertAccountData( ctx context.Context, userID, roomID, dataType string, -) (types.StreamPosition, error) { - pos, err := d.accountData.insertAccountData(ctx, userID, roomID, dataType) - return types.StreamPosition(pos), err +) (int64, error) { + return d.accountData.insertAccountData(ctx, userID, roomID, dataType) } // AddInviteEvent stores a new invite event for a user. // If the invite was successfully stored this returns the stream ID it was stored at. // Returns an error if there was a problem communicating with the database. -func (d *SyncServerDatabase) AddInviteEvent( +func (d *SyncServerDatasource) AddInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.Event, -) (types.StreamPosition, error) { - pos, err := d.invites.insertInviteEvent(ctx, inviteEvent) - return types.StreamPosition(pos), err +) (int64, error) { + return d.invites.insertInviteEvent(ctx, inviteEvent) } // RetireInviteEvent removes an old invite event from the database. // Returns an error if there was a problem communicating with the database. -func (d *SyncServerDatabase) RetireInviteEvent( +func (d *SyncServerDatasource) RetireInviteEvent( ctx context.Context, inviteEventID string, ) error { // TODO: Record that invite has been retired in a stream so that we can @@ -385,10 +517,30 @@ func (d *SyncServerDatabase) RetireInviteEvent( return err } -func (d *SyncServerDatabase) addInvitesToResponse( +func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) { + d.typingCache.SetTimeoutCallback(fn) +} + +// AddTypingUser adds a typing user to the typing cache. +// Returns the newly calculated sync position for typing notifications. +func (d *SyncServerDatasource) AddTypingUser( + userID, roomID string, expireTime *time.Time, +) int64 { + return d.typingCache.AddTypingUser(userID, roomID, expireTime) +} + +// RemoveTypingUser removes a typing user from the typing cache. +// Returns the newly calculated sync position for typing notifications. +func (d *SyncServerDatasource) RemoveTypingUser( + userID, roomID string, +) int64 { + return d.typingCache.RemoveUser(userID, roomID) +} + +func (d *SyncServerDatasource) addInvitesToResponse( ctx context.Context, txn *sql.Tx, userID string, - fromPos, toPos types.StreamPosition, + fromPos, toPos int64, res *types.Response, ) error { invites, err := d.invites.selectInviteEventsInRange( @@ -409,11 +561,11 @@ func (d *SyncServerDatabase) addInvitesToResponse( } // addRoomDeltaToResponse adds a room state delta to a sync response -func (d *SyncServerDatabase) addRoomDeltaToResponse( +func (d *SyncServerDatasource) addRoomDeltaToResponse( ctx context.Context, device *authtypes.Device, txn *sql.Tx, - fromPos, toPos types.StreamPosition, + fromPos, toPos int64, delta stateDelta, numRecentEventsPerRoom int, res *types.Response, @@ -445,10 +597,12 @@ func (d *SyncServerDatabase) addRoomDeltaToResponse( switch delta.membership { case "join": jr := types.NewJoinResponse() - if prevBatch := recentStreamEvents[0].streamPosition - 1; prevBatch > 0 { - jr.Timeline.PrevBatch = types.StreamPosition(prevBatch).String() + if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 { + // Use the short form of batch token for prev_batch + jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) } else { - jr.Timeline.PrevBatch = types.StreamPosition(1).String() + // Use the short form of batch token for prev_batch + jr.Timeline.PrevBatch = "1" } jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true @@ -460,10 +614,12 @@ func (d *SyncServerDatabase) addRoomDeltaToResponse( // TODO: recentEvents may contain events that this user is not allowed to see because they are // no longer in the room. lr := types.NewLeaveResponse() - if prevBatch := recentStreamEvents[0].streamPosition - 1; prevBatch > 0 { - lr.Timeline.PrevBatch = types.StreamPosition(prevBatch).String() + if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 { + // Use the short form of batch token for prev_batch + lr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) } else { - lr.Timeline.PrevBatch = types.StreamPosition(1).String() + // Use the short form of batch token for prev_batch + lr.Timeline.PrevBatch = "1" } lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true @@ -476,7 +632,7 @@ func (d *SyncServerDatabase) addRoomDeltaToResponse( // fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. // Returns a map of room ID to list of events. -func (d *SyncServerDatabase) fetchStateEvents( +func (d *SyncServerDatasource) fetchStateEvents( ctx context.Context, txn *sql.Tx, roomIDToEventIDSet map[string]map[string]bool, eventIDToEvent map[string]streamEvent, @@ -521,7 +677,7 @@ func (d *SyncServerDatabase) fetchStateEvents( return stateBetween, nil } -func (d *SyncServerDatabase) fetchMissingStateEvents( +func (d *SyncServerDatasource) fetchMissingStateEvents( ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]streamEvent, error) { // Fetch from the events table first so we pick up the stream ID for the @@ -560,9 +716,9 @@ func (d *SyncServerDatabase) fetchMissingStateEvents( return events, nil } -func (d *SyncServerDatabase) getStateDeltas( +func (d *SyncServerDatasource) getStateDeltas( ctx context.Context, device *authtypes.Device, txn *sql.Tx, - fromPos, toPos types.StreamPosition, userID string, + fromPos, toPos int64, userID string, ) ([]stateDelta, error) { // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // - Get membership list changes for this user in this sync response @@ -601,7 +757,7 @@ func (d *SyncServerDatabase) getStateDeltas( } s := make([]streamEvent, len(allState)) for i := 0; i < len(s); i++ { - s[i] = streamEvent{Event: allState[i], streamPosition: types.StreamPosition(0)} + s[i] = streamEvent{Event: allState[i], streamPosition: 0} } state[roomID] = s continue // we'll add this room in when we do joined rooms diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index 5ed701d8..30ac3a2e 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -26,7 +26,7 @@ import ( ) // Notifier will wake up sleeping requests when there is some new data. -// It does not tell requests what that data is, only the stream position which +// It does not tell requests what that data is, only the sync position which // they can use to get at it. This is done to prevent races whereby we tell the caller // the event, but the token has already advanced by the time they fetch it, resulting // in missed events. @@ -35,18 +35,18 @@ type Notifier struct { roomIDToJoinedUsers map[string]userIDSet // Protects currPos and userStreams. streamLock *sync.Mutex - // The latest sync stream position - currPos types.StreamPosition + // The latest sync position + currPos types.SyncPosition // A map of user_id => UserStream which can be used to wake a given user's /sync request. userStreams map[string]*UserStream // The last time we cleaned out stale entries from the userStreams map lastCleanUpTime time.Time } -// NewNotifier creates a new notifier set to the given stream position. +// NewNotifier creates a new notifier set to the given sync position. // In order for this to be of any use, the Notifier needs to be told all rooms and // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). -func NewNotifier(pos types.StreamPosition) *Notifier { +func NewNotifier(pos types.SyncPosition) *Notifier { return &Notifier{ currPos: pos, roomIDToJoinedUsers: make(map[string]userIDSet), @@ -58,20 +58,30 @@ func NewNotifier(pos types.StreamPosition) *Notifier { // OnNewEvent is called when a new event is received from the room server. Must only be // called from a single goroutine, to avoid races between updates which could set the -// current position in the stream incorrectly. -// Can be called either with a *gomatrixserverlib.Event, or with an user ID -func (n *Notifier) OnNewEvent(ev *gomatrixserverlib.Event, userID string, pos types.StreamPosition) { +// current sync position incorrectly. +// Chooses which user sync streams to update by a provided *gomatrixserverlib.Event +// (based on the users in the event's room), +// a roomID directly, or a list of user IDs, prioritised by parameter ordering. +// posUpdate contains the latest position(s) for one or more types of events. +// If a position in posUpdate is 0, it means no updates are available of that type. +// Typically a consumer supplies a posUpdate with the latest sync position for the +// event type it handles, leaving other fields as 0. +func (n *Notifier) OnNewEvent( + ev *gomatrixserverlib.Event, roomID string, userIDs []string, + posUpdate types.SyncPosition, +) { // update the current position then notify relevant /sync streams. // This needs to be done PRIOR to waking up users as they will read this value. n.streamLock.Lock() defer n.streamLock.Unlock() - n.currPos = pos + latestPos := n.currPos.WithUpdates(posUpdate) + n.currPos = latestPos n.removeEmptyUserStreams() if ev != nil { // Map this event's room_id to a list of joined users, and wake them up. - userIDs := n.joinedUsers(ev.RoomID()) + usersToNotify := n.joinedUsers(ev.RoomID()) // If this is an invite, also add in the invitee to this list. if ev.Type() == "m.room.member" && ev.StateKey() != nil { targetUserID := *ev.StateKey() @@ -84,11 +94,11 @@ func (n *Notifier) OnNewEvent(ev *gomatrixserverlib.Event, userID string, pos ty // Keep the joined user map up-to-date switch membership { case "invite": - userIDs = append(userIDs, targetUserID) + usersToNotify = append(usersToNotify, targetUserID) case "join": // Manually append the new user's ID so they get notified // along all members in the room - userIDs = append(userIDs, targetUserID) + usersToNotify = append(usersToNotify, targetUserID) n.addJoinedUser(ev.RoomID(), targetUserID) case "leave": fallthrough @@ -98,11 +108,15 @@ func (n *Notifier) OnNewEvent(ev *gomatrixserverlib.Event, userID string, pos ty } } - for _, toNotifyUserID := range userIDs { - n.wakeupUser(toNotifyUserID, pos) - } - } else if len(userID) > 0 { - n.wakeupUser(userID, pos) + n.wakeupUsers(usersToNotify, latestPos) + } else if roomID != "" { + n.wakeupUsers(n.joinedUsers(roomID), latestPos) + } else if len(userIDs) > 0 { + n.wakeupUsers(userIDs, latestPos) + } else { + log.WithFields(log.Fields{ + "posUpdate": posUpdate.String, + }).Warn("Notifier.OnNewEvent called but caller supplied no user to wake up") } } @@ -127,7 +141,7 @@ func (n *Notifier) GetListener(req syncRequest) UserStreamListener { } // Load the membership states required to notify users correctly. -func (n *Notifier) Load(ctx context.Context, db *storage.SyncServerDatabase) error { +func (n *Notifier) Load(ctx context.Context, db *storage.SyncServerDatasource) error { roomToUsers, err := db.AllJoinedUsersInRooms(ctx) if err != nil { return err @@ -136,8 +150,11 @@ func (n *Notifier) Load(ctx context.Context, db *storage.SyncServerDatabase) err return nil } -// CurrentPosition returns the current stream position -func (n *Notifier) CurrentPosition() types.StreamPosition { +// CurrentPosition returns the current sync position +func (n *Notifier) CurrentPosition() types.SyncPosition { + n.streamLock.Lock() + defer n.streamLock.Unlock() + return n.currPos } @@ -156,12 +173,13 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) { } } -func (n *Notifier) wakeupUser(userID string, newPos types.StreamPosition) { - stream := n.fetchUserStream(userID, false) - if stream == nil { - return +func (n *Notifier) wakeupUsers(userIDs []string, newPos types.SyncPosition) { + for _, userID := range userIDs { + stream := n.fetchUserStream(userID, false) + if stream != nil { + stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream + } } - stream.Broadcast(newPos) // wakeup all goroutines Wait()ing on this stream } // fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true, diff --git a/syncapi/sync/notifier_test.go b/syncapi/sync/notifier_test.go index 4fa54393..904315e9 100644 --- a/syncapi/sync/notifier_test.go +++ b/syncapi/sync/notifier_test.go @@ -32,19 +32,40 @@ var ( randomMessageEvent gomatrixserverlib.Event aliceInviteBobEvent gomatrixserverlib.Event bobLeaveEvent gomatrixserverlib.Event + syncPositionVeryOld types.SyncPosition + syncPositionBefore types.SyncPosition + syncPositionAfter types.SyncPosition + syncPositionNewEDU types.SyncPosition + syncPositionAfter2 types.SyncPosition ) var ( - streamPositionVeryOld = types.StreamPosition(5) - streamPositionBefore = types.StreamPosition(11) - streamPositionAfter = types.StreamPosition(12) - streamPositionAfter2 = types.StreamPosition(13) - roomID = "!test:localhost" - alice = "@alice:localhost" - bob = "@bob:localhost" + roomID = "!test:localhost" + alice = "@alice:localhost" + bob = "@bob:localhost" ) func init() { + baseSyncPos := types.SyncPosition{ + PDUPosition: 0, + TypingPosition: 0, + } + + syncPositionVeryOld = baseSyncPos + syncPositionVeryOld.PDUPosition = 5 + + syncPositionBefore = baseSyncPos + syncPositionBefore.PDUPosition = 11 + + syncPositionAfter = baseSyncPos + syncPositionAfter.PDUPosition = 12 + + syncPositionNewEDU = syncPositionAfter + syncPositionNewEDU.TypingPosition = 1 + + syncPositionAfter2 = baseSyncPos + syncPositionAfter2.PDUPosition = 13 + var err error randomMessageEvent, err = gomatrixserverlib.NewEventFromTrustedJSON([]byte(`{ "type": "m.room.message", @@ -92,19 +113,19 @@ func init() { // Test that the current position is returned if a request is already behind. func TestImmediateNotification(t *testing.T) { - n := NewNotifier(streamPositionBefore) - pos, err := waitForEvents(n, newTestSyncRequest(alice, streamPositionVeryOld)) + n := NewNotifier(syncPositionBefore) + pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionVeryOld)) if err != nil { t.Fatalf("TestImmediateNotification error: %s", err) } - if pos != streamPositionBefore { - t.Fatalf("TestImmediateNotification want %d, got %d", streamPositionBefore, pos) + if pos != syncPositionBefore { + t.Fatalf("TestImmediateNotification want %d, got %d", syncPositionBefore, pos) } } // Test that new events to a joined room unblocks the request. func TestNewEventAndJoinedToRoom(t *testing.T) { - n := NewNotifier(streamPositionBefore) + n := NewNotifier(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, }) @@ -112,12 +133,12 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) if err != nil { t.Errorf("TestNewEventAndJoinedToRoom error: %s", err) } - if pos != streamPositionAfter { - t.Errorf("TestNewEventAndJoinedToRoom want %d, got %d", streamPositionAfter, pos) + if pos != syncPositionAfter { + t.Errorf("TestNewEventAndJoinedToRoom want %d, got %d", syncPositionAfter, pos) } wg.Done() }() @@ -125,14 +146,14 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { stream := n.fetchUserStream(bob, true) waitForBlocking(stream, 1) - n.OnNewEvent(&randomMessageEvent, "", streamPositionAfter) + n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) wg.Wait() } // Test that an invite unblocks the request func TestNewInviteEventForUser(t *testing.T) { - n := NewNotifier(streamPositionBefore) + n := NewNotifier(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, }) @@ -140,12 +161,12 @@ func TestNewInviteEventForUser(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) if err != nil { t.Errorf("TestNewInviteEventForUser error: %s", err) } - if pos != streamPositionAfter { - t.Errorf("TestNewInviteEventForUser want %d, got %d", streamPositionAfter, pos) + if pos != syncPositionAfter { + t.Errorf("TestNewInviteEventForUser want %d, got %d", syncPositionAfter, pos) } wg.Done() }() @@ -153,14 +174,42 @@ func TestNewInviteEventForUser(t *testing.T) { stream := n.fetchUserStream(bob, true) waitForBlocking(stream, 1) - n.OnNewEvent(&aliceInviteBobEvent, "", streamPositionAfter) + n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionAfter) + + wg.Wait() +} + +// Test an EDU-only update wakes up the request. +func TestEDUWakeup(t *testing.T) { + n := NewNotifier(syncPositionAfter) + n.setUsersJoinedToRooms(map[string][]string{ + roomID: {alice, bob}, + }) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter)) + if err != nil { + t.Errorf("TestNewInviteEventForUser error: %s", err) + } + if pos != syncPositionNewEDU { + t.Errorf("TestNewInviteEventForUser want %d, got %d", syncPositionNewEDU, pos) + } + wg.Done() + }() + + stream := n.fetchUserStream(bob, true) + waitForBlocking(stream, 1) + + n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionNewEDU) wg.Wait() } // Test that all blocked requests get woken up on a new event. func TestMultipleRequestWakeup(t *testing.T) { - n := NewNotifier(streamPositionBefore) + n := NewNotifier(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, }) @@ -168,12 +217,12 @@ func TestMultipleRequestWakeup(t *testing.T) { var wg sync.WaitGroup wg.Add(3) poll := func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) if err != nil { t.Errorf("TestMultipleRequestWakeup error: %s", err) } - if pos != streamPositionAfter { - t.Errorf("TestMultipleRequestWakeup want %d, got %d", streamPositionAfter, pos) + if pos != syncPositionAfter { + t.Errorf("TestMultipleRequestWakeup want %d, got %d", syncPositionAfter, pos) } wg.Done() } @@ -184,7 +233,7 @@ func TestMultipleRequestWakeup(t *testing.T) { stream := n.fetchUserStream(bob, true) waitForBlocking(stream, 3) - n.OnNewEvent(&randomMessageEvent, "", streamPositionAfter) + n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) wg.Wait() @@ -198,7 +247,7 @@ func TestMultipleRequestWakeup(t *testing.T) { func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { // listen as bob. Make bob leave room. Make alice send event to room. // Make sure alice gets woken up only and not bob as well. - n := NewNotifier(streamPositionBefore) + n := NewNotifier(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, }) @@ -208,18 +257,18 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { // Make bob leave the room leaveWG.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) if err != nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err) } - if pos != streamPositionAfter { - t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", streamPositionAfter, pos) + if pos != syncPositionAfter { + t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", syncPositionAfter, pos) } leaveWG.Done() }() bobStream := n.fetchUserStream(bob, true) waitForBlocking(bobStream, 1) - n.OnNewEvent(&bobLeaveEvent, "", streamPositionAfter) + n.OnNewEvent(&bobLeaveEvent, "", nil, syncPositionAfter) leaveWG.Wait() // send an event into the room. Make sure alice gets it. Bob should not. @@ -227,19 +276,19 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { aliceStream := n.fetchUserStream(alice, true) aliceWG.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(alice, streamPositionAfter)) + pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionAfter)) if err != nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err) } - if pos != streamPositionAfter2 { - t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", streamPositionAfter2, pos) + if pos != syncPositionAfter2 { + t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", syncPositionAfter2, pos) } aliceWG.Done() }() go func() { // this should timeout with an error (but the main goroutine won't wait for the timeout explicitly) - _, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionAfter)) + _, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter)) if err == nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom expect error but got nil") } @@ -248,7 +297,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { waitForBlocking(aliceStream, 1) waitForBlocking(bobStream, 1) - n.OnNewEvent(&randomMessageEvent, "", streamPositionAfter2) + n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter2) aliceWG.Wait() // it's possible that at this point alice has been informed and bob is about to be informed, so wait @@ -256,18 +305,17 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { time.Sleep(1 * time.Millisecond) } -// same as Notifier.WaitForEvents but with a timeout. -func waitForEvents(n *Notifier, req syncRequest) (types.StreamPosition, error) { +func waitForEvents(n *Notifier, req syncRequest) (types.SyncPosition, error) { listener := n.GetListener(req) defer listener.Close() select { case <-time.After(5 * time.Second): - return types.StreamPosition(0), fmt.Errorf( + return types.SyncPosition{}, fmt.Errorf( "waitForEvents timed out waiting for %s (pos=%d)", req.device.UserID, req.since, ) case <-listener.GetNotifyChannel(*req.since): - p := listener.GetStreamPosition() + p := listener.GetSyncPosition() return p, nil } } @@ -280,7 +328,7 @@ func waitForBlocking(s *UserStream, numBlocking uint) { } } -func newTestSyncRequest(userID string, since types.StreamPosition) syncRequest { +func newTestSyncRequest(userID string, since types.SyncPosition) syncRequest { return syncRequest{ device: authtypes.Device{UserID: userID}, timeout: 1 * time.Minute, diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index 35a15f6f..a5d2f60f 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -16,8 +16,10 @@ package sync import ( "context" + "errors" "net/http" "strconv" + "strings" "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -36,7 +38,7 @@ type syncRequest struct { device authtypes.Device limit int timeout time.Duration - since *types.StreamPosition // nil means that no since token was supplied + since *types.SyncPosition // nil means that no since token was supplied wantFullState bool log *log.Entry } @@ -73,15 +75,41 @@ func getTimeout(timeoutMS string) time.Duration { } // getSyncStreamPosition tries to parse a 'since' token taken from the API to a -// stream position. If the string is empty then (nil, nil) is returned. -func getSyncStreamPosition(since string) (*types.StreamPosition, error) { +// types.SyncPosition. If the string is empty then (nil, nil) is returned. +// There are two forms of tokens: The full length form containing all PDU and EDU +// positions separated by "_", and the short form containing only the PDU +// position. Short form can be used for, e.g., `prev_batch` tokens. +func getSyncStreamPosition(since string) (*types.SyncPosition, error) { if since == "" { return nil, nil } - i, err := strconv.Atoi(since) - if err != nil { - return nil, err + + posStrings := strings.Split(since, "_") + if len(posStrings) != 2 && len(posStrings) != 1 { + // A token can either be full length or short (PDU-only). + return nil, errors.New("malformed batch token") + } + + positions := make([]int64, len(posStrings)) + for i, posString := range posStrings { + pos, err := strconv.ParseInt(posString, 10, 64) + if err != nil { + return nil, err + } + positions[i] = pos + } + + if len(positions) == 2 { + // Full length token; construct SyncPosition with every entry in + // `positions`. These entries must have the same order with the fields + // in struct SyncPosition, so we disable the govet check below. + return &types.SyncPosition{ //nolint:govet + positions[0], positions[1], + }, nil + } else { + // Token with PDU position only + return &types.SyncPosition{ + PDUPosition: positions[0], + }, nil } - token := types.StreamPosition(i) - return &token, nil } diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 89137eb5..a6ec6bd9 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -31,13 +31,13 @@ import ( // RequestPool manages HTTP long-poll connections for /sync type RequestPool struct { - db *storage.SyncServerDatabase + db *storage.SyncServerDatasource accountDB *accounts.Database notifier *Notifier } // NewRequestPool makes a new RequestPool -func NewRequestPool(db *storage.SyncServerDatabase, n *Notifier, adb *accounts.Database) *RequestPool { +func NewRequestPool(db *storage.SyncServerDatasource, n *Notifier, adb *accounts.Database) *RequestPool { return &RequestPool{db, adb, n} } @@ -92,11 +92,13 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype // respond with, so we skip the return an go back to waiting for content to // be sent down or the request timing out. var hasTimedOut bool + sincePos := *syncReq.since for { select { // Wait for notifier to wake us up - case <-userStreamListener.GetNotifyChannel(currPos): - currPos = userStreamListener.GetStreamPosition() + case <-userStreamListener.GetNotifyChannel(sincePos): + currPos = userStreamListener.GetSyncPosition() + sincePos = currPos // Or for timeout to expire case <-timer.C: // We just need to ensure we get out of the select after reaching the @@ -128,24 +130,24 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype } } -func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.StreamPosition) (res *types.Response, err error) { +func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.SyncPosition) (res *types.Response, err error) { // TODO: handle ignored users if req.since == nil { res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit) } else { - res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, currentPos, req.limit) + res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, latestPos, req.limit) } if err != nil { return } - res, err = rp.appendAccountData(res, req.device.UserID, req, currentPos) + res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition) return } func (rp *RequestPool) appendAccountData( - data *types.Response, userID string, req syncRequest, currentPos types.StreamPosition, + data *types.Response, userID string, req syncRequest, currentPos int64, ) (*types.Response, error) { // TODO: Account data doesn't have a sync position of its own, meaning that // account data might be sent multiple time to the client if multiple account @@ -179,7 +181,7 @@ func (rp *RequestPool) appendAccountData( } // Sync is not initial, get all account data since the latest sync - dataTypes, err := rp.db.GetAccountDataInRange(req.ctx, userID, *req.since, currentPos) + dataTypes, err := rp.db.GetAccountDataInRange(req.ctx, userID, req.since.PDUPosition, currentPos) if err != nil { return nil, err } diff --git a/syncapi/sync/userstream.go b/syncapi/sync/userstream.go index 77d09c20..beb10e48 100644 --- a/syncapi/sync/userstream.go +++ b/syncapi/sync/userstream.go @@ -34,8 +34,8 @@ type UserStream struct { lock sync.Mutex // Closed when there is an update. signalChannel chan struct{} - // The last stream position that there may have been an update for the suser - pos types.StreamPosition + // The last sync position that there may have been an update for the user + pos types.SyncPosition // The last time when we had some listeners waiting timeOfLastChannel time.Time // The number of listeners waiting @@ -51,7 +51,7 @@ type UserStreamListener struct { } // NewUserStream creates a new user stream -func NewUserStream(userID string, currPos types.StreamPosition) *UserStream { +func NewUserStream(userID string, currPos types.SyncPosition) *UserStream { return &UserStream{ UserID: userID, timeOfLastChannel: time.Now(), @@ -84,8 +84,8 @@ func (s *UserStream) GetListener(ctx context.Context) UserStreamListener { return listener } -// Broadcast a new stream position for this user. -func (s *UserStream) Broadcast(pos types.StreamPosition) { +// Broadcast a new sync position for this user. +func (s *UserStream) Broadcast(pos types.SyncPosition) { s.lock.Lock() defer s.lock.Unlock() @@ -118,9 +118,9 @@ func (s *UserStream) TimeOfLastNonEmpty() time.Time { return s.timeOfLastChannel } -// GetStreamPosition returns last stream position which the UserStream was +// GetStreamPosition returns last sync position which the UserStream was // notified about -func (s *UserStreamListener) GetStreamPosition() types.StreamPosition { +func (s *UserStreamListener) GetSyncPosition() types.SyncPosition { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() @@ -132,11 +132,11 @@ func (s *UserStreamListener) GetStreamPosition() types.StreamPosition { // sincePos specifies from which point we want to be notified about. If there // has already been an update after sincePos we'll return a closed channel // immediately. -func (s *UserStreamListener) GetNotifyChannel(sincePos types.StreamPosition) <-chan struct{} { +func (s *UserStreamListener) GetNotifyChannel(sincePos types.SyncPosition) <-chan struct{} { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() - if sincePos < s.userStream.pos { + if s.userStream.pos.IsAfter(sincePos) { // If the listener is behind, i.e. missed a potential update, then we // want them to wake up immediately. We do this by returning a new // closed stream, which returns immediately when selected. diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 2db54c3c..4738feea 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -28,7 +28,6 @@ import ( "github.com/matrix-org/dendrite/syncapi/routing" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" - "github.com/matrix-org/dendrite/syncapi/types" ) // SetupSyncAPIComponent sets up and registers HTTP handlers for the SyncAPI @@ -39,17 +38,17 @@ func SetupSyncAPIComponent( accountsDB *accounts.Database, queryAPI api.RoomserverQueryAPI, ) { - syncDB, err := storage.NewSyncServerDatabase(string(base.Cfg.Database.SyncAPI)) + syncDB, err := storage.NewSyncServerDatasource(string(base.Cfg.Database.SyncAPI)) if err != nil { logrus.WithError(err).Panicf("failed to connect to sync db") } - pos, err := syncDB.SyncStreamPosition(context.Background()) + pos, err := syncDB.SyncPosition(context.Background()) if err != nil { - logrus.WithError(err).Panicf("failed to get stream position") + logrus.WithError(err).Panicf("failed to get sync position") } - notifier := sync.NewNotifier(types.StreamPosition(pos)) + notifier := sync.NewNotifier(pos) err = notifier.Load(context.Background(), syncDB) if err != nil { logrus.WithError(err).Panicf("failed to start notifier") @@ -71,5 +70,12 @@ func SetupSyncAPIComponent( logrus.WithError(err).Panicf("failed to start client data consumer") } + typingConsumer := consumers.NewOutputTypingEventConsumer( + base.Cfg, base.KafkaConsumer, notifier, syncDB, + ) + if err = typingConsumer.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start typing server consumer") + } + routing.Setup(base.APIMux, requestPool, syncDB, deviceDB) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index d0b1c38a..af7ec865 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -21,12 +21,38 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -// StreamPosition represents the offset in the sync stream a client is at. -type StreamPosition int64 +// SyncPosition contains the PDU and EDU stream sync positions for a client. +type SyncPosition struct { + // PDUPosition is the stream position for PDUs the client is at. + PDUPosition int64 + // TypingPosition is the client's position for typing notifications. + TypingPosition int64 +} // String implements the Stringer interface. -func (sp StreamPosition) String() string { - return strconv.FormatInt(int64(sp), 10) +func (sp SyncPosition) String() string { + return strconv.FormatInt(sp.PDUPosition, 10) + "_" + + strconv.FormatInt(sp.TypingPosition, 10) +} + +// IsAfter returns whether one SyncPosition refers to states newer than another SyncPosition. +func (sp SyncPosition) IsAfter(other SyncPosition) bool { + return sp.PDUPosition > other.PDUPosition || + sp.TypingPosition > other.TypingPosition +} + +// WithUpdates returns a copy of the SyncPosition with updates applied from another SyncPosition. +// If the latter SyncPosition contains a field that is not 0, it is considered an update, +// and its value will replace the corresponding value in the SyncPosition on which WithUpdates is called. +func (sp SyncPosition) WithUpdates(other SyncPosition) SyncPosition { + ret := sp + if other.PDUPosition != 0 { + ret.PDUPosition = other.PDUPosition + } + if other.TypingPosition != 0 { + ret.TypingPosition = other.TypingPosition + } + return ret } // PrevEventRef represents a reference to a previous event in a state event upgrade @@ -53,11 +79,10 @@ type Response struct { } // NewResponse creates an empty response with initialised maps. -func NewResponse(pos StreamPosition) *Response { - res := Response{} - // Make sure we send the next_batch as a string. We don't want to confuse clients by sending this - // as an integer even though (at the moment) it is. - res.NextBatch = pos.String() +func NewResponse(pos SyncPosition) *Response { + res := Response{ + NextBatch: pos.String(), + } // Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section, // so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors. res.Rooms.Join = make(map[string]JoinResponse) diff --git a/testfile b/testfile index e869be4f..bdd421e2 100644 --- a/testfile +++ b/testfile @@ -144,3 +144,6 @@ Events come down the correct room local user can join room with version 5 User can invite local user to room with version 5 Inbound federation can receive room-join requests +Typing events appear in initial sync +Typing events appear in incremental sync +Typing events appear in gapped sync diff --git a/typingserver/api/output.go b/typingserver/api/output.go index 813b9b7c..8696acf4 100644 --- a/typingserver/api/output.go +++ b/typingserver/api/output.go @@ -12,14 +12,17 @@ package api +import "time" + // OutputTypingEvent is an entry in typing server output kafka log. // This contains the event with extra fields used to create 'm.typing' event // in clientapi & federation. type OutputTypingEvent struct { // The Event for the typing edu event. Event TypingEvent `json:"event"` - // Users typing in the room when the event was generated. - TypingUsers []string `json:"typing_users"` + // ExpireTime is the interval after which the user should no longer be + // considered typing. Only available if Event.Typing is true. + ExpireTime *time.Time } // TypingEvent represents a matrix edu event of type 'm.typing'. diff --git a/typingserver/cache/cache.go b/typingserver/cache/cache.go index 85d74cd1..8d84f856 100644 --- a/typingserver/cache/cache.go +++ b/typingserver/cache/cache.go @@ -22,25 +22,66 @@ const defaultTypingTimeout = 10 * time.Second // userSet is a map of user IDs to a timer, timer fires at expiry. type userSet map[string]*time.Timer +// TimeoutCallbackFn is a function called right after the removal of a user +// from the typing user list due to timeout. +// latestSyncPosition is the typing sync position after the removal. +type TimeoutCallbackFn func(userID, roomID string, latestSyncPosition int64) + +type roomData struct { + syncPosition int64 + userSet userSet +} + // TypingCache maintains a list of users typing in each room. type TypingCache struct { sync.RWMutex - data map[string]userSet + latestSyncPosition int64 + data map[string]*roomData + timeoutCallback TimeoutCallbackFn +} + +// Create a roomData with its sync position set to the latest sync position. +// Must only be called after locking the cache. +func (t *TypingCache) newRoomData() *roomData { + return &roomData{ + syncPosition: t.latestSyncPosition, + userSet: make(userSet), + } } // NewTypingCache returns a new TypingCache initialised for use. func NewTypingCache() *TypingCache { - return &TypingCache{data: make(map[string]userSet)} + return &TypingCache{data: make(map[string]*roomData)} +} + +// SetTimeoutCallback sets a callback function that is called right after +// a user is removed from the typing user list due to timeout. +func (t *TypingCache) SetTimeoutCallback(fn TimeoutCallbackFn) { + t.timeoutCallback = fn } // GetTypingUsers returns the list of users typing in a room. -func (t *TypingCache) GetTypingUsers(roomID string) (users []string) { +func (t *TypingCache) GetTypingUsers(roomID string) []string { + users, _ := t.GetTypingUsersIfUpdatedAfter(roomID, 0) + // 0 should work above because the first position used will be 1. + return users +} + +// GetTypingUsersIfUpdatedAfter returns all users typing in this room with +// updated == true if the typing sync position of the room is after the given +// position. Otherwise, returns an empty slice with updated == false. +func (t *TypingCache) GetTypingUsersIfUpdatedAfter( + roomID string, position int64, +) (users []string, updated bool) { t.RLock() - usersMap, ok := t.data[roomID] - t.RUnlock() - if ok { - users = make([]string, 0, len(usersMap)) - for userID := range usersMap { + defer t.RUnlock() + + roomData, ok := t.data[roomID] + if ok && roomData.syncPosition > position { + updated = true + userSet := roomData.userSet + users = make([]string, 0, len(userSet)) + for userID := range userSet { users = append(users, userID) } } @@ -51,25 +92,41 @@ func (t *TypingCache) GetTypingUsers(roomID string) (users []string) { // AddTypingUser sets an user as typing in a room. // expire is the time when the user typing should time out. // if expire is nil, defaultTypingTimeout is assumed. -func (t *TypingCache) AddTypingUser(userID, roomID string, expire *time.Time) { +// Returns the latest sync position for typing after update. +func (t *TypingCache) AddTypingUser( + userID, roomID string, expire *time.Time, +) int64 { expireTime := getExpireTime(expire) if until := time.Until(expireTime); until > 0 { - timer := time.AfterFunc(until, t.timeoutCallback(userID, roomID)) - t.addUser(userID, roomID, timer) + timer := time.AfterFunc(until, func() { + latestSyncPosition := t.RemoveUser(userID, roomID) + if t.timeoutCallback != nil { + t.timeoutCallback(userID, roomID, latestSyncPosition) + } + }) + return t.addUser(userID, roomID, timer) } + return t.GetLatestSyncPosition() } // addUser with mutex lock & replace the previous timer. -func (t *TypingCache) addUser(userID, roomID string, expiryTimer *time.Timer) { +// Returns the latest typing sync position after update. +func (t *TypingCache) addUser( + userID, roomID string, expiryTimer *time.Timer, +) int64 { t.Lock() defer t.Unlock() + t.latestSyncPosition++ + if t.data[roomID] == nil { - t.data[roomID] = make(userSet) + t.data[roomID] = t.newRoomData() + } else { + t.data[roomID].syncPosition = t.latestSyncPosition } // Stop the timer to cancel the call to timeoutCallback - if timer, ok := t.data[roomID][userID]; ok { + if timer, ok := t.data[roomID].userSet[userID]; ok { // It may happen that at this stage timer fires but now we have a lock on t. // Hence the execution of timeoutCallback will happen after we unlock. // So we may lose a typing state, though this event is highly unlikely. @@ -78,26 +135,40 @@ func (t *TypingCache) addUser(userID, roomID string, expiryTimer *time.Timer) { timer.Stop() } - t.data[roomID][userID] = expiryTimer -} + t.data[roomID].userSet[userID] = expiryTimer -// Returns a function which is called after timeout happens. -// This removes the user. -func (t *TypingCache) timeoutCallback(userID, roomID string) func() { - return func() { - t.RemoveUser(userID, roomID) - } + return t.latestSyncPosition } // RemoveUser with mutex lock & stop the timer. -func (t *TypingCache) RemoveUser(userID, roomID string) { +// Returns the latest sync position for typing after update. +func (t *TypingCache) RemoveUser(userID, roomID string) int64 { t.Lock() defer t.Unlock() - if timer, ok := t.data[roomID][userID]; ok { - timer.Stop() - delete(t.data[roomID], userID) + roomData, ok := t.data[roomID] + if !ok { + return t.latestSyncPosition } + + timer, ok := roomData.userSet[userID] + if !ok { + return t.latestSyncPosition + } + + timer.Stop() + delete(roomData.userSet, userID) + + t.latestSyncPosition++ + t.data[roomID].syncPosition = t.latestSyncPosition + + return t.latestSyncPosition +} + +func (t *TypingCache) GetLatestSyncPosition() int64 { + t.Lock() + defer t.Unlock() + return t.latestSyncPosition } func getExpireTime(expire *time.Time) time.Time { diff --git a/typingserver/input/input.go b/typingserver/input/input.go index b9968ce4..0e2fbe51 100644 --- a/typingserver/input/input.go +++ b/typingserver/input/input.go @@ -57,15 +57,21 @@ func (t *TypingServerInputAPI) InputTypingEvent( } func (t *TypingServerInputAPI) sendEvent(ite *api.InputTypingEvent) error { - userIDs := t.Cache.GetTypingUsers(ite.RoomID) ev := &api.TypingEvent{ Type: gomatrixserverlib.MTyping, RoomID: ite.RoomID, UserID: ite.UserID, + Typing: ite.Typing, } ote := &api.OutputTypingEvent{ - Event: *ev, - TypingUsers: userIDs, + Event: *ev, + } + + if ev.Typing { + expireTime := ite.OriginServerTS.Time().Add( + time.Duration(ite.Timeout) * time.Millisecond, + ) + ote.ExpireTime = &expireTime } eventJSON, err := json.Marshal(ote)