Use new testrig for key changes tests (#2552)

* Use new testrig for tests

* Log the error message
This commit is contained in:
Till 2022-07-05 14:50:24 +02:00 committed by GitHub
parent 43147bd654
commit f29cdb26f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,36 +1,26 @@
package storage package storage_test
import ( import (
"context" "context"
"fmt"
"io/ioutil"
"log"
"os"
"reflect" "reflect"
"testing" "testing"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
) )
var ctx = context.Background() var ctx = context.Background()
func MustCreateDatabase(t *testing.T) (Database, func()) { func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
tmpfile, err := ioutil.TempFile("", "keyserver_storage_test") base, close := testrig.CreateBaseDendrite(t, dbType)
db, err := storage.NewDatabase(base, &base.Cfg.KeyServer.Database)
if err != nil { if err != nil {
log.Fatal(err) t.Fatalf("failed to create new database: %v", err)
}
t.Logf("Database %s", tmpfile.Name())
db, err := NewDatabase(nil, &config.DatabaseOptions{
ConnectionString: config.DataSource(fmt.Sprintf("file://%s", tmpfile.Name())),
})
if err != nil {
t.Fatalf("Failed to NewDatabase: %s", err)
}
return db, func() {
os.Remove(tmpfile.Name())
} }
return db, close
} }
func MustNotError(t *testing.T, err error) { func MustNotError(t *testing.T, err error) {
@ -42,7 +32,8 @@ func MustNotError(t *testing.T, err error) {
} }
func TestKeyChanges(t *testing.T) { func TestKeyChanges(t *testing.T) {
db, clean := MustCreateDatabase(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := MustCreateDatabase(t, dbType)
defer clean() defer clean()
_, err := db.StoreKeyChange(ctx, "@alice:localhost") _, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err) MustNotError(t, err)
@ -60,10 +51,12 @@ func TestKeyChanges(t *testing.T) {
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
} }
})
} }
func TestKeyChangesNoDupes(t *testing.T) { func TestKeyChangesNoDupes(t *testing.T) {
db, clean := MustCreateDatabase(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := MustCreateDatabase(t, dbType)
defer clean() defer clean()
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err) MustNotError(t, err)
@ -84,10 +77,12 @@ func TestKeyChangesNoDupes(t *testing.T) {
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
} }
})
} }
func TestKeyChangesUpperLimit(t *testing.T) { func TestKeyChangesUpperLimit(t *testing.T) {
db, clean := MustCreateDatabase(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := MustCreateDatabase(t, dbType)
defer clean() defer clean()
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err) MustNotError(t, err)
@ -105,13 +100,15 @@ func TestKeyChangesUpperLimit(t *testing.T) {
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
} }
})
} }
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, // The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
// and that they are returned correctly when querying for device keys. // and that they are returned correctly when querying for device keys.
func TestDeviceKeysStreamIDGeneration(t *testing.T) { func TestDeviceKeysStreamIDGeneration(t *testing.T) {
var err error var err error
db, clean := MustCreateDatabase(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := MustCreateDatabase(t, dbType)
defer clean() defer clean()
alice := "@alice:TestDeviceKeysStreamIDGeneration" alice := "@alice:TestDeviceKeysStreamIDGeneration"
bob := "@bob:TestDeviceKeysStreamIDGeneration" bob := "@bob:TestDeviceKeysStreamIDGeneration"
@ -189,4 +186,5 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
} }
} }
})
} }