mirror of
https://github.com/1f349/dendrite.git
synced 2024-11-08 18:16:59 +00:00
syncapi: add more tests; fix more bugs (#2338)
* syncapi: add more tests; fix more bugs bugfixes: - The postgres impl of TopologyTable.SelectEventIDsInRange did not use the provided txn - The postgres impl of EventsTable.SelectEvents did not preserve the ordering of the input event IDs in the output events slice - The sqlite impl of EventsTable.SelectEvents did not use a bulk `IN ($1)` query. Added tests: - `TestGetEventsInRangeWithTopologyToken` - `TestOutputRoomEventsTable` - `TestTopologyTable` * -p 1 for now
This commit is contained in:
parent
986d27a128
commit
6d25bd6ca5
2
.github/workflows/dendrite.yml
vendored
2
.github/workflows/dendrite.yml
vendored
@ -111,7 +111,7 @@ jobs:
|
|||||||
key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }}
|
key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }}
|
||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-go${{ matrix.go }}-test-
|
${{ runner.os }}-go${{ matrix.go }}-test-
|
||||||
- run: go test ./...
|
- run: go test -p 1 ./...
|
||||||
env:
|
env:
|
||||||
POSTGRES_HOST: localhost
|
POSTGRES_HOST: localhost
|
||||||
POSTGRES_USER: postgres
|
POSTGRES_USER: postgres
|
||||||
|
@ -104,7 +104,7 @@ type Database interface {
|
|||||||
// DeletePeek deletes all peeks for a given room by a given user
|
// DeletePeek deletes all peeks for a given room by a given user
|
||||||
// Returns an error if there was a problem communicating with the database.
|
// Returns an error if there was a problem communicating with the database.
|
||||||
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
|
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
|
||||||
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit.
|
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
|
||||||
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
|
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
|
||||||
// EventPositionInTopology returns the depth and stream position of the given event.
|
// EventPositionInTopology returns the depth and stream position of the given event.
|
||||||
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
|
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
|
||||||
|
@ -427,7 +427,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
|||||||
// selectEvents returns the events for the given event IDs. If an event is
|
// selectEvents returns the events for the given event IDs. If an event is
|
||||||
// missing from the database, it will be omitted.
|
// missing from the database, it will be omitted.
|
||||||
func (s *outputRoomEventsStatements) SelectEvents(
|
func (s *outputRoomEventsStatements) SelectEvents(
|
||||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool,
|
||||||
) ([]types.StreamEvent, error) {
|
) ([]types.StreamEvent, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
|
||||||
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||||
@ -435,7 +435,25 @@ func (s *outputRoomEventsStatements) SelectEvents(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
||||||
return rowsToStreamEvents(rows)
|
streamEvents, err := rowsToStreamEvents(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if preserveOrder {
|
||||||
|
eventMap := make(map[string]types.StreamEvent)
|
||||||
|
for _, ev := range streamEvents {
|
||||||
|
eventMap[ev.EventID()] = ev
|
||||||
|
}
|
||||||
|
var returnEvents []types.StreamEvent
|
||||||
|
for _, eventID := range eventIDs {
|
||||||
|
ev, ok := eventMap[eventID]
|
||||||
|
if ok {
|
||||||
|
returnEvents = append(returnEvents, ev)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return returnEvents, nil
|
||||||
|
}
|
||||||
|
return streamEvents, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
|
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
|
||||||
|
@ -148,9 +148,9 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
|||||||
// is requested or not.
|
// is requested or not.
|
||||||
var stmt *sql.Stmt
|
var stmt *sql.Stmt
|
||||||
if chronologicalOrder {
|
if chronologicalOrder {
|
||||||
stmt = s.selectEventIDsInRangeASCStmt
|
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt)
|
||||||
} else {
|
} else {
|
||||||
stmt = s.selectEventIDsInRangeDESCStmt
|
stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query the event IDs.
|
// Query the event IDs.
|
||||||
|
@ -150,7 +150,7 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre
|
|||||||
// Returns an error if there was a problem talking with the database.
|
// Returns an error if there was a problem talking with the database.
|
||||||
// Does not include any transaction IDs in the returned events.
|
// Does not include any transaction IDs in the returned events.
|
||||||
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs)
|
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -312,7 +312,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e
|
|||||||
|
|
||||||
// Check if we have all of the event's previous events. If an event is
|
// Check if we have all of the event's previous events. If an event is
|
||||||
// missing, add it to the room's backward extremities.
|
// missing, add it to the room's backward extremities.
|
||||||
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs())
|
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -457,7 +457,7 @@ func (d *Database) GetEventsInTopologicalRange(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Retrieve the events' contents using their IDs.
|
// Retrieve the events' contents using their IDs.
|
||||||
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs)
|
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, true)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -619,7 +619,7 @@ func (d *Database) fetchMissingStateEvents(
|
|||||||
) ([]types.StreamEvent, error) {
|
) ([]types.StreamEvent, error) {
|
||||||
// Fetch from the events table first so we pick up the stream ID for the
|
// Fetch from the events table first so we pick up the stream ID for the
|
||||||
// event.
|
// event.
|
||||||
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs)
|
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -51,13 +51,13 @@ const selectMaxAccountDataIDSQL = "" +
|
|||||||
|
|
||||||
type accountDataStatements struct {
|
type accountDataStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *StreamIDStatements
|
||||||
insertAccountDataStmt *sql.Stmt
|
insertAccountDataStmt *sql.Stmt
|
||||||
selectMaxAccountDataIDStmt *sql.Stmt
|
selectMaxAccountDataIDStmt *sql.Stmt
|
||||||
selectAccountDataInRangeStmt *sql.Stmt
|
selectAccountDataInRangeStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
|
func NewSqliteAccountDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.AccountData, error) {
|
||||||
s := &accountDataStatements{
|
s := &accountDataStatements{
|
||||||
db: db,
|
db: db,
|
||||||
streamIDStatements: streamID,
|
streamIDStatements: streamID,
|
||||||
|
@ -90,7 +90,7 @@ const selectEventsWithEventIDsSQL = "" +
|
|||||||
|
|
||||||
type currentRoomStateStatements struct {
|
type currentRoomStateStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *StreamIDStatements
|
||||||
upsertRoomStateStmt *sql.Stmt
|
upsertRoomStateStmt *sql.Stmt
|
||||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||||
deleteRoomStateForRoomStmt *sql.Stmt
|
deleteRoomStateForRoomStmt *sql.Stmt
|
||||||
@ -100,7 +100,7 @@ type currentRoomStateStatements struct {
|
|||||||
selectStateEventStmt *sql.Stmt
|
selectStateEventStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
|
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
|
||||||
s := ¤tRoomStateStatements{
|
s := ¤tRoomStateStatements{
|
||||||
db: db,
|
db: db,
|
||||||
streamIDStatements: streamID,
|
streamIDStatements: streamID,
|
||||||
|
@ -59,14 +59,14 @@ const selectMaxInviteIDSQL = "" +
|
|||||||
|
|
||||||
type inviteEventsStatements struct {
|
type inviteEventsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *StreamIDStatements
|
||||||
insertInviteEventStmt *sql.Stmt
|
insertInviteEventStmt *sql.Stmt
|
||||||
selectInviteEventsInRangeStmt *sql.Stmt
|
selectInviteEventsInRangeStmt *sql.Stmt
|
||||||
deleteInviteEventStmt *sql.Stmt
|
deleteInviteEventStmt *sql.Stmt
|
||||||
selectMaxInviteIDStmt *sql.Stmt
|
selectMaxInviteIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
|
func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) {
|
||||||
s := &inviteEventsStatements{
|
s := &inviteEventsStatements{
|
||||||
db: db,
|
db: db,
|
||||||
streamIDStatements: streamID,
|
streamIDStatements: streamID,
|
||||||
|
@ -58,7 +58,7 @@ const insertEventSQL = "" +
|
|||||||
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
|
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
|
||||||
|
|
||||||
const selectEventsSQL = "" +
|
const selectEventsSQL = "" +
|
||||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
|
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id IN ($1)"
|
||||||
|
|
||||||
const selectRecentEventsSQL = "" +
|
const selectRecentEventsSQL = "" +
|
||||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||||
@ -111,9 +111,8 @@ const selectContextAfterEventSQL = "" +
|
|||||||
|
|
||||||
type outputRoomEventsStatements struct {
|
type outputRoomEventsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *StreamIDStatements
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
selectEventsStmt *sql.Stmt
|
|
||||||
selectMaxEventIDStmt *sql.Stmt
|
selectMaxEventIDStmt *sql.Stmt
|
||||||
updateEventJSONStmt *sql.Stmt
|
updateEventJSONStmt *sql.Stmt
|
||||||
deleteEventsForRoomStmt *sql.Stmt
|
deleteEventsForRoomStmt *sql.Stmt
|
||||||
@ -122,7 +121,7 @@ type outputRoomEventsStatements struct {
|
|||||||
selectContextAfterEventStmt *sql.Stmt
|
selectContextAfterEventStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
|
func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) {
|
||||||
s := &outputRoomEventsStatements{
|
s := &outputRoomEventsStatements{
|
||||||
db: db,
|
db: db,
|
||||||
streamIDStatements: streamID,
|
streamIDStatements: streamID,
|
||||||
@ -133,7 +132,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even
|
|||||||
}
|
}
|
||||||
return s, sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertEventStmt, insertEventSQL},
|
{&s.insertEventStmt, insertEventSQL},
|
||||||
{&s.selectEventsStmt, selectEventsSQL},
|
|
||||||
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
|
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
|
||||||
{&s.updateEventJSONStmt, updateEventJSONSQL},
|
{&s.updateEventJSONStmt, updateEventJSONSQL},
|
||||||
{&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL},
|
{&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL},
|
||||||
@ -421,21 +419,43 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
|||||||
// selectEvents returns the events for the given event IDs. If an event is
|
// selectEvents returns the events for the given event IDs. If an event is
|
||||||
// missing from the database, it will be omitted.
|
// missing from the database, it will be omitted.
|
||||||
func (s *outputRoomEventsStatements) SelectEvents(
|
func (s *outputRoomEventsStatements) SelectEvents(
|
||||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool,
|
||||||
) ([]types.StreamEvent, error) {
|
) ([]types.StreamEvent, error) {
|
||||||
var returnEvents []types.StreamEvent
|
iEventIDs := make([]interface{}, len(eventIDs))
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
|
for i := range eventIDs {
|
||||||
for _, eventID := range eventIDs {
|
iEventIDs[i] = eventIDs[i]
|
||||||
rows, err := stmt.QueryContext(ctx, eventID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if streamEvents, err := rowsToStreamEvents(rows); err == nil {
|
|
||||||
returnEvents = append(returnEvents, streamEvents...)
|
|
||||||
}
|
|
||||||
internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
|
||||||
}
|
}
|
||||||
return returnEvents, nil
|
selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1)
|
||||||
|
var rows *sql.Rows
|
||||||
|
var err error
|
||||||
|
if txn != nil {
|
||||||
|
rows, err = txn.QueryContext(ctx, selectSQL, iEventIDs...)
|
||||||
|
} else {
|
||||||
|
rows, err = s.db.QueryContext(ctx, selectSQL, iEventIDs...)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
||||||
|
streamEvents, err := rowsToStreamEvents(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if preserveOrder {
|
||||||
|
var returnEvents []types.StreamEvent
|
||||||
|
eventMap := make(map[string]types.StreamEvent)
|
||||||
|
for _, ev := range streamEvents {
|
||||||
|
eventMap[ev.EventID()] = ev
|
||||||
|
}
|
||||||
|
for _, eventID := range eventIDs {
|
||||||
|
ev, ok := eventMap[eventID]
|
||||||
|
if ok {
|
||||||
|
returnEvents = append(returnEvents, ev)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return returnEvents, nil
|
||||||
|
}
|
||||||
|
return streamEvents, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
|
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
|
||||||
|
@ -66,7 +66,7 @@ const selectMaxPeekIDSQL = "" +
|
|||||||
|
|
||||||
type peekStatements struct {
|
type peekStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *StreamIDStatements
|
||||||
insertPeekStmt *sql.Stmt
|
insertPeekStmt *sql.Stmt
|
||||||
deletePeekStmt *sql.Stmt
|
deletePeekStmt *sql.Stmt
|
||||||
deletePeeksStmt *sql.Stmt
|
deletePeeksStmt *sql.Stmt
|
||||||
@ -75,7 +75,7 @@ type peekStatements struct {
|
|||||||
selectMaxPeekIDStmt *sql.Stmt
|
selectMaxPeekIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) {
|
func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) {
|
||||||
_, err := db.Exec(peeksSchema)
|
_, err := db.Exec(peeksSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -75,7 +75,7 @@ const selectPresenceAfter = "" +
|
|||||||
|
|
||||||
type presenceStatements struct {
|
type presenceStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *StreamIDStatements
|
||||||
upsertPresenceStmt *sql.Stmt
|
upsertPresenceStmt *sql.Stmt
|
||||||
upsertPresenceFromSyncStmt *sql.Stmt
|
upsertPresenceFromSyncStmt *sql.Stmt
|
||||||
selectPresenceForUsersStmt *sql.Stmt
|
selectPresenceForUsersStmt *sql.Stmt
|
||||||
@ -83,7 +83,7 @@ type presenceStatements struct {
|
|||||||
selectPresenceAfterStmt *sql.Stmt
|
selectPresenceAfterStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqlitePresenceTable(db *sql.DB, streamID *streamIDStatements) (*presenceStatements, error) {
|
func NewSqlitePresenceTable(db *sql.DB, streamID *StreamIDStatements) (*presenceStatements, error) {
|
||||||
_, err := db.Exec(presenceSchema)
|
_, err := db.Exec(presenceSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -59,13 +59,13 @@ const selectMaxReceiptIDSQL = "" +
|
|||||||
|
|
||||||
type receiptStatements struct {
|
type receiptStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *StreamIDStatements
|
||||||
upsertReceipt *sql.Stmt
|
upsertReceipt *sql.Stmt
|
||||||
selectRoomReceipts *sql.Stmt
|
selectRoomReceipts *sql.Stmt
|
||||||
selectMaxReceiptID *sql.Stmt
|
selectMaxReceiptID *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) {
|
func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) {
|
||||||
_, err := db.Exec(receiptsSchema)
|
_, err := db.Exec(receiptsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -32,12 +32,12 @@ const increaseStreamIDStmt = "" +
|
|||||||
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" +
|
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" +
|
||||||
" RETURNING stream_id"
|
" RETURNING stream_id"
|
||||||
|
|
||||||
type streamIDStatements struct {
|
type StreamIDStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
increaseStreamIDStmt *sql.Stmt
|
increaseStreamIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
func (s *StreamIDStatements) Prepare(db *sql.DB) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
_, err = db.Exec(streamIDTableSchema)
|
_, err = db.Exec(streamIDTableSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -49,31 +49,31 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
func (s *StreamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||||
err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos)
|
err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
func (s *StreamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||||
err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
|
err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
func (s *StreamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||||
err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos)
|
err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
func (s *StreamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||||
err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
|
err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *streamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||||
err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos)
|
err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos)
|
||||||
return
|
return
|
||||||
|
@ -30,7 +30,7 @@ type SyncServerDatasource struct {
|
|||||||
shared.Database
|
shared.Database
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.Writer
|
writer sqlutil.Writer
|
||||||
streamID streamIDStatements
|
streamID StreamIDStatements
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase creates a new sync server database
|
// NewDatabase creates a new sync server database
|
||||||
@ -49,7 +49,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
|
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
|
||||||
if err = d.streamID.prepare(d.db); err != nil {
|
if err = d.streamID.Prepare(d.db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)
|
accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)
|
||||||
|
@ -3,6 +3,7 @@ package storage_test
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
@ -38,7 +39,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("WriteEvent failed: %s", err)
|
t.Fatalf("WriteEvent failed: %s", err)
|
||||||
}
|
}
|
||||||
fmt.Println("Event ID", ev.EventID(), " spos=", pos, "depth=", ev.Depth())
|
t.Logf("Event ID %s spos=%v depth=%v", ev.EventID(), pos, ev.Depth())
|
||||||
positions = append(positions, pos)
|
positions = append(positions, pos)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -46,7 +47,6 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
|
|||||||
|
|
||||||
func TestWriteEvents(t *testing.T) {
|
func TestWriteEvents(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
t.Parallel()
|
|
||||||
alice := test.NewUser()
|
alice := test.NewUser()
|
||||||
r := test.NewRoom(t, alice)
|
r := test.NewRoom(t, alice)
|
||||||
db, close := MustCreateDatabase(t, dbType)
|
db, close := MustCreateDatabase(t, dbType)
|
||||||
@ -61,84 +61,84 @@ func TestRecentEventsPDU(t *testing.T) {
|
|||||||
db, close := MustCreateDatabase(t, dbType)
|
db, close := MustCreateDatabase(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
alice := test.NewUser()
|
alice := test.NewUser()
|
||||||
var filter gomatrixserverlib.RoomEventFilter
|
// dummy room to make sure SQL queries are filtering on room ID
|
||||||
filter.Limit = 100
|
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
|
||||||
|
|
||||||
|
// actual test room
|
||||||
r := test.NewRoom(t, alice)
|
r := test.NewRoom(t, alice)
|
||||||
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"})
|
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"})
|
||||||
events := r.Events()
|
events := r.Events()
|
||||||
positions := MustWriteEvents(t, db, events)
|
positions := MustWriteEvents(t, db, events)
|
||||||
|
|
||||||
|
// dummy room to make sure SQL queries are filtering on room ID
|
||||||
|
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
|
||||||
|
|
||||||
latest, err := db.MaxStreamPositionForPDUs(ctx)
|
latest, err := db.MaxStreamPositionForPDUs(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
|
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
Name string
|
Name string
|
||||||
From types.StreamPosition
|
From types.StreamPosition
|
||||||
To types.StreamPosition
|
To types.StreamPosition
|
||||||
WantEvents []*gomatrixserverlib.HeaderedEvent
|
Limit int
|
||||||
WantLimited bool
|
ReverseOrder bool
|
||||||
|
WantEvents []*gomatrixserverlib.HeaderedEvent
|
||||||
|
WantLimited bool
|
||||||
}{
|
}{
|
||||||
// The purpose of this test is to make sure that incremental syncs are including up to the latest events.
|
// The purpose of this test is to make sure that incremental syncs are including up to the latest events.
|
||||||
// It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event.
|
// It's a basic sanity test that sync works. It creates a streaming position that is on the penultimate event.
|
||||||
// It makes sure the response includes the final event.
|
// It makes sure the response includes the final event.
|
||||||
{
|
{
|
||||||
Name: "IncrementalSync penultimate",
|
Name: "penultimate",
|
||||||
From: positions[len(positions)-2], // pretend we are at the penultimate event
|
From: positions[len(positions)-2], // pretend we are at the penultimate event
|
||||||
To: latest,
|
To: latest,
|
||||||
|
Limit: 100,
|
||||||
WantEvents: events[len(events)-1:],
|
WantEvents: events[len(events)-1:],
|
||||||
WantLimited: false,
|
WantLimited: false,
|
||||||
},
|
},
|
||||||
/*
|
// The purpose of this test is to check that limits can be applied and work.
|
||||||
// The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the
|
// This is critical for big rooms hence the test here.
|
||||||
// number of returned events. This is critical for big rooms hence the test here.
|
{
|
||||||
{
|
Name: "limited",
|
||||||
Name: "IncrementalSync limited",
|
From: 0,
|
||||||
DoSync: func() (*types.Response, error) {
|
To: latest,
|
||||||
from := types.StreamingToken{ // pretend we are 10 events behind
|
Limit: 1,
|
||||||
PDUPosition: positions[len(positions)-11],
|
WantEvents: events[len(events)-1:],
|
||||||
}
|
WantLimited: true,
|
||||||
res := types.NewResponse()
|
},
|
||||||
// limit is set to 5
|
// The purpose of this test is to check that we can return every event with a high
|
||||||
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
|
// enough limit
|
||||||
},
|
{
|
||||||
// want the last 5 events, NOT the last 10.
|
Name: "large limited",
|
||||||
WantTimeline: events[len(events)-5:],
|
From: 0,
|
||||||
},
|
To: latest,
|
||||||
// The purpose of this test is to check that CompleteSync returns all the current state as well as
|
Limit: 100,
|
||||||
// honouring the `numRecentEventsPerRoom` value
|
WantEvents: events,
|
||||||
{
|
WantLimited: false,
|
||||||
Name: "CompleteSync limited",
|
},
|
||||||
DoSync: func() (*types.Response, error) {
|
// The purpose of this test is to check that we can return events in reverse order
|
||||||
res := types.NewResponse()
|
{
|
||||||
// limit set to 5
|
Name: "reverse",
|
||||||
return db.CompleteSync(ctx, res, testUserDeviceA, 5)
|
From: positions[len(positions)-3], // 2 events back
|
||||||
},
|
To: latest,
|
||||||
// want the last 5 events
|
Limit: 100,
|
||||||
WantTimeline: events[len(events)-5:],
|
ReverseOrder: true,
|
||||||
// want all state for the room
|
WantEvents: test.Reversed(events[len(events)-2:]),
|
||||||
WantState: state,
|
WantLimited: false,
|
||||||
},
|
},
|
||||||
// The purpose of this test is to check that CompleteSync can return everything with a high enough
|
|
||||||
// `numRecentEventsPerRoom`.
|
|
||||||
{
|
|
||||||
Name: "CompleteSync",
|
|
||||||
DoSync: func() (*types.Response, error) {
|
|
||||||
res := types.NewResponse()
|
|
||||||
return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1)
|
|
||||||
},
|
|
||||||
WantTimeline: events,
|
|
||||||
// We want no state at all as that field in /sync is the delta between the token (beginning of time)
|
|
||||||
// and the START of the timeline.
|
|
||||||
}, */
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for i := range testCases {
|
||||||
|
tc := testCases[i]
|
||||||
t.Run(tc.Name, func(st *testing.T) {
|
t.Run(tc.Name, func(st *testing.T) {
|
||||||
|
var filter gomatrixserverlib.RoomEventFilter
|
||||||
|
filter.Limit = tc.Limit
|
||||||
gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{
|
gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{
|
||||||
From: tc.From,
|
From: tc.From,
|
||||||
To: tc.To,
|
To: tc.To,
|
||||||
}, &filter, true, true)
|
}, &filter, !tc.ReverseOrder, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
st.Fatalf("failed to do sync: %s", err)
|
st.Fatalf("failed to do sync: %s", err)
|
||||||
}
|
}
|
||||||
@ -148,100 +148,48 @@ func TestRecentEventsPDU(t *testing.T) {
|
|||||||
if len(gotEvents) != len(tc.WantEvents) {
|
if len(gotEvents) != len(tc.WantEvents) {
|
||||||
st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents))
|
st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents))
|
||||||
}
|
}
|
||||||
|
for j := range gotEvents {
|
||||||
|
if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) {
|
||||||
|
st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON()))
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
db := MustCreateDatabase(t)
|
|
||||||
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
|
||||||
positions := MustWriteEvents(t, db, events)
|
|
||||||
latest, err := db.SyncPosition(ctx)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get SyncPosition: %s", err)
|
|
||||||
}
|
|
||||||
from := types.StreamingToken{
|
|
||||||
PDUPosition: positions[len(positions)-2],
|
|
||||||
}
|
|
||||||
|
|
||||||
res := types.NewResponse()
|
|
||||||
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to IncrementalSync with latest token")
|
|
||||||
}
|
|
||||||
roomRes, ok := res.Rooms.Join[testRoomID]
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res)
|
|
||||||
}
|
|
||||||
// returns the last event "Message 10"
|
|
||||||
assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:]))
|
|
||||||
|
|
||||||
prev := roomRes.Timeline.PrevBatch.String()
|
|
||||||
if prev == "" {
|
|
||||||
t.Fatalf("IncrementalSync expected prev_batch token")
|
|
||||||
}
|
|
||||||
prevBatchToken, err := types.NewTopologyTokenFromString(prev)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to NewTopologyTokenFromString : %s", err)
|
|
||||||
}
|
|
||||||
// backpaginate 5 messages starting at the latest position.
|
|
||||||
// head towards the beginning of time
|
|
||||||
to := types.TopologyToken{}
|
|
||||||
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetEventsInRange returned an error: %s", err)
|
|
||||||
}
|
|
||||||
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
|
|
||||||
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-6:len(events)-1]))
|
|
||||||
}
|
|
||||||
|
|
||||||
// The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token.
|
|
||||||
func TestGetEventsInRangeWithStreamToken(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
db := MustCreateDatabase(t)
|
|
||||||
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
|
||||||
MustWriteEvents(t, db, events)
|
|
||||||
latest, err := db.SyncPosition(ctx)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get SyncPosition: %s", err)
|
|
||||||
}
|
|
||||||
// head towards the beginning of time
|
|
||||||
to := types.StreamingToken{}
|
|
||||||
|
|
||||||
// backpaginate 5 messages starting at the latest position.
|
|
||||||
paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetEventsInRange returned an error: %s", err)
|
|
||||||
}
|
|
||||||
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
|
|
||||||
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
|
|
||||||
}
|
|
||||||
|
|
||||||
// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token
|
// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token
|
||||||
func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
|
func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
|
||||||
t.Parallel()
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
db := MustCreateDatabase(t)
|
db, close := MustCreateDatabase(t, dbType)
|
||||||
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
|
defer close()
|
||||||
MustWriteEvents(t, db, events)
|
alice := test.NewUser()
|
||||||
from, err := db.MaxTopologicalPosition(ctx, testRoomID)
|
r := test.NewRoom(t, alice)
|
||||||
if err != nil {
|
for i := 0; i < 10; i++ {
|
||||||
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
|
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})
|
||||||
}
|
}
|
||||||
// head towards the beginning of time
|
events := r.Events()
|
||||||
to := types.TopologyToken{}
|
_ = MustWriteEvents(t, db, events)
|
||||||
|
|
||||||
// backpaginate 5 messages starting at the latest position.
|
from, err := db.MaxTopologicalPosition(ctx, r.ID)
|
||||||
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true)
|
if err != nil {
|
||||||
if err != nil {
|
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
|
||||||
t.Fatalf("GetEventsInRange returned an error: %s", err)
|
}
|
||||||
}
|
t.Logf("max topo pos = %+v", from)
|
||||||
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
|
// head towards the beginning of time
|
||||||
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
|
to := types.TopologyToken{}
|
||||||
|
|
||||||
|
// backpaginate 5 messages starting at the latest position.
|
||||||
|
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, 5, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
|
||||||
|
}
|
||||||
|
gots := db.StreamEventsToEvents(nil, paginatedEvents)
|
||||||
|
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth.
|
// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth.
|
||||||
// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent
|
// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent
|
||||||
// will appear FIRST when going backwards. This test creates a DAG like:
|
// will appear FIRST when going backwards. This test creates a DAG like:
|
||||||
@ -651,12 +599,4 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ
|
|||||||
tok.Decrement()
|
tok.Decrement()
|
||||||
return &tok
|
return &tok
|
||||||
}
|
}
|
||||||
|
|
||||||
func reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
|
|
||||||
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
|
|
||||||
for i := 0; i < len(in); i++ {
|
|
||||||
out[i] = in[len(in)-i-1]
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
*/
|
*/
|
||||||
|
@ -59,7 +59,7 @@ type Events interface {
|
|||||||
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
|
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
|
||||||
// SelectEarlyEvents returns the earliest events in the given room.
|
// SelectEarlyEvents returns the earliest events in the given room.
|
||||||
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
|
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
|
||||||
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)
|
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool) ([]types.StreamEvent, error)
|
||||||
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
|
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
|
||||||
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.
|
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.
|
||||||
DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
||||||
|
82
syncapi/storage/tables/output_room_events_test.go
Normal file
82
syncapi/storage/tables/output_room_events_test.go
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
package tables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) {
|
||||||
|
t.Helper()
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open db: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tab tables.Events
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
tab, err = postgres.NewPostgresEventsTable(db)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
var stream sqlite3.StreamIDStatements
|
||||||
|
if err = stream.Prepare(db); err != nil {
|
||||||
|
t.Fatalf("failed to prepare stream stmts: %s", err)
|
||||||
|
}
|
||||||
|
tab, err = sqlite3.NewSqliteEventsTable(db, &stream)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to make new table: %s", err)
|
||||||
|
}
|
||||||
|
return tab, db, close
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOutputRoomEventsTable(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
alice := test.NewUser()
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
tab, db, close := newOutputRoomEventsTable(t, dbType)
|
||||||
|
defer close()
|
||||||
|
events := room.Events()
|
||||||
|
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
|
||||||
|
for _, ev := range events {
|
||||||
|
_, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to InsertEvent: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// order = 2,0,3,1
|
||||||
|
wantEventIDs := []string{
|
||||||
|
events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(),
|
||||||
|
}
|
||||||
|
gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to SelectEvents: %s", err)
|
||||||
|
}
|
||||||
|
gotEventIDs := make([]string, len(gotEvents))
|
||||||
|
for i := range gotEvents {
|
||||||
|
gotEventIDs[i] = gotEvents[i].EventID()
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(gotEventIDs, wantEventIDs) {
|
||||||
|
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
91
syncapi/storage/tables/topology_test.go
Normal file
91
syncapi/storage/tables/topology_test.go
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
package tables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.DB, func()) {
|
||||||
|
t.Helper()
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open db: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tab tables.Topology
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
tab, err = postgres.NewPostgresTopologyTable(db)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
tab, err = sqlite3.NewSqliteTopologyTable(db)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to make new table: %s", err)
|
||||||
|
}
|
||||||
|
return tab, db, close
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTopologyTable(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
alice := test.NewUser()
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
tab, db, close := newTopologyTable(t, dbType)
|
||||||
|
defer close()
|
||||||
|
events := room.Events()
|
||||||
|
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
|
||||||
|
var highestPos types.StreamPosition
|
||||||
|
for i, ev := range events {
|
||||||
|
topoPos, err := tab.InsertEventInTopology(ctx, txn, ev, types.StreamPosition(i))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to InsertEventInTopology: %s", err)
|
||||||
|
}
|
||||||
|
// topo pos = depth, depth starts at 1, hence 1+i
|
||||||
|
if topoPos != types.StreamPosition(1+i) {
|
||||||
|
return fmt.Errorf("got topo pos %d want %d", topoPos, 1+i)
|
||||||
|
}
|
||||||
|
highestPos = topoPos + 1
|
||||||
|
}
|
||||||
|
// check ordering works without limit
|
||||||
|
eventIDs, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||||
|
}
|
||||||
|
test.AssertEventIDsEqual(t, eventIDs, events[:])
|
||||||
|
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||||
|
}
|
||||||
|
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:]))
|
||||||
|
// check ordering works with limit
|
||||||
|
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||||
|
}
|
||||||
|
test.AssertEventIDsEqual(t, eventIDs, events[:3])
|
||||||
|
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
|
||||||
|
}
|
||||||
|
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:]))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
@ -121,6 +121,7 @@ func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) {
|
|||||||
for dbName, dbType := range dbs {
|
for dbName, dbType := range dbs {
|
||||||
dbt := dbType
|
dbt := dbType
|
||||||
t.Run(dbName, func(tt *testing.T) {
|
t.Run(dbName, func(tt *testing.T) {
|
||||||
|
tt.Parallel()
|
||||||
testFn(tt, dbt)
|
testFn(tt, dbt)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,9 @@
|
|||||||
package test
|
package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
@ -49,3 +51,40 @@ func WithUnsigned(unsigned interface{}) eventModifier {
|
|||||||
e.unsigned = unsigned
|
e.unsigned = unsigned
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reverse a list of events
|
||||||
|
func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
|
||||||
|
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
|
||||||
|
for i := 0; i < len(in); i++ {
|
||||||
|
out[i] = in[len(in)-i-1]
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func AssertEventIDsEqual(t *testing.T, gotEventIDs []string, wants []*gomatrixserverlib.HeaderedEvent) {
|
||||||
|
t.Helper()
|
||||||
|
if len(gotEventIDs) != len(wants) {
|
||||||
|
t.Fatalf("length mismatch: got %d events, want %d", len(gotEventIDs), len(wants))
|
||||||
|
}
|
||||||
|
for i := range wants {
|
||||||
|
w := wants[i].EventID()
|
||||||
|
g := gotEventIDs[i]
|
||||||
|
if w != g {
|
||||||
|
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AssertEventsEqual(t *testing.T, gots, wants []*gomatrixserverlib.HeaderedEvent) {
|
||||||
|
t.Helper()
|
||||||
|
if len(gots) != len(wants) {
|
||||||
|
t.Fatalf("length mismatch: got %d events, want %d", len(gots), len(wants))
|
||||||
|
}
|
||||||
|
for i := range wants {
|
||||||
|
w := wants[i].JSON()
|
||||||
|
g := gots[i].JSON()
|
||||||
|
if !bytes.Equal(w, g) {
|
||||||
|
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user