diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index 164be6be..e8d1bfe8 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -48,13 +48,19 @@ type mockDeviceListUpdaterDatabase struct { staleUsers map[string]bool prevIDsExist func(string, []int) bool storedKeys []api.DeviceMessage + mu sync.Mutex // protect staleUsers } // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + d.mu.Lock() + defer d.mu.Unlock() var result []string - for userID := range d.staleUsers { + for userID, isStale := range d.staleUsers { + if !isStale { + continue + } _, remoteServer, err := gomatrixserverlib.SplitID('@', userID) if err != nil { return nil, err @@ -75,6 +81,8 @@ func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, do // MarkDeviceListStale sets the stale bit for this user to isStale. func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { + d.mu.Lock() + defer d.mu.Unlock() d.staleUsers[userID] = isStale return nil } @@ -247,3 +255,82 @@ func TestUpdateNoPrevID(t *testing.T) { } } + +// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the +// update is still ongoing. +func TestDebounce(t *testing.T) { + db := &mockDeviceListUpdaterDatabase{ + staleUsers: make(map[string]bool), + prevIDsExist: func(string, []int) bool { + return true + }, + } + ap := &mockDeviceListUpdaterAPI{} + producer := &mockKeyChangeProducer{} + fedCh := make(chan *http.Response, 1) + srv := gomatrixserverlib.ServerName("example.com") + userID := "@alice:example.com" + keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}` + incomingFedReq := make(chan struct{}) + fedClient := newFedClient(func(req *http.Request) (*http.Response, error) { + if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) { + return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path) + } + close(incomingFedReq) + return <-fedCh, nil + }) + updater := NewDeviceListUpdater(db, ap, producer, fedClient, 1) + if err := updater.Start(); err != nil { + t.Fatalf("failed to start updater: %s", err) + } + + // hit this 5 times + var wg sync.WaitGroup + wg.Add(5) + for i := 0; i < 5; i++ { + go func() { + defer wg.Done() + if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil { + t.Errorf("ManualUpdate: %s", err) + } + }() + } + + // wait until the updater hits federation + select { + case <-incomingFedReq: + case <-time.After(time.Second): + t.Fatalf("timed out waiting for updater to hit federation") + } + + // user should be marked as stale + if !db.staleUsers[userID] { + t.Errorf("user %s not marked as stale", userID) + } + // now send the response over federation + fedCh <- &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(` + { + "user_id": "` + userID + `", + "stream_id": 5, + "devices": [ + { + "device_id": "JLAFKJWSCS", + "keys": ` + keyJSON + `, + "device_display_name": "Mobile Phone" + } + ] + } + `)), + } + close(fedCh) + // wait until all 5 ManualUpdates return. If we hit federation again we won't send a response + // and should panic with read on a closed channel + wg.Wait() + + // user is no longer stale now + if db.staleUsers[userID] { + t.Errorf("user %s is marked as stale", userID) + } +}