From e485f9c2bd15bca397229444399fa7e168eca43d Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 10 Mar 2022 13:17:28 +0000 Subject: [PATCH] 64-bit stream IDs for device list updates (#2267) --- federationapi/consumers/keychange.go | 4 ++-- keyserver/api/api.go | 6 +++--- keyserver/internal/device_list_update.go | 2 +- keyserver/internal/device_list_update_test.go | 14 +++++++------- keyserver/internal/internal.go | 2 +- keyserver/storage/interface.go | 2 +- keyserver/storage/postgres/device_keys_table.go | 10 +++++----- keyserver/storage/shared/storage.go | 12 ++++-------- keyserver/storage/sqlite3/device_keys_table.go | 14 +++++++------- keyserver/storage/storage_test.go | 2 +- keyserver/storage/tables/interface.go | 2 +- 11 files changed, 33 insertions(+), 37 deletions(-) diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 22dbc32d..33d716d2 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -203,9 +203,9 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { return err == nil } -func prevID(streamID int) []int { +func prevID(streamID int64) []int64 { if streamID <= 1 { return nil } - return []int{streamID - 1} + return []int64{streamID - 1} } diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 54eb04f8..d361c622 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -70,7 +70,7 @@ type DeviceMessage struct { *DeviceKeys `json:"DeviceKeys,omitempty"` *eduapi.OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"` // A monotonically increasing number which represents device changes for this user. - StreamID int + StreamID int64 DeviceChangeID int64 } @@ -108,7 +108,7 @@ type DeviceKeys struct { } // WithStreamID returns a copy of this device message with the given stream ID -func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage { +func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage { return DeviceMessage{ DeviceKeys: k, StreamID: streamID, @@ -281,7 +281,7 @@ type QueryDeviceMessagesRequest struct { type QueryDeviceMessagesResponse struct { // The latest stream ID - StreamID int + StreamID int64 Devices []DeviceMessage Error *KeyError } diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 974d0196..4b2b8c18 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -109,7 +109,7 @@ type DeviceListUpdaterDatabase interface { StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error // PrevIDsExists returns true if all prev IDs exist for this user. - PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) + PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index ff939355..0033a508 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -46,7 +46,7 @@ func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) erro type mockDeviceListUpdaterDatabase struct { staleUsers map[string]bool - prevIDsExist func(string, []int) bool + prevIDsExist func(string, []int64) bool storedKeys []api.DeviceMessage mu sync.Mutex // protect staleUsers } @@ -101,7 +101,7 @@ func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Contex } // PrevIDsExists returns true if all prev IDs exist for this user. -func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) { +func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { return d.prevIDsExist(userID, prevIDs), nil } @@ -139,7 +139,7 @@ func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrix func TestUpdateHavePrevID(t *testing.T) { db := &mockDeviceListUpdaterDatabase{ staleUsers: make(map[string]bool), - prevIDsExist: func(string, []int) bool { + prevIDsExist: func(string, []int64) bool { return true }, } @@ -151,7 +151,7 @@ func TestUpdateHavePrevID(t *testing.T) { Deleted: false, DeviceID: "FOO", Keys: []byte(`{"key":"value"}`), - PrevID: []int{0}, + PrevID: []int64{0}, StreamID: 1, UserID: "@alice:localhost", } @@ -185,7 +185,7 @@ func TestUpdateHavePrevID(t *testing.T) { func TestUpdateNoPrevID(t *testing.T) { db := &mockDeviceListUpdaterDatabase{ staleUsers: make(map[string]bool), - prevIDsExist: func(string, []int) bool { + prevIDsExist: func(string, []int64) bool { return false }, } @@ -226,7 +226,7 @@ func TestUpdateNoPrevID(t *testing.T) { Deleted: false, DeviceID: "another_device_id", Keys: []byte(`{"key":"value"}`), - PrevID: []int{3}, + PrevID: []int64{3}, StreamID: 4, UserID: remoteUserID, } @@ -268,7 +268,7 @@ func TestDebounce(t *testing.T) { t.Skipf("panic on closed channel on GHA") db := &mockDeviceListUpdaterDatabase{ staleUsers: make(map[string]bool), - prevIDsExist: func(string, []int) bool { + prevIDsExist: func(string, []int64) bool { return true }, } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 0a8bef95..cc9d3a61 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -205,7 +205,7 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query } return } - maxStreamID := 0 + maxStreamID := int64(0) for _, m := range msgs { if m.StreamID > maxStreamID { maxStreamID = m.StreamID diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 4dffe695..16e03477 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -49,7 +49,7 @@ type Database interface { StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error // PrevIDsExists returns true if all prev IDs exist for this user. - PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) + PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index 628301cf..ccd20cbd 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -121,7 +121,7 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { for i, key := range keys { var keyJSONStr string - var streamID int + var streamID int64 var displayName sql.NullString err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) if err != nil && err != sql.ErrNoRows { @@ -138,15 +138,15 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] return nil } -func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) { // nullable if there are no results - var nullStream sql.NullInt32 + var nullStream sql.NullInt64 err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) if err == sql.ErrNoRows { err = nil } if nullStream.Valid { - streamID = nullStream.Int32 + streamID = nullStream.Int64 } return } @@ -211,7 +211,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID } dk.UserID = userID var keyJSON string - var streamID int + var streamID int64 var displayName sql.NullString if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { return nil, err diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index f2790c8d..03215b93 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -59,12 +59,8 @@ func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) } -func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) { - sids := make([]int64, len(prevIDs)) - for i := range prevIDs { - sids[i] = int64(prevIDs[i]) - } - count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, sids) +func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { + count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs) if err != nil { return false, err } @@ -85,7 +81,7 @@ func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceM func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { // work out the latest stream IDs for each user - userIDToStreamID := make(map[string]int) + userIDToStreamID := make(map[string]int64) for _, k := range keys { userIDToStreamID[k.UserID] = 0 } @@ -95,7 +91,7 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe if err != nil { return err } - userIDToStreamID[userID] = int(streamID) + userIDToStreamID[userID] = streamID } // set the stream IDs for each key for i := range keys { diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index b461424c..e77b49b3 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -145,7 +145,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID dk.Type = api.TypeDeviceKeyUpdate dk.UserID = userID var keyJSON string - var streamID int + var streamID int64 var displayName sql.NullString if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { return nil, err @@ -166,7 +166,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { for i, key := range keys { var keyJSONStr string - var streamID int + var streamID int64 var displayName sql.NullString err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) if err != nil && err != sql.ErrNoRows { @@ -183,15 +183,15 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] return nil } -func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) { // nullable if there are no results - var nullStream sql.NullInt32 + var nullStream sql.NullInt64 err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) if err == sql.ErrNoRows { err = nil } if nullStream.Valid { - streamID = nullStream.Int32 + streamID = nullStream.Int64 } return } @@ -204,13 +204,13 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID } query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1) // nullable if there are no results - var count sql.NullInt32 + var count sql.NullInt64 err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count) if err != nil { return 0, err } if count.Valid { - return int(count.Int32), nil + return int(count.Int64), nil } return 0, nil } diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index 4d513724..84d2098a 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -177,7 +177,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { if err != nil { t.Fatalf("DeviceKeysForUser returned error: %s", err) } - wantStreamIDs := map[string]int{ + wantStreamIDs := map[string]int64{ "AAA": 3, "another_device": 2, } diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index cd171959..f840cd1f 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -37,7 +37,7 @@ type OneTimeKeys interface { type DeviceKeys interface { SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error - SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) + SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error