package storage_test import ( "context" "reflect" "sync" "testing" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" ) var ctx = context.Background() func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { base, close := testrig.CreateBaseDendrite(t, dbType) db, err := storage.NewDatabase(base, &base.Cfg.KeyServer.Database) if err != nil { t.Fatalf("failed to create new database: %v", err) } return db, close } func MustNotError(t *testing.T, err error) { t.Helper() if err == nil { return } t.Fatalf("operation failed: %s", err) } func TestKeyChanges(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, clean := MustCreateDatabase(t, dbType) defer clean() _, err := db.StoreKeyChange(ctx, "@alice:localhost") MustNotError(t, err) deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") MustNotError(t, err) deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") MustNotError(t, err) userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } if latest != deviceChangeIDC { t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) } if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) } }) } func TestKeyChangesNoDupes(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, clean := MustCreateDatabase(t, dbType) defer clean() deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") MustNotError(t, err) deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") MustNotError(t, err) if deviceChangeIDA == deviceChangeIDB { t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) } deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") MustNotError(t, err) userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } if latest != deviceChangeID { t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) } if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) } }) } func TestKeyChangesUpperLimit(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, clean := MustCreateDatabase(t, dbType) defer clean() deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") MustNotError(t, err) deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") MustNotError(t, err) _, err = db.StoreKeyChange(ctx, "@charlie:localhost") MustNotError(t, err) userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } if latest != deviceChangeIDB { t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) } if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) } }) } var dbLock sync.Mutex var deviceArray = []string{"AAA", "another_device"} // The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, // and that they are returned correctly when querying for device keys. func TestDeviceKeysStreamIDGeneration(t *testing.T) { var err error test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, clean := MustCreateDatabase(t, dbType) defer clean() alice := "@alice:TestDeviceKeysStreamIDGeneration" bob := "@bob:TestDeviceKeysStreamIDGeneration" msgs := []api.DeviceMessage{ { Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ DeviceID: "AAA", UserID: alice, KeyJSON: []byte(`{"key":"v1"}`), }, // StreamID: 1 }, { Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ DeviceID: "AAA", UserID: bob, KeyJSON: []byte(`{"key":"v1"}`), }, // StreamID: 1 as this is a different user }, { Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ DeviceID: "another_device", UserID: alice, KeyJSON: []byte(`{"key":"v1"}`), }, // StreamID: 2 as this is a 2nd device key }, } MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) if msgs[0].StreamID != 1 { t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) } if msgs[1].StreamID != 1 { t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) } if msgs[2].StreamID != 2 { t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) } // updating a device sets the next stream ID for that user msgs = []api.DeviceMessage{ { Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ DeviceID: "AAA", UserID: alice, KeyJSON: []byte(`{"key":"v2"}`), }, // StreamID: 3 }, } MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) if msgs[0].StreamID != 3 { t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) } dbLock.Lock() defer dbLock.Unlock() // Querying for device keys returns the latest stream IDs msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false) if err != nil { t.Fatalf("DeviceKeysForUser returned error: %s", err) } wantStreamIDs := map[string]int64{ "AAA": 3, "another_device": 2, } if len(msgs) != len(wantStreamIDs) { t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) } for _, m := range msgs { if m.StreamID != wantStreamIDs[m.DeviceID] { t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) } } }) }