Fix prev_batch tokens (#999)

This commit is contained in:
Kegsay 2020-05-01 12:41:38 +01:00 committed by GitHub
parent b28674435e
commit 17e046f18f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 83 additions and 75 deletions

View File

@ -94,9 +94,6 @@ const selectEarlyEventsSQL = "" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id ASC LIMIT $4" " ORDER BY id ASC LIMIT $4"
const selectStreamPositionForEventIDSQL = "" +
"SELECT id FROM syncapi_output_room_events WHERE event_id = $1"
const selectMaxEventIDSQL = "" + const selectMaxEventIDSQL = "" +
"SELECT MAX(id) FROM syncapi_output_room_events" "SELECT MAX(id) FROM syncapi_output_room_events"
@ -114,14 +111,13 @@ const selectStateInRangeSQL = "" +
" LIMIT $8" " LIMIT $8"
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt selectEventsStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt
selectRecentEventsStmt *sql.Stmt selectRecentEventsStmt *sql.Stmt
selectRecentEventsForSyncStmt *sql.Stmt selectRecentEventsForSyncStmt *sql.Stmt
selectEarlyEventsStmt *sql.Stmt selectEarlyEventsStmt *sql.Stmt
selectStateInRangeStmt *sql.Stmt selectStateInRangeStmt *sql.Stmt
selectStreamPositionForEventIDStmt *sql.Stmt
} }
func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
@ -150,18 +146,9 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil { if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil {
return return
} }
if s.selectStreamPositionForEventIDStmt, err = db.Prepare(selectStreamPositionForEventIDSQL); err != nil {
return
}
return return
} }
func (s *outputRoomEventsStatements) selectStreamPositionForEventID(ctx context.Context, eventID string) (types.StreamPosition, error) {
var id int64
err := s.selectStreamPositionForEventIDStmt.QueryRowContext(ctx, eventID).Scan(&id)
return types.StreamPosition(id), err
}
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
// two positions, only the most recent state is returned. // two positions, only the most recent state is returned.

View File

@ -60,7 +60,7 @@ const selectEventIDsInRangeDESCSQL = "" +
" ORDER BY topological_position DESC, stream_position DESC LIMIT $6" " ORDER BY topological_position DESC, stream_position DESC LIMIT $6"
const selectPositionInTopologySQL = "" + const selectPositionInTopologySQL = "" +
"SELECT topological_position FROM syncapi_output_room_events_topology" + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" +
" WHERE event_id = $1" " WHERE event_id = $1"
// Select the max topological position for the room, then sort by stream position and take the highest, // Select the max topological position for the room, then sort by stream position and take the highest,
@ -163,8 +163,8 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange(
// topology of the room it belongs to. // topology of the room it belongs to.
func (s *outputRoomEventsTopologyStatements) selectPositionInTopology( func (s *outputRoomEventsTopologyStatements) selectPositionInTopology(
ctx context.Context, eventID string, ctx context.Context, eventID string,
) (pos types.StreamPosition, err error) { ) (pos, spos types.StreamPosition, err error) {
err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos) err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos)
return return
} }

View File

@ -320,12 +320,7 @@ func (d *SyncServerDatasource) EventsAtTopologicalPosition(
func (d *SyncServerDatasource) EventPositionInTopology( func (d *SyncServerDatasource) EventPositionInTopology(
ctx context.Context, eventID string, ctx context.Context, eventID string,
) (depth types.StreamPosition, stream types.StreamPosition, err error) { ) (depth types.StreamPosition, stream types.StreamPosition, err error) {
depth, err = d.topology.selectPositionInTopology(ctx, eventID) return d.topology.selectPositionInTopology(ctx, eventID)
if err != nil {
return
}
stream, err = d.events.selectStreamPositionForEventID(ctx, eventID)
return
} }
func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) {
@ -591,8 +586,8 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
// Retrieve the backward topology position, i.e. the position of the // Retrieve the backward topology position, i.e. the position of the
// oldest event in the room's topology. // oldest event in the room's topology.
var backwardTopologyPos types.StreamPosition var backwardTopologyPos, backwardStreamPos types.StreamPosition
backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID()) backwardTopologyPos, backwardStreamPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID())
if backwardTopologyPos-1 <= 0 { if backwardTopologyPos-1 <= 0 {
backwardTopologyPos = types.StreamPosition(1) backwardTopologyPos = types.StreamPosition(1)
} else { } else {
@ -605,7 +600,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeTopology, backwardTopologyPos, 0, types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos,
).String() ).String()
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = true jr.Timeline.Limited = true
@ -720,9 +715,9 @@ func (d *SyncServerDatasource) addInvitesToResponse(
func (d *SyncServerDatasource) getBackwardTopologyPos( func (d *SyncServerDatasource) getBackwardTopologyPos(
ctx context.Context, ctx context.Context,
events []types.StreamEvent, events []types.StreamEvent,
) (pos types.StreamPosition) { ) (pos, spos types.StreamPosition) {
if len(events) > 0 { if len(events) > 0 {
pos, _ = d.topology.selectPositionInTopology(ctx, events[0].EventID()) pos, spos, _ = d.topology.selectPositionInTopology(ctx, events[0].EventID())
} }
if pos-1 <= 0 { if pos-1 <= 0 {
pos = types.StreamPosition(1) pos = types.StreamPosition(1)
@ -761,14 +756,14 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
} }
recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) recentEvents := d.StreamEventsToEvents(device, recentStreamEvents)
delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back
backwardTopologyPos := d.getBackwardTopologyPos(ctx, recentStreamEvents) backwardTopologyPos, backwardStreamPos := d.getBackwardTopologyPos(ctx, recentStreamEvents)
switch delta.membership { switch delta.membership {
case gomatrixserverlib.Join: case gomatrixserverlib.Join:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeTopology, backwardTopologyPos, 0, types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos,
).String() ).String()
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
@ -781,7 +776,7 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
// no longer in the room. // no longer in the room.
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeTopology, backwardTopologyPos, 0, types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos,
).String() ).String()
lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true

