Per-user-per-device sync streams (#1068)

* Per-user-per-device sync streams

* Tweaks

* Tweaks

* Pass full device into CompleteSync

* Set user IDs and device IDs properly in tests

* Add new test, fix TestNewEventAndWasPreviouslyJoinedToRoom

* nolint a function that is not used yet

* Add test for waking up single device

* Hopefully unstick test

* Try to ensure that TestCorrectStreamWakeup doesn't block forever

* Update tests
This commit is contained in:
Neil Alexander 2020-05-28 10:05:04 +01:00 committed by GitHub
parent 57841fc35e
commit 02fe38e1f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 161 additions and 70 deletions

View File

@ -13,4 +13,4 @@ go build ./cmd/...
./scripts/find-lint.sh ./scripts/find-lint.sh
echo "Testing..." echo "Testing..."
go test ./... go test -v ./...

View File

@ -58,7 +58,7 @@ type Database interface {
// ID. // ID.
IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error)
// CompleteSync returns a complete /sync API response for the given user. // CompleteSync returns a complete /sync API response for the given user.
CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error) CompleteSync(ctx context.Context, device authtypes.Device, numRecentEventsPerRoom int) (*types.Response, error)
// GetAccountDataInRange returns all account data for a given user inserted or // GetAccountDataInRange returns all account data for a given user inserted or
// updated between two given positions // updated between two given positions
// Returns a map following the format data[roomID] = []dataTypes // Returns a map following the format data[roomID] = []dataTypes

View File

