bugfix: fix sytest 155 by actually returning depth+1 and not 0

This commit is contained in:
Kegan Dougal 2020-03-06 14:31:12 +00:00
parent a97b8eafd4
commit 87283e9de7
2 changed files with 10 additions and 4 deletions

View File

@ -111,7 +111,6 @@ type eventStatements struct {
bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt
selectMaxEventDepthStmt *sql.Stmt
}
func (s *eventStatements) prepare(db *sql.DB) (err error) {
@ -135,7 +134,6 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
}.prepare(db)
}
@ -462,8 +460,12 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e
func (s *eventStatements) selectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) {
var result int64
selectStmt := common.TxStmt(txn, s.selectMaxEventDepthStmt)
err := selectStmt.QueryRowContext(ctx, eventNIDsAsArray(eventNIDs)).Scan(&result)
iEventIDs := make([]interface{}, len(eventNIDs))
for i, v := range eventNIDs {
iEventIDs[i] = v
}
sqlStr := strings.Replace(selectMaxEventDepthSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
err := txn.QueryRowContext(ctx, sqlStr, iEventIDs...).Scan(&result)
if err != nil {
return 0, err
}

View File

@ -16,6 +16,7 @@ package routing
import (
"context"
"fmt"
"net/http"
"sort"
"strconv"
@ -176,6 +177,7 @@ func (r *messagesReq) retrieveEvents() (
r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering,
)
if err != nil {
err = fmt.Errorf("GetEventsInRange: %s", err)
return
}
@ -226,12 +228,14 @@ func (r *messagesReq) retrieveEvents() (
r.ctx, events[0].EventID(),
)
if err != nil {
err = fmt.Errorf("EventPositionInTopology: for start event %s: %s", events[0].EventID(), err)
return
}
endPos, err := r.db.EventPositionInTopology(
r.ctx, events[len(events)-1].EventID(),
)
if err != nil {
err = fmt.Errorf("EventPositionInTopology: for end event %s: %s", events[len(events)-1].EventID(), err)
return
}
// Generate pagination tokens to send to the client using the positions