package tables_test import ( "context" "testing" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/stretchr/testify/assert" ) func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables.StateSnapshot, close func()) { t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) db, err := sqlutil.Open(&config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, sqlutil.NewExclusiveWriter()) assert.NoError(t, err) switch dbType { case test.DBTypePostgres: err = postgres.CreateStateSnapshotTable(db) assert.NoError(t, err) tab, err = postgres.PrepareStateSnapshotTable(db) case test.DBTypeSQLite: err = sqlite3.CreateStateSnapshotTable(db) assert.NoError(t, err) tab, err = sqlite3.PrepareStateSnapshotTable(db) } assert.NoError(t, err) return tab, close } func TestStateSnapshotTable(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { tab, close := mustCreateStateSnapshotTable(t, dbType) defer close() // generate some dummy data var stateBlockNIDs types.StateBlockNIDs for i := 0; i < 100; i++ { stateBlockNIDs = append(stateBlockNIDs, types.StateBlockNID(i)) } stateNID, err := tab.InsertState(ctx, nil, 1, stateBlockNIDs) assert.NoError(t, err) assert.Equal(t, types.StateSnapshotNID(1), stateNID) // verify ON CONFLICT; Note: this updates the sequence! stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs) assert.NoError(t, err) assert.Equal(t, types.StateSnapshotNID(1), stateNID) // create a second snapshot var stateBlockNIDs2 types.StateBlockNIDs for i := 100; i < 150; i++ { stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i)) } stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2) assert.NoError(t, err) // StateSnapshotNID is now 3, since the DO UPDATE SET statement incremented the sequence assert.Equal(t, types.StateSnapshotNID(3), stateNID) nidLists, err := tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{1, 3}) assert.NoError(t, err) assert.Equal(t, stateBlockNIDs, types.StateBlockNIDs(nidLists[0].StateBlockNIDs)) assert.Equal(t, stateBlockNIDs2, types.StateBlockNIDs(nidLists[1].StateBlockNIDs)) // check we get an error if the state snapshot does not exist _, err = tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{2}) assert.Error(t, err) // create a second snapshot for i := 0; i < 65555; i++ { stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i)) } _, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2) assert.NoError(t, err) }) }