From d5a44fd3e87947aa6e3c9683cf5befc1c9738ad3 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Wed, 17 May 2017 15:38:24 +0100 Subject: [PATCH] Only wake up /sync requests which the event is for (#101) --- .../cmd/dendrite-sync-api-server/main.go | 3 + .../storage/current_room_state_table.go | 29 ++ .../dendrite/syncapi/storage/syncserver.go | 5 + .../dendrite/syncapi/sync/notifier.go | 189 ++++++++++-- .../dendrite/syncapi/sync/notifier_test.go | 292 ++++++++++++++++++ .../dendrite/syncapi/sync/request.go | 7 +- .../dendrite/syncapi/sync/userstream.go | 79 +++++ 7 files changed, 582 insertions(+), 22 deletions(-) create mode 100644 src/github.com/matrix-org/dendrite/syncapi/sync/notifier_test.go create mode 100644 src/github.com/matrix-org/dendrite/syncapi/sync/userstream.go diff --git a/src/github.com/matrix-org/dendrite/cmd/dendrite-sync-api-server/main.go b/src/github.com/matrix-org/dendrite/cmd/dendrite-sync-api-server/main.go index f67bf0e5..8b1da837 100644 --- a/src/github.com/matrix-org/dendrite/cmd/dendrite-sync-api-server/main.go +++ b/src/github.com/matrix-org/dendrite/cmd/dendrite-sync-api-server/main.go @@ -78,6 +78,9 @@ func main() { } n := sync.NewNotifier(types.StreamPosition(pos)) + if err := n.Load(db); err != nil { + log.Panicf("startup: failed to set up notifier: %s", err) + } server, err := consumers.NewServer(cfg, n, db) if err != nil { log.Panicf("startup: failed to create sync server: %s", err) diff --git a/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go b/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go index e8cc6851..b74514c1 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go +++ b/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go @@ -61,11 +61,15 @@ const selectRoomIDsWithMembershipSQL = "" + const selectCurrentStateSQL = "" + "SELECT event_json FROM current_room_state WHERE room_id = $1" +const selectJoinedUsersSQL = "" + + "SELECT room_id, state_key FROM current_room_state WHERE type = 'm.room.member' AND membership = 'join'" + type currentRoomStateStatements struct { upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt selectCurrentStateStmt *sql.Stmt + selectJoinedUsersStmt *sql.Stmt } func (s *currentRoomStateStatements) prepare(db *sql.DB) (err error) { @@ -85,9 +89,34 @@ func (s *currentRoomStateStatements) prepare(db *sql.DB) (err error) { if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil { return } + if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { + return + } return } +// JoinedMemberLists returns a map of room ID to a list of joined user IDs. +func (s *currentRoomStateStatements) JoinedMemberLists() (map[string][]string, error) { + rows, err := s.selectJoinedUsersStmt.Query() + if err != nil { + return nil, err + } + defer rows.Close() + + result := make(map[string][]string) + for rows.Next() { + var roomID string + var userID string + if err := rows.Scan(&roomID, &userID); err != nil { + return nil, err + } + users := result[roomID] + users = append(users, userID) + result[roomID] = users + } + return result, nil +} + // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(txn *sql.Tx, userID, membership string) ([]string, error) { rows, err := txn.Stmt(s.selectRoomIDsWithMembershipStmt).Query(userID, membership) diff --git a/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go b/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go index 46e2b9f6..fb1a5c16 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go +++ b/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go @@ -61,6 +61,11 @@ func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) { return &SyncServerDatabase{db, partitions, events, state}, nil } +// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. +func (d *SyncServerDatabase) AllJoinedUsersInRooms() (map[string][]string, error) { + return d.roomstate.JoinedMemberLists() +} + // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races // when generating the stream position for this event. Returns the sync stream position for the inserted event. // Returns an error if there was a problem inserting this event. diff --git a/src/github.com/matrix-org/dendrite/syncapi/sync/notifier.go b/src/github.com/matrix-org/dendrite/syncapi/sync/notifier.go index cc986579..1ed0cf55 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/sync/notifier.go +++ b/src/github.com/matrix-org/dendrite/syncapi/sync/notifier.go @@ -15,27 +15,42 @@ package sync import ( + "encoding/json" "sync" + log "github.com/Sirupsen/logrus" + "github.com/matrix-org/dendrite/clientapi/events" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) -// Notifier will wake up sleeping requests in the request pool when there -// is some new data. It does not tell requests what that data is, only the -// stream position which they can use to get at it. +// Notifier will wake up sleeping requests when there is some new data. +// It does not tell requests what that data is, only the stream position which +// they can use to get at it. This is done to prevent races whereby we tell the caller +// the event, but the token has already advanced by the time they fetch it, resulting +// in missed events. type Notifier struct { - // The latest sync stream position: guarded by 'cond'. + // A map of RoomID => Set : Must only be accessed by the OnNewEvent goroutine + roomIDToJoinedUsers map[string]set + // Protects currPos and userStreams. + streamLock *sync.Mutex + // The latest sync stream position: guarded by 'currPosMutex' which is RW to allow + // for concurrent reads on /sync requests currPos types.StreamPosition - // A condition variable to notify all waiting goroutines of a new sync stream position - cond *sync.Cond + // A map of user_id => UserStream which can be used to wake a given user's /sync request. + userStreams map[string]*UserStream } // NewNotifier creates a new notifier set to the given stream position. +// In order for this to be of any use, the Notifier needs to be told all rooms and +// the joined users within each of them by calling Notifier.LoadFromDatabase(). func NewNotifier(pos types.StreamPosition) *Notifier { return &Notifier{ - pos, - sync.NewCond(&sync.Mutex{}), + currPos: pos, + roomIDToJoinedUsers: make(map[string]set), + userStreams: make(map[string]*UserStream), + streamLock: &sync.Mutex{}, } } @@ -43,25 +58,157 @@ func NewNotifier(pos types.StreamPosition) *Notifier { // called from a single goroutine, to avoid races between updates which could set the // current position in the stream incorrectly. func (n *Notifier) OnNewEvent(ev *gomatrixserverlib.Event, pos types.StreamPosition) { - // update the current position in a guard and then notify all /sync streams - n.cond.L.Lock() + // update the current position then notify relevant /sync streams. + // This needs to be done PRIOR to waking up users as they will read this value. + n.streamLock.Lock() + defer n.streamLock.Unlock() n.currPos = pos - n.cond.L.Unlock() - n.cond.Broadcast() // notify ALL waiting goroutines + // Map this event's room_id to a list of joined users, and wake them up. + userIDs := n.joinedUsers(ev.RoomID()) + // If this is an invite, also add in the invitee to this list. + if ev.Type() == "m.room.member" && ev.StateKey() != nil { + userID := *ev.StateKey() + var memberContent events.MemberContent + if err := json.Unmarshal(ev.Content(), &memberContent); err != nil { + log.WithError(err).WithField("event_id", ev.EventID()).Errorf( + "Notifier.OnNewEvent: Failed to unmarshal member event", + ) + } else { + // Keep the joined user map up-to-date + switch memberContent.Membership { + case "invite": + userIDs = append(userIDs, userID) + case "join": + n.addJoinedUser(ev.RoomID(), userID) + case "leave": + fallthrough + case "ban": + n.removeJoinedUser(ev.RoomID(), userID) + } + } + } + + for _, userID := range userIDs { + n.wakeupUser(userID, pos) + } } // WaitForEvents blocks until there are new events for this request. func (n *Notifier) WaitForEvents(req syncRequest) types.StreamPosition { - // In a guard, check if the /sync request should block, and block it until we get a new position - n.cond.L.Lock() + // Do what synapse does: https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/notifier.py#L298 + // - Bucket request into a lookup map keyed off a list of joined room IDs and separately a user ID + // - Incoming events wake requests for a matching room ID + // - Incoming events wake requests for a matching user ID (needed for invites) + + // TODO: v1 /events 'peeking' has an 'explicit room ID' which is also tracked, + // but given we don't do /events, let's pretend it doesn't exist. + + // In a guard, check if the /sync request should block, and block it until we get woken up + n.streamLock.Lock() currentPos := n.currPos - for req.since == currentPos { - // we need to wait for a new event. - // TODO: This waits for ANY new event, we need to only wait for events which we care about. - n.cond.Wait() // atomically unlocks and blocks goroutine, then re-acquires lock on unblock - currentPos = n.currPos + + // TODO: We increment the stream position for any event, so it's possible that we return immediately + // with a pos which contains no new events for this user. We should probably re-wait for events + // automatically in this case. + if req.since != currentPos { + n.streamLock.Unlock() + return currentPos } - n.cond.L.Unlock() - return currentPos + + // wait to be woken up, and then re-check the stream position + req.log.WithField("user_id", req.userID).Info("Waiting for event") + + // give up the stream lock prior to waiting on the user lock + stream := n.fetchUserStream(req.userID, true) + n.streamLock.Unlock() + return stream.Wait(currentPos) +} + +// Load the membership states required to notify users correctly. +func (n *Notifier) Load(db *storage.SyncServerDatabase) error { + roomToUsers, err := db.AllJoinedUsersInRooms() + if err != nil { + return err + } + n.setUsersJoinedToRooms(roomToUsers) + return nil +} + +// setUsersJoinedToRooms marks the given users as 'joined' to the given rooms, such that new events from +// these rooms will wake the given users /sync requests. This should be called prior to ANY calls to +// OnNewEvent (eg on startup) to prevent racing. +func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) { + // This is just the bulk form of addJoinedUser + for roomID, userIDs := range roomIDToUserIDs { + if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { + n.roomIDToJoinedUsers[roomID] = make(set) + } + for _, userID := range userIDs { + n.roomIDToJoinedUsers[roomID].add(userID) + } + } +} + +func (n *Notifier) wakeupUser(userID string, newPos types.StreamPosition) { + stream := n.fetchUserStream(userID, false) + if stream == nil { + return + } + stream.Broadcast(newPos) // wakeup all goroutines Wait()ing on this stream +} + +// fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true, +// a stream will be made for this user if one doesn't exist and it will be returned. This +// function does not wait for data to be available on the stream. +func (n *Notifier) fetchUserStream(userID string, makeIfNotExists bool) *UserStream { + stream, ok := n.userStreams[userID] + if !ok { + // TODO: Unbounded growth of streams (1 per user) + stream = NewUserStream(userID) + n.userStreams[userID] = stream + } + return stream +} + +// Not thread-safe: must be called on the OnNewEvent goroutine only +func (n *Notifier) addJoinedUser(roomID, userID string) { + if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { + n.roomIDToJoinedUsers[roomID] = make(set) + } + n.roomIDToJoinedUsers[roomID].add(userID) +} + +// Not thread-safe: must be called on the OnNewEvent goroutine only +func (n *Notifier) removeJoinedUser(roomID, userID string) { + if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { + n.roomIDToJoinedUsers[roomID] = make(set) + } + n.roomIDToJoinedUsers[roomID].remove(userID) +} + +// Not thread-safe: must be called on the OnNewEvent goroutine only +func (n *Notifier) joinedUsers(roomID string) (userIDs []string) { + if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { + return + } + return n.roomIDToJoinedUsers[roomID].values() +} + +// A string set, mainly existing for improving clarity of structs in this file. +type set map[string]bool + +func (s set) add(str string) { + s[str] = true +} + +func (s set) remove(str string) { + delete(s, str) +} + +func (s set) values() (vals []string) { + for str := range s { + vals = append(vals, str) + } + return } diff --git a/src/github.com/matrix-org/dendrite/syncapi/sync/notifier_test.go b/src/github.com/matrix-org/dendrite/syncapi/sync/notifier_test.go new file mode 100644 index 00000000..784faf57 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/syncapi/sync/notifier_test.go @@ -0,0 +1,292 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sync + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +var ( + randomMessageEvent gomatrixserverlib.Event + aliceInviteBobEvent gomatrixserverlib.Event + bobLeaveEvent gomatrixserverlib.Event +) + +var ( + streamPositionVeryOld = types.StreamPosition(5) + streamPositionBefore = types.StreamPosition(11) + streamPositionAfter = types.StreamPosition(12) + streamPositionAfter2 = types.StreamPosition(13) + roomID = "!test:localhost" + alice = "@alice:localhost" + bob = "@bob:localhost" +) + +func init() { + var err error + randomMessageEvent, err = gomatrixserverlib.NewEventFromTrustedJSON([]byte(`{ + "type": "m.room.message", + "content": { + "body": "Hello World", + "msgtype": "m.text" + }, + "sender": "@noone:localhost", + "room_id": "`+roomID+`", + "origin_server_ts": 12345, + "event_id": "$randomMessageEvent:localhost" + }`), false) + if err != nil { + panic(err) + } + aliceInviteBobEvent, err = gomatrixserverlib.NewEventFromTrustedJSON([]byte(`{ + "type": "m.room.member", + "state_key": "`+bob+`", + "content": { + "membership": "invite" + }, + "sender": "`+alice+`", + "room_id": "`+roomID+`", + "origin_server_ts": 12345, + "event_id": "$aliceInviteBobEvent:localhost" + }`), false) + if err != nil { + panic(err) + } + bobLeaveEvent, err = gomatrixserverlib.NewEventFromTrustedJSON([]byte(`{ + "type": "m.room.member", + "state_key": "`+bob+`", + "content": { + "membership": "leave" + }, + "sender": "`+bob+`", + "room_id": "`+roomID+`", + "origin_server_ts": 12345, + "event_id": "$bobLeaveEvent:localhost" + }`), false) + if err != nil { + panic(err) + } +} + +// Test that the current position is returned if a request is already behind. +func TestImmediateNotification(t *testing.T) { + n := NewNotifier(streamPositionBefore) + pos, err := waitForEvents(n, newTestSyncRequest(alice, streamPositionVeryOld)) + if err != nil { + t.Fatalf("TestImmediateNotification error: %s", err) + } + if pos != streamPositionBefore { + t.Fatalf("TestImmediateNotification want %d, got %d", streamPositionBefore, pos) + } +} + +// Test that new events to a joined room unblocks the request. +func TestNewEventAndJoinedToRoom(t *testing.T) { + n := NewNotifier(streamPositionBefore) + n.setUsersJoinedToRooms(map[string][]string{ + roomID: []string{alice, bob}, + }) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) + if err != nil { + t.Errorf("TestNewEventAndJoinedToRoom error: %s", err) + } + if pos != streamPositionAfter { + t.Errorf("TestNewEventAndJoinedToRoom want %d, got %d", streamPositionAfter, pos) + } + wg.Done() + }() + + stream := n.fetchUserStream(bob, true) + waitForBlocking(stream, 1) + + n.OnNewEvent(&randomMessageEvent, streamPositionAfter) + + wg.Wait() +} + +// Test that an invite unblocks the request +func TestNewInviteEventForUser(t *testing.T) { + n := NewNotifier(streamPositionBefore) + n.setUsersJoinedToRooms(map[string][]string{ + roomID: []string{alice, bob}, + }) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) + if err != nil { + t.Errorf("TestNewInviteEventForUser error: %s", err) + } + if pos != streamPositionAfter { + t.Errorf("TestNewInviteEventForUser want %d, got %d", streamPositionAfter, pos) + } + wg.Done() + }() + + stream := n.fetchUserStream(bob, true) + waitForBlocking(stream, 1) + + n.OnNewEvent(&aliceInviteBobEvent, streamPositionAfter) + + wg.Wait() +} + +// Test that all blocked requests get woken up on a new event. +func TestMultipleRequestWakeup(t *testing.T) { + n := NewNotifier(streamPositionBefore) + n.setUsersJoinedToRooms(map[string][]string{ + roomID: []string{alice, bob}, + }) + + var wg sync.WaitGroup + wg.Add(3) + poll := func() { + pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) + if err != nil { + t.Errorf("TestMultipleRequestWakeup error: %s", err) + } + if pos != streamPositionAfter { + t.Errorf("TestMultipleRequestWakeup want %d, got %d", streamPositionAfter, pos) + } + wg.Done() + } + go poll() + go poll() + go poll() + + stream := n.fetchUserStream(bob, true) + waitForBlocking(stream, 3) + + n.OnNewEvent(&randomMessageEvent, streamPositionAfter) + + wg.Wait() + + numWaiting := stream.NumWaiting() + if numWaiting != 0 { + t.Errorf("TestMultipleRequestWakeup NumWaiting() want 0, got %d", numWaiting) + } +} + +// Test that you stop getting woken up when you leave a room. +func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { + // listen as bob. Make bob leave room. Make alice send event to room. + // Make sure alice gets woken up only and not bob as well. + n := NewNotifier(streamPositionBefore) + n.setUsersJoinedToRooms(map[string][]string{ + roomID: []string{alice, bob}, + }) + + var leaveWG sync.WaitGroup + + // Make bob leave the room + leaveWG.Add(1) + go func() { + pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) + if err != nil { + t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err) + } + if pos != streamPositionAfter { + t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", streamPositionAfter, pos) + } + leaveWG.Done() + }() + bobStream := n.fetchUserStream(bob, true) + waitForBlocking(bobStream, 1) + n.OnNewEvent(&bobLeaveEvent, streamPositionAfter) + leaveWG.Wait() + + // send an event into the room. Make sure alice gets it. Bob should not. + var aliceWG sync.WaitGroup + aliceStream := n.fetchUserStream(alice, true) + aliceWG.Add(1) + go func() { + pos, err := waitForEvents(n, newTestSyncRequest(alice, streamPositionAfter)) + if err != nil { + t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err) + } + if pos != streamPositionAfter2 { + t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", streamPositionAfter2, pos) + } + aliceWG.Done() + }() + + go func() { + // this should timeout with an error (but the main goroutine won't wait for the timeout explicitly) + _, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionAfter)) + if err == nil { + t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom expect error but got nil") + } + }() + + waitForBlocking(aliceStream, 1) + waitForBlocking(bobStream, 1) + + n.OnNewEvent(&randomMessageEvent, streamPositionAfter2) + aliceWG.Wait() + + // it's possible that at this point alice has been informed and bob is about to be informed, so wait + // for a fraction of a second to account for this race + time.Sleep(1 * time.Millisecond) +} + +// same as Notifier.WaitForEvents but with a timeout. +func waitForEvents(n *Notifier, req syncRequest) (types.StreamPosition, error) { + done := make(chan types.StreamPosition, 1) + go func() { + newPos := n.WaitForEvents(req) + done <- newPos + close(done) + }() + select { + case <-time.After(5 * time.Second): + return types.StreamPosition(0), fmt.Errorf( + "waitForEvents timed out waiting for %s (pos=%d)", req.userID, req.since, + ) + case p := <-done: + return p, nil + } +} + +// Wait until something is Wait()ing on the user stream. +func waitForBlocking(s *UserStream, numBlocking int) { + for numBlocking != s.NumWaiting() { + // This is horrible but I don't want to add a signalling mechanism JUST for testing. + time.Sleep(1 * time.Microsecond) + } +} + +func newTestSyncRequest(userID string, since types.StreamPosition) syncRequest { + return syncRequest{ + userID: userID, + timeout: 1 * time.Minute, + since: since, + wantFullState: false, + limit: defaultTimelineLimit, + log: util.GetLogger(context.TODO()), + } +} diff --git a/src/github.com/matrix-org/dendrite/syncapi/sync/request.go b/src/github.com/matrix-org/dendrite/syncapi/sync/request.go index a44f8557..5260a363 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/sync/request.go +++ b/src/github.com/matrix-org/dendrite/syncapi/sync/request.go @@ -15,10 +15,13 @@ package sync import ( - "github.com/matrix-org/dendrite/syncapi/types" "net/http" "strconv" "time" + + log "github.com/Sirupsen/logrus" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/util" ) const defaultSyncTimeout = time.Duration(30) * time.Second @@ -31,6 +34,7 @@ type syncRequest struct { timeout time.Duration since types.StreamPosition wantFullState bool + log *log.Entry } func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) { @@ -48,6 +52,7 @@ func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) { since: since, wantFullState: wantFullState, limit: defaultTimelineLimit, // TODO: read from filter + log: util.GetLogger(req.Context()), }, nil } diff --git a/src/github.com/matrix-org/dendrite/syncapi/sync/userstream.go b/src/github.com/matrix-org/dendrite/syncapi/sync/userstream.go new file mode 100644 index 00000000..349b3e27 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/syncapi/sync/userstream.go @@ -0,0 +1,79 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sync + +import ( + "sync" + + "github.com/matrix-org/dendrite/syncapi/types" +) + +// UserStream represents a communication mechanism between the /sync request goroutine +// and the underlying sync server goroutines. Goroutines can Wait() for a stream position and +// goroutines can Broadcast(streamPosition) to other goroutines. +type UserStream struct { + UserID string + // Because this is a Cond, we can notify all waiting goroutines so this works + // across devices for the same user. Protects pos. + cond *sync.Cond + // The position to broadcast to callers of Wait(). + pos types.StreamPosition + // The number of goroutines blocked on Wait() - used for testing and metrics + numWaiting int +} + +// NewUserStream creates a new user stream +func NewUserStream(userID string) *UserStream { + return &UserStream{ + UserID: userID, + cond: sync.NewCond(&sync.Mutex{}), + } +} + +// Wait blocks until there is a new stream position for this user, which is then returned. +// waitAtPos should be the position the stream thinks it should be waiting at. +func (s *UserStream) Wait(waitAtPos types.StreamPosition) (pos types.StreamPosition) { + s.cond.L.Lock() + // Before we start blocking, we need to make sure that we didn't race with a call + // to Broadcast() between calling Wait() and actually sleeping. We check the last + // broadcast pos to see if it is newer than the pos we are meant to wait at. If it + // is newer, something has Broadcast to this stream more recently so return immediately. + if s.pos > waitAtPos { + pos = s.pos + s.cond.L.Unlock() + return + } + s.numWaiting++ + s.cond.Wait() + pos = s.pos + s.numWaiting-- + s.cond.L.Unlock() + return +} + +// Broadcast a new stream position for this user. +func (s *UserStream) Broadcast(pos types.StreamPosition) { + s.cond.L.Lock() + s.pos = pos + s.cond.L.Unlock() + s.cond.Broadcast() +} + +// NumWaiting returns the number of goroutines waiting for Wait() to return. Used for metrics and testing. +func (s *UserStream) NumWaiting() int { + s.cond.L.Lock() + defer s.cond.L.Unlock() + return s.numWaiting +}