Use new testrig for key changes tests (#2552)

* Use new testrig for tests

* Log the error message
This commit is contained in:
Till 2022-07-05 14:50:24 +02:00 committed by GitHub
parent 43147bd654
commit f29cdb26f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,36 +1,26 @@
package storage package storage_test
import ( import (
"context" "context"
"fmt"
"io/ioutil"
"log"
"os"
"reflect" "reflect"
"testing" "testing"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
) )
var ctx = context.Background() var ctx = context.Background()
func MustCreateDatabase(t *testing.T) (Database, func()) { func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
tmpfile, err := ioutil.TempFile("", "keyserver_storage_test") base, close := testrig.CreateBaseDendrite(t, dbType)
db, err := storage.NewDatabase(base, &base.Cfg.KeyServer.Database)
if err != nil { if err != nil {
log.Fatal(err) t.Fatalf("failed to create new database: %v", err)
}
t.Logf("Database %s", tmpfile.Name())
db, err := NewDatabase(nil, &config.DatabaseOptions{
ConnectionString: config.DataSource(fmt.Sprintf("file://%s", tmpfile.Name())),
})
if err != nil {
t.Fatalf("Failed to NewDatabase: %s", err)
}
return db, func() {
os.Remove(tmpfile.Name())
} }
return db, close
} }
func MustNotError(t *testing.T, err error) { func MustNotError(t *testing.T, err error) {
@ -42,151 +32,159 @@ func MustNotError(t *testing.T, err error) {
} }
func TestKeyChanges(t *testing.T) { func TestKeyChanges(t *testing.T) {
db, clean := MustCreateDatabase(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
defer clean() db, clean := MustCreateDatabase(t, dbType)
_, err := db.StoreKeyChange(ctx, "@alice:localhost") defer clean()
MustNotError(t, err) _, err := db.StoreKeyChange(ctx, "@alice:localhost")
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") MustNotError(t, err)
MustNotError(t, err) deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") MustNotError(t, err)
MustNotError(t, err) deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) MustNotError(t, err)
if err != nil { userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
t.Fatalf("Failed to KeyChanges: %s", err) if err != nil {
} t.Fatalf("Failed to KeyChanges: %s", err)
if latest != deviceChangeIDC { }
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) if latest != deviceChangeIDC {
} t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { }
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
} t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
})
} }
func TestKeyChangesNoDupes(t *testing.T) { func TestKeyChangesNoDupes(t *testing.T) {
db, clean := MustCreateDatabase(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
defer clean() db, clean := MustCreateDatabase(t, dbType)
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") defer clean()
MustNotError(t, err) deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") MustNotError(t, err)
MustNotError(t, err) deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
if deviceChangeIDA == deviceChangeIDB { MustNotError(t, err)
t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) if deviceChangeIDA == deviceChangeIDB {
} t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") }
MustNotError(t, err) deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) MustNotError(t, err)
if err != nil { userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
t.Fatalf("Failed to KeyChanges: %s", err) if err != nil {
} t.Fatalf("Failed to KeyChanges: %s", err)
if latest != deviceChangeID { }
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) if latest != deviceChangeID {
} t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { }
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
} t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
})
} }
func TestKeyChangesUpperLimit(t *testing.T) { func TestKeyChangesUpperLimit(t *testing.T) {
db, clean := MustCreateDatabase(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
defer clean() db, clean := MustCreateDatabase(t, dbType)
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") defer clean()
MustNotError(t, err) deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") MustNotError(t, err)
MustNotError(t, err) deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
_, err = db.StoreKeyChange(ctx, "@charlie:localhost") MustNotError(t, err)
MustNotError(t, err) _, err = db.StoreKeyChange(ctx, "@charlie:localhost")
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) MustNotError(t, err)
if err != nil { userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
t.Fatalf("Failed to KeyChanges: %s", err) if err != nil {
} t.Fatalf("Failed to KeyChanges: %s", err)
if latest != deviceChangeIDB { }
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) if latest != deviceChangeIDB {
} t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { }
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
} t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
})
} }
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, // The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
// and that they are returned correctly when querying for device keys. // and that they are returned correctly when querying for device keys.
func TestDeviceKeysStreamIDGeneration(t *testing.T) { func TestDeviceKeysStreamIDGeneration(t *testing.T) {
var err error var err error
db, clean := MustCreateDatabase(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
defer clean() db, clean := MustCreateDatabase(t, dbType)
alice := "@alice:TestDeviceKeysStreamIDGeneration" defer clean()
bob := "@bob:TestDeviceKeysStreamIDGeneration" alice := "@alice:TestDeviceKeysStreamIDGeneration"
msgs := []api.DeviceMessage{ bob := "@bob:TestDeviceKeysStreamIDGeneration"
{ msgs := []api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate, {
DeviceKeys: &api.DeviceKeys{ Type: api.TypeDeviceKeyUpdate,
DeviceID: "AAA", DeviceKeys: &api.DeviceKeys{
UserID: alice, DeviceID: "AAA",
KeyJSON: []byte(`{"key":"v1"}`), UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 1
}, },
// StreamID: 1 {
}, Type: api.TypeDeviceKeyUpdate,
{ DeviceKeys: &api.DeviceKeys{
Type: api.TypeDeviceKeyUpdate, DeviceID: "AAA",
DeviceKeys: &api.DeviceKeys{ UserID: bob,
DeviceID: "AAA", KeyJSON: []byte(`{"key":"v1"}`),
UserID: bob, },
KeyJSON: []byte(`{"key":"v1"}`), // StreamID: 1 as this is a different user
}, },
// StreamID: 1 as this is a different user {
}, Type: api.TypeDeviceKeyUpdate,
{ DeviceKeys: &api.DeviceKeys{
Type: api.TypeDeviceKeyUpdate, DeviceID: "another_device",
DeviceKeys: &api.DeviceKeys{ UserID: alice,
DeviceID: "another_device", KeyJSON: []byte(`{"key":"v1"}`),
UserID: alice, },
KeyJSON: []byte(`{"key":"v1"}`), // StreamID: 2 as this is a 2nd device key
}, },
// StreamID: 2 as this is a 2nd device key
},
}
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 1 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
}
if msgs[1].StreamID != 1 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
}
if msgs[2].StreamID != 2 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
}
// updating a device sets the next stream ID for that user
msgs = []api.DeviceMessage{
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v2"}`),
},
// StreamID: 3
},
}
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 3 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
}
// Querying for device keys returns the latest stream IDs
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err)
}
wantStreamIDs := map[string]int64{
"AAA": 3,
"another_device": 2,
}
if len(msgs) != len(wantStreamIDs) {
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
}
for _, m := range msgs {
if m.StreamID != wantStreamIDs[m.DeviceID] {
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
} }
} MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 1 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
}
if msgs[1].StreamID != 1 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
}
if msgs[2].StreamID != 2 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
}
// updating a device sets the next stream ID for that user
msgs = []api.DeviceMessage{
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v2"}`),
},
// StreamID: 3
},
}
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 3 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
}
// Querying for device keys returns the latest stream IDs
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err)
}
wantStreamIDs := map[string]int64{
"AAA": 3,
"another_device": 2,
}
if len(msgs) != len(wantStreamIDs) {
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
}
for _, m := range msgs {
if m.StreamID != wantStreamIDs[m.DeviceID] {
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
}
}
})
} }