Add test infrastructure code for dendrite unit/integ tests (#2331)

* Add test infrastructure code for dendrite unit/integ tests

Start re-enabling some syncapi storage tests in the process.

* Linting

* Add postgres service to unit tests

* dendrite not syncv3

* Skip test which doesn't work

* Linting

* Add `jetstream.PrepareForTests`

Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
This commit is contained in:
kegsay 2022-04-08 10:12:30 +01:00 committed by GitHub
parent 955e6eb307
commit 7499147550
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 598 additions and 222 deletions

View File

@ -73,6 +73,26 @@ jobs:
timeout-minutes: 5 timeout-minutes: 5
name: Unit tests (Go ${{ matrix.go }}) name: Unit tests (Go ${{ matrix.go }})
runs-on: ubuntu-latest runs-on: ubuntu-latest
# Service containers to run with `container-job`
services:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres:13-alpine
# Provide the password for postgres
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: dendrite
ports:
# Maps tcp port 5432 on service container to the host
- 5432:5432
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@ -92,6 +112,11 @@ jobs:
restore-keys: | restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-test- ${{ runner.os }}-go${{ matrix.go }}-test-
- run: go test ./... - run: go test ./...
env:
POSTGRES_HOST: localhost
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: dendrite
# build Dendrite for linux with different architectures and go versions # build Dendrite for linux with different architectures and go versions
build: build:

View File

@ -2,7 +2,6 @@ package input_test
import ( import (
"context" "context"
"fmt"
"os" "os"
"testing" "testing"
"time" "time"
@ -12,30 +11,22 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/internal/input"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
) )
func psqlConnectionString() config.DataSource { var js nats.JetStreamContext
user := os.Getenv("POSTGRES_USER") var jc *nats.Conn
if user == "" {
user = "dendrite" func TestMain(m *testing.M) {
} var pc *process.ProcessContext
dbName := os.Getenv("POSTGRES_DB") pc, js, jc = jetstream.PrepareForTests()
if dbName == "" { code := m.Run()
dbName = "dendrite" pc.ShutdownDendrite()
} pc.WaitForComponentsToFinish()
connStr := fmt.Sprintf( os.Exit(code)
"user=%s dbname=%s sslmode=disable", user, dbName,
)
password := os.Getenv("POSTGRES_PASSWORD")
if password != "" {
connStr += fmt.Sprintf(" password=%s", password)
}
host := os.Getenv("POSTGRES_HOST")
if host != "" {
connStr += fmt.Sprintf(" host=%s", host)
}
return config.DataSource(connStr)
} }
func TestSingleTransactionOnInput(t *testing.T) { func TestSingleTransactionOnInput(t *testing.T) {
@ -63,7 +54,7 @@ func TestSingleTransactionOnInput(t *testing.T) {
} }
db, err := storage.Open( db, err := storage.Open(
&config.DatabaseOptions{ &config.DatabaseOptions{
ConnectionString: psqlConnectionString(), ConnectionString: "",
MaxOpenConnections: 1, MaxOpenConnections: 1,
MaxIdleConnections: 1, MaxIdleConnections: 1,
}, },
@ -74,7 +65,9 @@ func TestSingleTransactionOnInput(t *testing.T) {
t.SkipNow() t.SkipNow()
} }
inputter := &input.Inputer{ inputter := &input.Inputer{
DB: db, DB: db,
JetStream: js,
NATSClient: jc,
} }
res := &api.InputRoomEventsResponse{} res := &api.InputRoomEventsResponse{}
inputter.InputRoomEvents( inputter.InputRoomEvents(

View File

@ -13,12 +13,22 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
natsserver "github.com/nats-io/nats-server/v2/server" natsserver "github.com/nats-io/nats-server/v2/server"
"github.com/nats-io/nats.go"
natsclient "github.com/nats-io/nats.go" natsclient "github.com/nats-io/nats.go"
) )
var natsServer *natsserver.Server var natsServer *natsserver.Server
var natsServerMutex sync.Mutex var natsServerMutex sync.Mutex
func PrepareForTests() (*process.ProcessContext, nats.JetStreamContext, *nats.Conn) {
cfg := &config.Dendrite{}
cfg.Defaults(true)
cfg.Global.JetStream.InMemory = true
pc := process.NewProcessContext()
js, jc := Prepare(pc, &cfg.Global.JetStream)
return pc, js, jc
}
func Prepare(process *process.ProcessContext, cfg *config.JetStream) (natsclient.JetStreamContext, *natsclient.Conn) { func Prepare(process *process.ProcessContext, cfg *config.JetStream) (natsclient.JetStreamContext, *natsclient.Conn) {
// check if we need an in-process NATS Server // check if we need an in-process NATS Server
if len(cfg.Addresses) != 0 { if len(cfg.Addresses) != 0 {

View File

@ -1,121 +1,28 @@
package storage_test package storage_test
// TODO: Fix these tests
/*
import ( import (
"context" "context"
"crypto/ed25519"
"encoding/json"
"fmt" "fmt"
"os"
"testing" "testing"
"time"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
var ( var ctx = context.Background()
ctx = context.Background()
emptyStateKey = ""
testOrigin = gomatrixserverlib.ServerName("hollow.knight")
testRoomID = fmt.Sprintf("!hallownest:%s", testOrigin)
testUserIDA = fmt.Sprintf("@hornet:%s", testOrigin)
testUserIDB = fmt.Sprintf("@paleking:%s", testOrigin)
testUserDeviceA = userapi.Device{
UserID: testUserIDA,
ID: "device_id_A",
DisplayName: "Device A",
}
testRoomVersion = gomatrixserverlib.RoomVersionV4
testKeyID = gomatrixserverlib.KeyID("ed25519:storage_test")
testPrivateKey = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
})
)
func MustCreateEvent(t *testing.T, roomID string, prevs []*gomatrixserverlib.HeaderedEvent, b *gomatrixserverlib.EventBuilder) *gomatrixserverlib.HeaderedEvent { func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
b.RoomID = roomID connStr, close := test.PrepareDBConnectionString(t, dbType)
if prevs != nil { db, err := storage.NewSyncServerDatasource(&config.DatabaseOptions{
prevIDs := make([]string, len(prevs)) ConnectionString: config.DataSource(connStr),
for i := range prevs {
prevIDs[i] = prevs[i].EventID()
}
b.PrevEvents = prevIDs
}
e, err := b.Build(time.Now(), testOrigin, testKeyID, testPrivateKey, testRoomVersion)
if err != nil {
t.Fatalf("failed to build event: %s", err)
}
return e.Headered(testRoomVersion)
}
func MustCreateDatabase(t *testing.T) storage.Database {
dbname := fmt.Sprintf("test_%s.db", t.Name())
if _, err := os.Stat(dbname); err == nil {
if err = os.Remove(dbname); err != nil {
t.Fatalf("tried to delete stale test database but failed: %s", err)
}
}
db, err := sqlite3.NewDatabase(&config.DatabaseOptions{
ConnectionString: config.DataSource(fmt.Sprintf("file:%s", dbname)),
}) })
if err != nil { if err != nil {
t.Fatalf("NewSyncServerDatasource returned %s", err) t.Fatalf("NewSyncServerDatasource returned %s", err)
} }
return db return db, close
}
// Create a list of events which include a create event, join event and some messages.
func SimpleRoom(t *testing.T, roomID, userA, userB string) (msgs []*gomatrixserverlib.HeaderedEvent, state []*gomatrixserverlib.HeaderedEvent) {
var events []*gomatrixserverlib.HeaderedEvent
events = append(events, MustCreateEvent(t, roomID, nil, &gomatrixserverlib.EventBuilder{
Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, userA)),
Type: "m.room.create",
StateKey: &emptyStateKey,
Sender: userA,
Depth: int64(len(events) + 1),
}))
state = append(state, events[len(events)-1])
events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
Content: []byte(`{"membership":"join"}`),
Type: "m.room.member",
StateKey: &userA,
Sender: userA,
Depth: int64(len(events) + 1),
}))
state = append(state, events[len(events)-1])
for i := 0; i < 10; i++ {
events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
Content: []byte(fmt.Sprintf(`{"body":"Message A %d"}`, i+1)),
Type: "m.room.message",
Sender: userA,
Depth: int64(len(events) + 1),
}))
}
events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
Content: []byte(`{"membership":"join"}`),
Type: "m.room.member",
StateKey: &userB,
Sender: userB,
Depth: int64(len(events) + 1),
}))
state = append(state, events[len(events)-1])
for i := 0; i < 10; i++ {
events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
Content: []byte(fmt.Sprintf(`{"body":"Message B %d"}`, i+1)),
Type: "m.room.message",
Sender: userB,
Depth: int64(len(events) + 1),
}))
}
return events, state
} }
func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) { func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) {
@ -138,111 +45,115 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
} }
func TestWriteEvents(t *testing.T) { func TestWriteEvents(t *testing.T) {
t.Parallel() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db := MustCreateDatabase(t) t.Parallel()
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) alice := test.NewUser()
MustWriteEvents(t, db, events) r := test.NewRoom(t, alice)
db, close := MustCreateDatabase(t, dbType)
defer close()
MustWriteEvents(t, db, r.Events())
})
} }
// These tests assert basic functionality of the IncrementalSync and CompleteSync functions. // These tests assert basic functionality of RecentEvents for PDUs
func TestSyncResponse(t *testing.T) { func TestRecentEventsPDU(t *testing.T) {
t.Parallel() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db := MustCreateDatabase(t) db, close := MustCreateDatabase(t, dbType)
events, state := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) defer close()
positions := MustWriteEvents(t, db, events) alice := test.NewUser()
latest, err := db.SyncPosition(ctx) var filter gomatrixserverlib.RoomEventFilter
if err != nil { filter.Limit = 100
t.Fatalf("failed to get SyncPosition: %s", err) r := test.NewRoom(t, alice)
} r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"})
events := r.Events()
positions := MustWriteEvents(t, db, events)
latest, err := db.MaxStreamPositionForPDUs(ctx)
if err != nil {
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
}
testCases := []struct { testCases := []struct {
Name string Name string
DoSync func() (*types.Response, error) From types.StreamPosition
WantTimeline []*gomatrixserverlib.HeaderedEvent To types.StreamPosition
WantState []*gomatrixserverlib.HeaderedEvent WantEvents []*gomatrixserverlib.HeaderedEvent
}{ WantLimited bool
// The purpose of this test is to make sure that incremental syncs are including up to the latest events. }{
// It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event. // The purpose of this test is to make sure that incremental syncs are including up to the latest events.
// It makes sure the response includes the final event. // It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event.
{ // It makes sure the response includes the final event.
Name: "IncrementalSync penultimate", {
DoSync: func() (*types.Response, error) { Name: "IncrementalSync penultimate",
from := types.StreamingToken{ // pretend we are at the penultimate event From: positions[len(positions)-2], // pretend we are at the penultimate event
PDUPosition: positions[len(positions)-2], To: latest,
} WantEvents: events[len(events)-1:],
res := types.NewResponse() WantLimited: false,
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
}, },
WantTimeline: events[len(events)-1:], /*
}, // The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the
// The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the // number of returned events. This is critical for big rooms hence the test here.
// number of returned events. This is critical for big rooms hence the test here. {
{ Name: "IncrementalSync limited",
Name: "IncrementalSync limited", DoSync: func() (*types.Response, error) {
DoSync: func() (*types.Response, error) { from := types.StreamingToken{ // pretend we are 10 events behind
from := types.StreamingToken{ // pretend we are 10 events behind PDUPosition: positions[len(positions)-11],
PDUPosition: positions[len(positions)-11], }
} res := types.NewResponse()
res := types.NewResponse() // limit is set to 5
// limit is set to 5 return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) },
}, // want the last 5 events, NOT the last 10.
// want the last 5 events, NOT the last 10. WantTimeline: events[len(events)-5:],
WantTimeline: events[len(events)-5:], },
}, // The purpose of this test is to check that CompleteSync returns all the current state as well as
// The purpose of this test is to check that CompleteSync returns all the current state as well as // honouring the `numRecentEventsPerRoom` value
// honouring the `numRecentEventsPerRoom` value {
{ Name: "CompleteSync limited",
Name: "CompleteSync limited", DoSync: func() (*types.Response, error) {
DoSync: func() (*types.Response, error) { res := types.NewResponse()
res := types.NewResponse() // limit set to 5
// limit set to 5 return db.CompleteSync(ctx, res, testUserDeviceA, 5)
return db.CompleteSync(ctx, res, testUserDeviceA, 5) },
}, // want the last 5 events
// want the last 5 events WantTimeline: events[len(events)-5:],
WantTimeline: events[len(events)-5:], // want all state for the room
// want all state for the room WantState: state,
WantState: state, },
}, // The purpose of this test is to check that CompleteSync can return everything with a high enough
// The purpose of this test is to check that CompleteSync can return everything with a high enough // `numRecentEventsPerRoom`.
// `numRecentEventsPerRoom`. {
{ Name: "CompleteSync",
Name: "CompleteSync", DoSync: func() (*types.Response, error) {
DoSync: func() (*types.Response, error) { res := types.NewResponse()
res := types.NewResponse() return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1)
return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1) },
}, WantTimeline: events,
WantTimeline: events, // We want no state at all as that field in /sync is the delta between the token (beginning of time)
// We want no state at all as that field in /sync is the delta between the token (beginning of time) // and the START of the timeline.
// and the START of the timeline. }, */
}, }
}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.Name, func(st *testing.T) { t.Run(tc.Name, func(st *testing.T) {
res, err := tc.DoSync() gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{
if err != nil { From: tc.From,
st.Fatalf("failed to do sync: %s", err) To: tc.To,
} }, &filter, true, true)
next := types.StreamingToken{ if err != nil {
PDUPosition: latest.PDUPosition, st.Fatalf("failed to do sync: %s", err)
TypingPosition: latest.TypingPosition, }
ReceiptPosition: latest.ReceiptPosition, if limited != tc.WantLimited {
SendToDevicePosition: latest.SendToDevicePosition, st.Errorf("got limited=%v want %v", limited, tc.WantLimited)
} }
if res.NextBatch.String() != next.String() { if len(gotEvents) != len(tc.WantEvents) {
st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String()) st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents))
} }
roomRes, ok := res.Rooms.Join[testRoomID] })
if !ok { }
st.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res) })
}
assertEventsEqual(st, "state for "+testRoomID, false, roomRes.State.Events, tc.WantState)
assertEventsEqual(st, "timeline for "+testRoomID, false, roomRes.Timeline.Events, tc.WantTimeline)
})
}
} }
/*
func TestGetEventsInRangeWithPrevBatch(t *testing.T) { func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
t.Parallel() t.Parallel()
db := MustCreateDatabase(t) db := MustCreateDatabase(t)

127
test/db.go Normal file
View File

@ -0,0 +1,127 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"database/sql"
"fmt"
"os"
"os/exec"
"os/user"
"testing"
)
type DBType int
var DBTypeSQLite DBType = 1
var DBTypePostgres DBType = 2
var Quiet = false
func createLocalDB(dbName string) string {
if !Quiet {
fmt.Println("Note: tests require a postgres install accessible to the current user")
}
createDB := exec.Command("createdb", dbName)
if !Quiet {
createDB.Stdout = os.Stdout
createDB.Stderr = os.Stderr
}
err := createDB.Run()
if err != nil && !Quiet {
fmt.Println("createLocalDB returned error:", err)
}
return dbName
}
func currentUser() string {
user, err := user.Current()
if err != nil {
if !Quiet {
fmt.Println("cannot get current user: ", err)
}
os.Exit(2)
}
return user.Username
}
// Prepare a sqlite or postgres connection string for testing.
// Returns the connection string to use and a close function which must be called when the test finishes.
// Calling this function twice will return the same database, which will have data from previous tests
// unless close() is called.
// TODO: namespace for concurrent package tests
func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) {
if dbType == DBTypeSQLite {
dbname := "dendrite_test.db"
return fmt.Sprintf("file:%s", dbname), func() {
err := os.Remove(dbname)
if err != nil {
t.Fatalf("failed to cleanup sqlite db '%s': %s", dbname, err)
}
}
}
// Required vars: user and db
// We'll try to infer from the local env if they are missing
user := os.Getenv("POSTGRES_USER")
if user == "" {
user = currentUser()
}
dbName := os.Getenv("POSTGRES_DB")
if dbName == "" {
dbName = createLocalDB("dendrite_test")
}
connStr = fmt.Sprintf(
"user=%s dbname=%s sslmode=disable",
user, dbName,
)
// optional vars, used in CI
password := os.Getenv("POSTGRES_PASSWORD")
if password != "" {
connStr += fmt.Sprintf(" password=%s", password)
}
host := os.Getenv("POSTGRES_HOST")
if host != "" {
connStr += fmt.Sprintf(" host=%s", host)
}
return connStr, func() {
// Drop all tables on the database to get a fresh instance
db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatalf("failed to connect to postgres db '%s': %s", connStr, err)
}
_, err = db.Exec(`DROP SCHEMA public CASCADE;
CREATE SCHEMA public;`)
if err != nil {
t.Fatalf("failed to cleanup postgres db '%s': %s", connStr, err)
}
_ = db.Close()
}
}
// Creates subtests with each known DBType
func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) {
dbs := map[string]DBType{
"postgres": DBTypePostgres,
"sqlite": DBTypeSQLite,
}
for dbName, dbType := range dbs {
dbt := dbType
t.Run(dbName, func(tt *testing.T) {
testFn(tt, dbt)
})
}
}

51
test/event.go Normal file
View File

@ -0,0 +1,51 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"crypto/ed25519"
"time"
"github.com/matrix-org/gomatrixserverlib"
)
type eventMods struct {
originServerTS time.Time
origin gomatrixserverlib.ServerName
stateKey *string
unsigned interface{}
keyID gomatrixserverlib.KeyID
privKey ed25519.PrivateKey
}
type eventModifier func(e *eventMods)
func WithTimestamp(ts time.Time) eventModifier {
return func(e *eventMods) {
e.originServerTS = ts
}
}
func WithStateKey(skey string) eventModifier {
return func(e *eventMods) {
e.stateKey = &skey
}
}
func WithUnsigned(unsigned interface{}) eventModifier {
return func(e *eventMods) {
e.unsigned = unsigned
}
}

223
test/room.go Normal file
View File

@ -0,0 +1,223 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"crypto/ed25519"
"encoding/json"
"fmt"
"sync/atomic"
"testing"
"time"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/gomatrixserverlib"
)
type Preset int
var (
PresetNone Preset = 0
PresetPrivateChat Preset = 1
PresetPublicChat Preset = 2
PresetTrustedPrivateChat Preset = 3
roomIDCounter = int64(0)
testKeyID = gomatrixserverlib.KeyID("ed25519:test")
testPrivateKey = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
})
)
type Room struct {
ID string
Version gomatrixserverlib.RoomVersion
preset Preset
creator *User
authEvents gomatrixserverlib.AuthEvents
events []*gomatrixserverlib.HeaderedEvent
}
// Create a new test room. Automatically creates the initial create events.
func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
t.Helper()
counter := atomic.AddInt64(&roomIDCounter, 1)
// set defaults then let roomModifiers override
r := &Room{
ID: fmt.Sprintf("!%d:localhost", counter),
creator: creator,
authEvents: gomatrixserverlib.NewAuthEvents(nil),
preset: PresetPublicChat,
Version: gomatrixserverlib.RoomVersionV9,
}
for _, m := range modifiers {
m(t, r)
}
r.insertCreateEvents(t)
return r
}
func (r *Room) insertCreateEvents(t *testing.T) {
t.Helper()
var joinRule gomatrixserverlib.JoinRuleContent
var hisVis gomatrixserverlib.HistoryVisibilityContent
plContent := eventutil.InitialPowerLevelsContent(r.creator.ID)
switch r.preset {
case PresetTrustedPrivateChat:
fallthrough
case PresetPrivateChat:
joinRule.JoinRule = "invite"
hisVis.HistoryVisibility = "shared"
case PresetPublicChat:
joinRule.JoinRule = "public"
hisVis.HistoryVisibility = "shared"
}
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{
"creator": r.creator.ID,
"room_version": r.Version,
}, WithStateKey(""))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, WithStateKey(r.creator.ID))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey(""))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey(""))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey(""))
}
// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe.
func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent {
t.Helper()
depth := 1 + len(r.events) // depth starts at 1
// possible event modifiers (optional fields)
mod := &eventMods{}
for _, m := range mods {
m(mod)
}
if mod.privKey == nil {
mod.privKey = testPrivateKey
}
if mod.keyID == "" {
mod.keyID = testKeyID
}
if mod.originServerTS.IsZero() {
mod.originServerTS = time.Now()
}
if mod.origin == "" {
mod.origin = gomatrixserverlib.ServerName("localhost")
}
var unsigned gomatrixserverlib.RawJSON
var err error
if mod.unsigned != nil {
unsigned, err = json.Marshal(mod.unsigned)
if err != nil {
t.Fatalf("CreateEvent[%s]: failed to marshal unsigned field: %s", eventType, err)
}
}
builder := &gomatrixserverlib.EventBuilder{
Sender: creator.ID,
RoomID: r.ID,
Type: eventType,
StateKey: mod.stateKey,
Depth: int64(depth),
Unsigned: unsigned,
}
err = builder.SetContent(content)
if err != nil {
t.Fatalf("CreateEvent[%s]: failed to SetContent: %s", eventType, err)
}
if depth > 1 {
builder.PrevEvents = []gomatrixserverlib.EventReference{r.events[len(r.events)-1].EventReference()}
}
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
if err != nil {
t.Fatalf("CreateEvent[%s]: failed to StateNeededForEventBuilder: %s", eventType, err)
}
refs, err := eventsNeeded.AuthEventReferences(&r.authEvents)
if err != nil {
t.Fatalf("CreateEvent[%s]: failed to AuthEventReferences: %s", eventType, err)
}
builder.AuthEvents = refs
ev, err := builder.Build(
mod.originServerTS, mod.origin, mod.keyID,
mod.privKey, r.Version,
)
if err != nil {
t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err)
}
if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil {
t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err)
}
return ev.Headered(r.Version)
}
// Add a new event to this room DAG. Not thread-safe.
func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) {
t.Helper()
// Add the event to the list of auth events
r.events = append(r.events, he)
if he.StateKey() != nil {
err := r.authEvents.AddEvent(he.Unwrap())
if err != nil {
t.Fatalf("InsertEvent: failed to add event to auth events: %s", err)
}
}
}
func (r *Room) Events() []*gomatrixserverlib.HeaderedEvent {
return r.events
}
func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent {
t.Helper()
he := r.CreateEvent(t, creator, eventType, content, mods...)
r.InsertEvent(t, he)
return he
}
// All room modifiers are below
type roomModifier func(t *testing.T, r *Room)
func RoomPreset(p Preset) roomModifier {
return func(t *testing.T, r *Room) {
switch p {
case PresetPrivateChat:
fallthrough
case PresetPublicChat:
fallthrough
case PresetTrustedPrivateChat:
fallthrough
case PresetNone:
r.preset = p
default:
t.Errorf("invalid RoomPreset: %v", p)
}
}
}
func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier {
return func(t *testing.T, r *Room) {
r.Version = ver
}
}

36
test/user.go Normal file
View File

@ -0,0 +1,36 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"fmt"
"sync/atomic"
)
var (
userIDCounter = int64(0)
)
type User struct {
ID string
}
func NewUser() *User {
counter := atomic.AddInt64(&userIDCounter, 1)
u := &User{
ID: fmt.Sprintf("@%d:localhost", counter),
}
return u
}