View File

@ -74,9 +74,6 @@ const selectEarlyEventsSQL = "" +
const selectMaxEventIDSQL = "" + const selectMaxEventIDSQL = "" +
"SELECT MAX(id) FROM syncapi_output_room_events" "SELECT MAX(id) FROM syncapi_output_room_events"
const selectStreamPositionForEventIDSQL = "" +
"SELECT id FROM syncapi_output_room_events WHERE event_id = $1"
// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). // In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
/* /*
$1 = oldPos, $1 = oldPos,
@ -102,15 +99,14 @@ const selectStateInRangeSQL = "" +
" LIMIT $8" // limit " LIMIT $8" // limit
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt selectEventsStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt
selectRecentEventsStmt *sql.Stmt selectRecentEventsStmt *sql.Stmt
selectRecentEventsForSyncStmt *sql.Stmt selectRecentEventsForSyncStmt *sql.Stmt
selectEarlyEventsStmt *sql.Stmt selectEarlyEventsStmt *sql.Stmt
selectStateInRangeStmt *sql.Stmt selectStateInRangeStmt *sql.Stmt
selectStreamPositionForEventIDStmt *sql.Stmt
} }
func (s *outputRoomEventsStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { func (s *outputRoomEventsStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) {
@ -140,18 +136,9 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB, streamID *streamIDState
if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil { if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil {
return return
} }
if s.selectStreamPositionForEventIDStmt, err = db.Prepare(selectStreamPositionForEventIDSQL); err != nil {
return
}
return return
} }
func (s *outputRoomEventsStatements) selectStreamPositionForEventID(ctx context.Context, eventID string) (types.StreamPosition, error) {
var id int64
err := s.selectStreamPositionForEventIDStmt.QueryRowContext(ctx, eventID).Scan(&id)
return types.StreamPosition(id), err
}
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
// two positions, only the most recent state is returned. // two positions, only the most recent state is returned.

View File

@ -57,7 +57,7 @@ const selectEventIDsInRangeDESCSQL = "" +
" ORDER BY topological_position DESC, stream_position DESC LIMIT $6" " ORDER BY topological_position DESC, stream_position DESC LIMIT $6"
const selectPositionInTopologySQL = "" + const selectPositionInTopologySQL = "" +
"SELECT topological_position FROM syncapi_output_room_events_topology" + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" +
" WHERE event_id = $1" " WHERE event_id = $1"
const selectMaxPositionInTopologySQL = "" + const selectMaxPositionInTopologySQL = "" +
@ -157,9 +157,9 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange(
// topology of the room it belongs to. // topology of the room it belongs to.
func (s *outputRoomEventsTopologyStatements) selectPositionInTopology( func (s *outputRoomEventsTopologyStatements) selectPositionInTopology(
ctx context.Context, txn *sql.Tx, eventID string, ctx context.Context, txn *sql.Tx, eventID string,
) (pos types.StreamPosition, err error) { ) (pos types.StreamPosition, spos types.StreamPosition, err error) {
stmt := common.TxStmt(txn, s.selectPositionInTopologyStmt) stmt := common.TxStmt(txn, s.selectPositionInTopologyStmt)
err = stmt.QueryRowContext(ctx, eventID).Scan(&pos) err = stmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos)
return return
} }

View File

@ -374,12 +374,7 @@ func (d *SyncServerDatasource) EventsAtTopologicalPosition(
func (d *SyncServerDatasource) EventPositionInTopology( func (d *SyncServerDatasource) EventPositionInTopology(
ctx context.Context, eventID string, ctx context.Context, eventID string,
) (depth types.StreamPosition, stream types.StreamPosition, err error) { ) (depth types.StreamPosition, stream types.StreamPosition, err error) {
depth, err = d.topology.selectPositionInTopology(ctx, nil, eventID) return d.topology.selectPositionInTopology(ctx, nil, eventID)
if err != nil {
return
}
stream, err = d.events.selectStreamPositionForEventID(ctx, eventID)
return
} }
// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. // SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet.
@ -657,8 +652,8 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
// Retrieve the backward topology position, i.e. the position of the // Retrieve the backward topology position, i.e. the position of the
// oldest event in the room's topology. // oldest event in the room's topology.
var backwardTopologyPos types.StreamPosition var backwardTopologyPos, backwardTopologyStreamPos types.StreamPosition
backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID()) backwardTopologyPos, backwardTopologyStreamPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID())
if backwardTopologyPos-1 <= 0 { if backwardTopologyPos-1 <= 0 {
backwardTopologyPos = types.StreamPosition(1) backwardTopologyPos = types.StreamPosition(1)
} else { } else {
@ -671,7 +666,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeTopology, backwardTopologyPos, 0, types.PaginationTokenTypeTopology, backwardTopologyPos, backwardTopologyStreamPos,
).String() ).String()
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = true jr.Timeline.Limited = true
@ -818,10 +813,11 @@ func (d *SyncServerDatasource) addInvitesToResponse(
func (d *SyncServerDatasource) getBackwardTopologyPos( func (d *SyncServerDatasource) getBackwardTopologyPos(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
events []types.StreamEvent, events []types.StreamEvent,
) (pos types.StreamPosition) { ) (pos, spos types.StreamPosition) {
if len(events) > 0 { if len(events) > 0 {
pos, _ = d.topology.selectPositionInTopology(ctx, txn, events[0].EventID()) pos, spos, _ = d.topology.selectPositionInTopology(ctx, txn, events[0].EventID())
} }
// TODO: I have no idea what this is doing.
if pos-1 <= 0 { if pos-1 <= 0 {
pos = types.StreamPosition(1) pos = types.StreamPosition(1)
} else { } else {
@ -859,14 +855,14 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
} }
recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) recentEvents := d.StreamEventsToEvents(device, recentStreamEvents)
delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents)
backwardTopologyPos := d.getBackwardTopologyPos(ctx, txn, recentStreamEvents) backwardTopologyPos, backwardStreamPos := d.getBackwardTopologyPos(ctx, txn, recentStreamEvents)
switch delta.membership { switch delta.membership {
case gomatrixserverlib.Join: case gomatrixserverlib.Join:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeTopology, backwardTopologyPos, 0, types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos,
).String() ).String()
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
@ -879,7 +875,7 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
// no longer in the room. // no longer in the room.
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeTopology, backwardTopologyPos, 0, types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos,
).String() ).String()
lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true

View File

@ -220,6 +220,49 @@ func TestSyncResponse(t *testing.T) {
} }
} }
func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
t.Parallel()
db := MustCreateDatabase(t)
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
positions := MustWriteEvents(t, db, events)
latest, err := db.SyncPosition(ctx)
if err != nil {
t.Fatalf("failed to get SyncPosition: %s", err)
}
from := types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeStream, positions[len(positions)-2], types.StreamPosition(0),
)
res, err := db.IncrementalSync(ctx, testUserDeviceA, *from, latest, 5, false)
if err != nil {
t.Fatalf("failed to IncrementalSync with latest token")
}
roomRes, ok := res.Rooms.Join[testRoomID]
if !ok {
t.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res)
}
// returns the last event "Message 10"
assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:]))
prev := roomRes.Timeline.PrevBatch
if prev == "" {
t.Fatalf("IncrementalSync expected prev_batch token")
}
prevBatchToken, err := types.NewPaginationTokenFromString(prev)
if err != nil {
t.Fatalf("failed to NewPaginationTokenFromString : %s", err)
}
// backpaginate 5 messages starting at the latest position.
// head towards the beginning of time
to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0)
paginatedEvents, err := db.GetEventsInRange(ctx, prevBatchToken, to, testRoomID, 5, true)
if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err)
}
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-6:len(events)-1]))
}
// The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token. // The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token.
func TestGetEventsInRangeWithStreamToken(t *testing.T) { func TestGetEventsInRangeWithStreamToken(t *testing.T) {
t.Parallel() t.Parallel()