@ -666,10 +666,10 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
} }
func (d *Database) CompleteSync( func (d *Database) CompleteSync(
ctx context.Context, userID string, numRecentEventsPerRoom int, ctx context.Context, device authtypes.Device, numRecentEventsPerRoom int,
) (*types.Response, error) { ) (*types.Response, error) {
res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync( res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync(
ctx, userID, numRecentEventsPerRoom, ctx, device.UserID, numRecentEventsPerRoom,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -181,7 +181,7 @@ func TestSyncResponse(t *testing.T) {
Name: "CompleteSync limited", Name: "CompleteSync limited",
DoSync: func() (*types.Response, error) { DoSync: func() (*types.Response, error) {
// limit set to 5 // limit set to 5
return db.CompleteSync(ctx, testUserIDA, 5) return db.CompleteSync(ctx, testUserDeviceA, 5)
}, },
// want the last 5 events // want the last 5 events
WantTimeline: events[len(events)-5:], WantTimeline: events[len(events)-5:],
@ -193,7 +193,7 @@ func TestSyncResponse(t *testing.T) {
{ {
Name: "CompleteSync", Name: "CompleteSync",
DoSync: func() (*types.Response, error) { DoSync: func() (*types.Response, error) {
return db.CompleteSync(ctx, testUserIDA, len(events)+1) return db.CompleteSync(ctx, testUserDeviceA, len(events)+1)
}, },
WantTimeline: events, WantTimeline: events,
// We want no state at all as that field in /sync is the delta between the token (beginning of time) // We want no state at all as that field in /sync is the delta between the token (beginning of time)

View File

@ -37,8 +37,8 @@ type Notifier struct {
streamLock *sync.Mutex streamLock *sync.Mutex
// The latest sync position // The latest sync position
currPos types.StreamingToken currPos types.StreamingToken
// A map of user_id => UserStream which can be used to wake a given user's /sync request. // A map of user_id => device_id => UserStream which can be used to wake a given user's /sync request.
userStreams map[string]*UserStream userDeviceStreams map[string]map[string]*UserDeviceStream
// The last time we cleaned out stale entries from the userStreams map // The last time we cleaned out stale entries from the userStreams map
lastCleanUpTime time.Time lastCleanUpTime time.Time
} }
@ -50,7 +50,7 @@ func NewNotifier(pos types.StreamingToken) *Notifier {
return &Notifier{ return &Notifier{
currPos: pos, currPos: pos,
roomIDToJoinedUsers: make(map[string]userIDSet), roomIDToJoinedUsers: make(map[string]userIDSet),
userStreams: make(map[string]*UserStream), userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
streamLock: &sync.Mutex{}, streamLock: &sync.Mutex{},
lastCleanUpTime: time.Now(), lastCleanUpTime: time.Now(),
} }
@ -123,7 +123,7 @@ func (n *Notifier) OnNewEvent(
// GetListener returns a UserStreamListener that can be used to wait for // GetListener returns a UserStreamListener that can be used to wait for
// updates for a user. Must be closed. // updates for a user. Must be closed.
// notify for anything before sincePos // notify for anything before sincePos
func (n *Notifier) GetListener(req syncRequest) UserStreamListener { func (n *Notifier) GetListener(req syncRequest) UserDeviceStreamListener {
// Do what synapse does: https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/notifier.py#L298 // Do what synapse does: https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/notifier.py#L298
// - Bucket request into a lookup map keyed off a list of joined room IDs and separately a user ID // - Bucket request into a lookup map keyed off a list of joined room IDs and separately a user ID
// - Incoming events wake requests for a matching room ID // - Incoming events wake requests for a matching room ID
@ -137,7 +137,7 @@ func (n *Notifier) GetListener(req syncRequest) UserStreamListener {
n.removeEmptyUserStreams() n.removeEmptyUserStreams()
return n.fetchUserStream(req.device.UserID, true).GetListener(req.ctx) return n.fetchUserDeviceStream(req.device.UserID, req.device.ID, true).GetListener(req.ctx)
} }
// Load the membership states required to notify users correctly. // Load the membership states required to notify users correctly.
@ -173,27 +173,69 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) {
} }
} }
// wakeupUsers will wake up the sync strems for all of the devices for all of the
// specified user IDs.
func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) { func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) {
for _, userID := range userIDs { for _, userID := range userIDs {
stream := n.fetchUserStream(userID, false) for _, stream := range n.fetchUserStreams(userID) {
if stream != nil { if stream == nil {
continue
}
stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream
} }
} }
} }
// fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true, // wakeupUserDevice will wake up the sync stream for a specific user device. Other
// device streams will be left alone.
// nolint:unused
func (n *Notifier) wakeupUserDevice(userDevices map[string]string, newPos types.StreamingToken) {
for userID, deviceID := range userDevices {
if stream := n.fetchUserDeviceStream(userID, deviceID, false); stream != nil {
stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream
}
}
}
// fetchUserDeviceStream retrieves a stream unique to the given device. If makeIfNotExists is true,
// a stream will be made for this device if one doesn't exist and it will be returned. This
// function does not wait for data to be available on the stream.
// NB: Callers should have locked the mutex before calling this function.
func (n *Notifier) fetchUserDeviceStream(userID, deviceID string, makeIfNotExists bool) *UserDeviceStream {
_, ok := n.userDeviceStreams[userID]
if !ok {
if !makeIfNotExists {
return nil
}
n.userDeviceStreams[userID] = map[string]*UserDeviceStream{}
}
stream, ok := n.userDeviceStreams[userID][deviceID]
if !ok {
if !makeIfNotExists {
return nil
}
// TODO: Unbounded growth of streams (1 per user)
if stream = NewUserDeviceStream(userID, deviceID, n.currPos); stream != nil {
n.userDeviceStreams[userID][deviceID] = stream
}
}
return stream
}
// fetchUserStreams retrieves all streams for the given user. If makeIfNotExists is true,
// a stream will be made for this user if one doesn't exist and it will be returned. This // a stream will be made for this user if one doesn't exist and it will be returned. This
// function does not wait for data to be available on the stream. // function does not wait for data to be available on the stream.
// NB: Callers should have locked the mutex before calling this function. // NB: Callers should have locked the mutex before calling this function.
func (n *Notifier) fetchUserStream(userID string, makeIfNotExists bool) *UserStream { func (n *Notifier) fetchUserStreams(userID string) []*UserDeviceStream {
stream, ok := n.userStreams[userID] user, ok := n.userDeviceStreams[userID]
if !ok && makeIfNotExists { if !ok {
// TODO: Unbounded growth of streams (1 per user) return []*UserDeviceStream{}
stream = NewUserStream(userID, n.currPos)
n.userStreams[userID] = stream
} }
return stream streams := []*UserDeviceStream{}
for _, stream := range user {
streams = append(streams, stream)
}
return streams
} }
// Not thread-safe: must be called on the OnNewEvent goroutine only // Not thread-safe: must be called on the OnNewEvent goroutine only
@ -236,9 +278,14 @@ func (n *Notifier) removeEmptyUserStreams() {
n.lastCleanUpTime = now n.lastCleanUpTime = now
deleteBefore := now.Add(-5 * time.Minute) deleteBefore := now.Add(-5 * time.Minute)
for key, value := range n.userStreams { for user, byUser := range n.userDeviceStreams {
if value.TimeOfLastNonEmpty().Before(deleteBefore) { for device, stream := range byUser {
delete(n.userStreams, key) if stream.TimeOfLastNonEmpty().Before(deleteBefore) {
delete(n.userDeviceStreams[user], device)
}
if len(n.userDeviceStreams[user]) == 0 {
delete(n.userDeviceStreams, user)
}
} }
} }
} }

View File

@ -41,9 +41,11 @@ var (
) )
var ( var (
roomID = "!test:localhost" roomID = "!test:localhost"
alice = "@alice:localhost" alice = "@alice:localhost"
bob = "@bob:localhost" aliceDev = "alicedevice"
bob = "@bob:localhost"
bobDev = "bobdev"
) )
func init() { func init() {
@ -107,7 +109,7 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) {
// Test that the current position is returned if a request is already behind. // Test that the current position is returned if a request is already behind.
func TestImmediateNotification(t *testing.T) { func TestImmediateNotification(t *testing.T) {
n := NewNotifier(syncPositionBefore) n := NewNotifier(syncPositionBefore)
pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionVeryOld)) pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld))
if err != nil { if err != nil {
t.Fatalf("TestImmediateNotification error: %s", err) t.Fatalf("TestImmediateNotification error: %s", err)
} }
@ -124,7 +126,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore))
if err != nil { if err != nil {
t.Errorf("TestNewEventAndJoinedToRoom error: %w", err) t.Errorf("TestNewEventAndJoinedToRoom error: %w", err)
} }
@ -132,7 +134,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
wg.Done() wg.Done()
}() }()
stream := lockedFetchUserStream(n, bob) stream := lockedFetchUserStream(n, bob, bobDev)
waitForBlocking(stream, 1) waitForBlocking(stream, 1)
n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter)
@ -140,6 +142,43 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestCorrectStream(t *testing.T) {
n := NewNotifier(syncPositionBefore)
stream := lockedFetchUserStream(n, bob, bobDev)
if stream.UserID != bob {
t.Fatalf("expected user %q, got %q", bob, stream.UserID)
}
if stream.DeviceID != bobDev {
t.Fatalf("expected device %q, got %q", bobDev, stream.DeviceID)
}
}
func TestCorrectStreamWakeup(t *testing.T) {
n := NewNotifier(syncPositionBefore)
awoken := make(chan string)
streamone := lockedFetchUserStream(n, alice, "one")
streamtwo := lockedFetchUserStream(n, alice, "two")
go func() {
select {
case <-streamone.signalChannel:
awoken <- "one"
case <-streamtwo.signalChannel:
awoken <- "two"
}
}()
time.Sleep(1 * time.Second)
wake := "two"
n.wakeupUserDevice(map[string]string{alice: wake}, syncPositionAfter)
if result := <-awoken; result != wake {
t.Fatalf("expected to wake %q, got %q", wake, result)
}
}
// Test that an invite unblocks the request // Test that an invite unblocks the request
func TestNewInviteEventForUser(t *testing.T) { func TestNewInviteEventForUser(t *testing.T) {
n := NewNotifier(syncPositionBefore) n := NewNotifier(syncPositionBefore)
@ -150,7 +189,7 @@ func TestNewInviteEventForUser(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore))
if err != nil { if err != nil {
t.Errorf("TestNewInviteEventForUser error: %w", err) t.Errorf("TestNewInviteEventForUser error: %w", err)
} }
@ -158,7 +197,7 @@ func TestNewInviteEventForUser(t *testing.T) {
wg.Done() wg.Done()
}() }()
stream := lockedFetchUserStream(n, bob) stream := lockedFetchUserStream(n, bob, bobDev)
waitForBlocking(stream, 1) waitForBlocking(stream, 1)
n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionAfter) n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionAfter)
@ -176,7 +215,7 @@ func TestEDUWakeup(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter)) pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter))
if err != nil { if err != nil {
t.Errorf("TestNewInviteEventForUser error: %w", err) t.Errorf("TestNewInviteEventForUser error: %w", err)
} }
@ -184,7 +223,7 @@ func TestEDUWakeup(t *testing.T) {
wg.Done() wg.Done()
}() }()
stream := lockedFetchUserStream(n, bob) stream := lockedFetchUserStream(n, bob, bobDev)
waitForBlocking(stream, 1) waitForBlocking(stream, 1)
n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionNewEDU) n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionNewEDU)
@ -202,7 +241,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(3) wg.Add(3)
poll := func() { poll := func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore))
if err != nil { if err != nil {
t.Errorf("TestMultipleRequestWakeup error: %w", err) t.Errorf("TestMultipleRequestWakeup error: %w", err)
} }
@ -213,7 +252,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
go poll() go poll()
go poll() go poll()
stream := lockedFetchUserStream(n, bob) stream := lockedFetchUserStream(n, bob, bobDev)
waitForBlocking(stream, 3) waitForBlocking(stream, 3)
n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter)
@ -240,24 +279,24 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
// Make bob leave the room // Make bob leave the room
leaveWG.Add(1) leaveWG.Add(1)
go func() { go func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore))
if err != nil { if err != nil {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err)
} }
mustEqualPositions(t, pos, syncPositionAfter) mustEqualPositions(t, pos, syncPositionAfter)
leaveWG.Done() leaveWG.Done()
}() }()
bobStream := lockedFetchUserStream(n, bob) bobStream := lockedFetchUserStream(n, bob, bobDev)
waitForBlocking(bobStream, 1) waitForBlocking(bobStream, 1)
n.OnNewEvent(&bobLeaveEvent, "", nil, syncPositionAfter) n.OnNewEvent(&bobLeaveEvent, "", nil, syncPositionAfter)
leaveWG.Wait() leaveWG.Wait()
// send an event into the room. Make sure alice gets it. Bob should not. // send an event into the room. Make sure alice gets it. Bob should not.
var aliceWG sync.WaitGroup var aliceWG sync.WaitGroup
aliceStream := lockedFetchUserStream(n, alice) aliceStream := lockedFetchUserStream(n, alice, aliceDev)
aliceWG.Add(1) aliceWG.Add(1)
go func() { go func() {
pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionAfter)) pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionAfter))
if err != nil { if err != nil {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err)
} }
@ -267,7 +306,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
go func() { go func() {
// this should timeout with an error (but the main goroutine won't wait for the timeout explicitly) // this should timeout with an error (but the main goroutine won't wait for the timeout explicitly)
_, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter)) _, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter))
if err == nil { if err == nil {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom expect error but got nil") t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom expect error but got nil")
} }
@ -300,7 +339,7 @@ func waitForEvents(n *Notifier, req syncRequest) (types.StreamingToken, error) {
} }
// Wait until something is Wait()ing on the user stream. // Wait until something is Wait()ing on the user stream.
func waitForBlocking(s *UserStream, numBlocking uint) { func waitForBlocking(s *UserDeviceStream, numBlocking uint) {
for numBlocking != s.NumWaiting() { for numBlocking != s.NumWaiting() {
// This is horrible but I don't want to add a signalling mechanism JUST for testing. // This is horrible but I don't want to add a signalling mechanism JUST for testing.
time.Sleep(1 * time.Microsecond) time.Sleep(1 * time.Microsecond)
@ -309,16 +348,19 @@ func waitForBlocking(s *UserStream, numBlocking uint) {
// lockedFetchUserStream invokes Notifier.fetchUserStream, respecting Notifier.streamLock. // lockedFetchUserStream invokes Notifier.fetchUserStream, respecting Notifier.streamLock.
// A new stream is made if it doesn't exist already. // A new stream is made if it doesn't exist already.
func lockedFetchUserStream(n *Notifier, userID string) *UserStream { func lockedFetchUserStream(n *Notifier, userID, deviceID string) *UserDeviceStream {
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
return n.fetchUserStream(userID, true) return n.fetchUserDeviceStream(userID, deviceID, true)
} }
func newTestSyncRequest(userID string, since types.StreamingToken) syncRequest { func newTestSyncRequest(userID, deviceID string, since types.StreamingToken) syncRequest {
return syncRequest{ return syncRequest{
device: authtypes.Device{UserID: userID}, device: authtypes.Device{
UserID: userID,
ID: deviceID,
},
timeout: 1 * time.Minute, timeout: 1 * time.Minute,
since: &since, since: &since,
wantFullState: false, wantFullState: false,

View File

@ -47,7 +47,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
var syncData *types.Response var syncData *types.Response
// Extract values from request // Extract values from request
userID := device.UserID
syncReq, err := newSyncRequest(req, *device) syncReq, err := newSyncRequest(req, *device)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -56,10 +55,11 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
} }
} }
logger := util.GetLogger(req.Context()).WithFields(log.Fields{ logger := util.GetLogger(req.Context()).WithFields(log.Fields{
"userID": userID, "userID": device.UserID,
"since": syncReq.since, "deviceID": device.ID,
"timeout": syncReq.timeout, "since": syncReq.since,
"limit": syncReq.limit, "timeout": syncReq.timeout,
"limit": syncReq.limit,
}) })
currPos := rp.notifier.CurrentPosition() currPos := rp.notifier.CurrentPosition()
@ -136,7 +136,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) { func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
// TODO: handle ignored users // TODO: handle ignored users
if req.since == nil { if req.since == nil {
res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit) res, err = rp.db.CompleteSync(req.ctx, req.device, req.limit)
} else { } else {
res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, latestPos, req.limit, req.wantFullState) res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, latestPos, req.limit, req.wantFullState)
} }

View File

@ -23,12 +23,13 @@ import (
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
) )
// UserStream represents a communication mechanism between the /sync request goroutine // UserDeviceStream represents a communication mechanism between the /sync request goroutine
// and the underlying sync server goroutines. // and the underlying sync server goroutines.
// Goroutines can get a UserStreamListener to wait for updates, and can Broadcast() // Goroutines can get a UserStreamListener to wait for updates, and can Broadcast()
// updates. // updates.
type UserStream struct { type UserDeviceStream struct {
UserID string UserID string
DeviceID string
// The lock that protects changes to this struct // The lock that protects changes to this struct
lock sync.Mutex lock sync.Mutex
// Closed when there is an update. // Closed when there is an update.
@ -41,18 +42,19 @@ type UserStream struct {
numWaiting uint numWaiting uint
} }
// UserStreamListener allows a sync request to wait for updates for a user. // UserDeviceStreamListener allows a sync request to wait for updates for a user.
type UserStreamListener struct { type UserDeviceStreamListener struct {
userStream *UserStream userStream *UserDeviceStream
// Whether the stream has been closed // Whether the stream has been closed
hasClosed bool hasClosed bool
} }
// NewUserStream creates a new user stream // NewUserDeviceStream creates a new user stream
func NewUserStream(userID string, currPos types.StreamingToken) *UserStream { func NewUserDeviceStream(userID, deviceID string, currPos types.StreamingToken) *UserDeviceStream {
return &UserStream{ return &UserDeviceStream{
UserID: userID, UserID: userID,
DeviceID: deviceID,
timeOfLastChannel: time.Now(), timeOfLastChannel: time.Now(),
pos: currPos, pos: currPos,
signalChannel: make(chan struct{}), signalChannel: make(chan struct{}),
@ -62,18 +64,18 @@ func NewUserStream(userID string, currPos types.StreamingToken) *UserStream {
// GetListener returns UserStreamListener that a sync request can use to wait // GetListener returns UserStreamListener that a sync request can use to wait
// for new updates with. // for new updates with.
// UserStreamListener must be closed // UserStreamListener must be closed
func (s *UserStream) GetListener(ctx context.Context) UserStreamListener { func (s *UserDeviceStream) GetListener(ctx context.Context) UserDeviceStreamListener {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
s.numWaiting++ // We decrement when UserStreamListener is closed s.numWaiting++ // We decrement when UserStreamListener is closed
listener := UserStreamListener{ listener := UserDeviceStreamListener{
userStream: s, userStream: s,
} }
// Lets be a bit paranoid here and check that Close() is being called // Lets be a bit paranoid here and check that Close() is being called
runtime.SetFinalizer(&listener, func(l *UserStreamListener) { runtime.SetFinalizer(&listener, func(l *UserDeviceStreamListener) {
if !l.hasClosed { if !l.hasClosed {
l.Close() l.Close()
} }
@ -83,7 +85,7 @@ func (s *UserStream) GetListener(ctx context.Context) UserStreamListener {
} }
// Broadcast a new sync position for this user. // Broadcast a new sync position for this user.
func (s *UserStream) Broadcast(pos types.StreamingToken) { func (s *UserDeviceStream) Broadcast(pos types.StreamingToken) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
@ -96,7 +98,7 @@ func (s *UserStream) Broadcast(pos types.StreamingToken) {
// NumWaiting returns the number of goroutines waiting for waiting for updates. // NumWaiting returns the number of goroutines waiting for waiting for updates.
// Used for metrics and testing. // Used for metrics and testing.
func (s *UserStream) NumWaiting() uint { func (s *UserDeviceStream) NumWaiting() uint {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
return s.numWaiting return s.numWaiting
@ -105,7 +107,7 @@ func (s *UserStream) NumWaiting() uint {
// TimeOfLastNonEmpty returns the last time that the number of waiting listeners // TimeOfLastNonEmpty returns the last time that the number of waiting listeners
// was non-empty, may be time.Now() if number of waiting listeners is currently // was non-empty, may be time.Now() if number of waiting listeners is currently
// non-empty. // non-empty.
func (s *UserStream) TimeOfLastNonEmpty() time.Time { func (s *UserDeviceStream) TimeOfLastNonEmpty() time.Time {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
@ -118,7 +120,7 @@ func (s *UserStream) TimeOfLastNonEmpty() time.Time {
// GetSyncPosition returns last sync position which the UserStream was // GetSyncPosition returns last sync position which the UserStream was
// notified about // notified about
func (s *UserStreamListener) GetSyncPosition() types.StreamingToken { func (s *UserDeviceStreamListener) GetSyncPosition() types.StreamingToken {
s.userStream.lock.Lock() s.userStream.lock.Lock()
defer s.userStream.lock.Unlock() defer s.userStream.lock.Unlock()
@ -130,7 +132,7 @@ func (s *UserStreamListener) GetSyncPosition() types.StreamingToken {
// sincePos specifies from which point we want to be notified about. If there // sincePos specifies from which point we want to be notified about. If there
// has already been an update after sincePos we'll return a closed channel // has already been an update after sincePos we'll return a closed channel
// immediately. // immediately.
func (s *UserStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-chan struct{} { func (s *UserDeviceStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-chan struct{} {
s.userStream.lock.Lock() s.userStream.lock.Lock()
defer s.userStream.lock.Unlock() defer s.userStream.lock.Unlock()
@ -147,7 +149,7 @@ func (s *UserStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-c
} }
// Close cleans up resources used // Close cleans up resources used
func (s *UserStreamListener) Close() { func (s *UserDeviceStreamListener) Close() {
s.userStream.lock.Lock() s.userStream.lock.Lock()
defer s.userStream.lock.Unlock() defer s.userStream.lock.Unlock()