mirror of
https://github.com/1f349/dendrite.git
synced 2024-11-22 11:41:38 +00:00
Fix room summary returning wrong heroes (#2930)
This should fix #2910. Probably makes Sytest/Complement a bit upset, since this not using `sort.Strings` anymore.
This commit is contained in:
parent
25dfbc6ec3
commit
0491a8e343
@ -45,7 +45,7 @@ type DatabaseTransaction interface {
|
||||
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
|
||||
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
|
||||
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
|
||||
GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error)
|
||||
GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error)
|
||||
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
|
||||
GetBackwardTopologyPos(ctx context.Context, events []*gomatrixserverlib.HeaderedEvent) (types.TopologyToken, error)
|
||||
PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error)
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
@ -110,6 +111,15 @@ const selectSharedUsersSQL = "" +
|
||||
" SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
|
||||
") AND type = 'm.room.member' AND state_key = ANY($2) AND membership IN ('join', 'invite');"
|
||||
|
||||
const selectMembershipCount = `SELECT count(*) FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND membership = $2`
|
||||
|
||||
const selectRoomHeroes = `
|
||||
SELECT state_key FROM syncapi_current_room_state
|
||||
WHERE type = 'm.room.member' AND room_id = $1 AND membership = ANY($2) AND state_key != $3
|
||||
ORDER BY added_at, state_key
|
||||
LIMIT 5
|
||||
`
|
||||
|
||||
type currentRoomStateStatements struct {
|
||||
upsertRoomStateStmt *sql.Stmt
|
||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||
@ -122,6 +132,8 @@ type currentRoomStateStatements struct {
|
||||
selectEventsWithEventIDsStmt *sql.Stmt
|
||||
selectStateEventStmt *sql.Stmt
|
||||
selectSharedUsersStmt *sql.Stmt
|
||||
selectMembershipCountStmt *sql.Stmt
|
||||
selectRoomHeroesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
|
||||
@ -141,40 +153,21 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteRoomStateForRoomStmt, err = db.Prepare(deleteRoomStateForRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectSharedUsersStmt, err = db.Prepare(selectSharedUsersSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertRoomStateStmt, upsertRoomStateSQL},
|
||||
{&s.deleteRoomStateByEventIDStmt, deleteRoomStateByEventIDSQL},
|
||||
{&s.deleteRoomStateForRoomStmt, deleteRoomStateForRoomSQL},
|
||||
{&s.selectRoomIDsWithMembershipStmt, selectRoomIDsWithMembershipSQL},
|
||||
{&s.selectRoomIDsWithAnyMembershipStmt, selectRoomIDsWithAnyMembershipSQL},
|
||||
{&s.selectCurrentStateStmt, selectCurrentStateSQL},
|
||||
{&s.selectJoinedUsersStmt, selectJoinedUsersSQL},
|
||||
{&s.selectJoinedUsersInRoomStmt, selectJoinedUsersInRoomSQL},
|
||||
{&s.selectEventsWithEventIDsStmt, selectEventsWithEventIDsSQL},
|
||||
{&s.selectStateEventStmt, selectStateEventSQL},
|
||||
{&s.selectSharedUsersStmt, selectSharedUsersSQL},
|
||||
{&s.selectMembershipCountStmt, selectMembershipCount},
|
||||
{&s.selectRoomHeroesStmt, selectRoomHeroes},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
|
||||
@ -447,3 +440,34 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *currentRoomStateStatements) SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomHeroesStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomID, pq.StringArray(memberships), excludeUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomHeroesStmt: rows.close() failed")
|
||||
|
||||
var stateKey string
|
||||
result := make([]string, 0, 5)
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&stateKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, stateKey)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *currentRoomStateStatements) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (count int, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt)
|
||||
err = stmt.QueryRowContext(ctx, roomID, membership).Scan(&count)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
@ -19,10 +19,8 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
@ -64,9 +62,6 @@ const selectMembershipCountSQL = "" +
|
||||
" SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" +
|
||||
") t WHERE t.membership = $3"
|
||||
|
||||
const selectHeroesSQL = "" +
|
||||
"SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5"
|
||||
|
||||
const selectMembershipBeforeSQL = "" +
|
||||
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
|
||||
|
||||
@ -81,7 +76,6 @@ WHERE ($3::text IS NULL OR t.membership = $3)
|
||||
type membershipsStatements struct {
|
||||
upsertMembershipStmt *sql.Stmt
|
||||
selectMembershipCountStmt *sql.Stmt
|
||||
selectHeroesStmt *sql.Stmt
|
||||
selectMembershipForUserStmt *sql.Stmt
|
||||
selectMembersStmt *sql.Stmt
|
||||
}
|
||||
@ -95,7 +89,6 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertMembershipStmt, upsertMembershipSQL},
|
||||
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
|
||||
{&s.selectHeroesStmt, selectHeroesSQL},
|
||||
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
|
||||
{&s.selectMembersStmt, selectMembersSQL},
|
||||
}.Prepare(db)
|
||||
@ -129,26 +122,6 @@ func (s *membershipsStatements) SelectMembershipCount(
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipsStatements) SelectHeroes(
|
||||
ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string,
|
||||
) (heroes []string, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectHeroesStmt)
|
||||
var rows *sql.Rows
|
||||
rows, err = stmt.QueryContext(ctx, roomID, userID, pq.StringArray(memberships))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed")
|
||||
var hero string
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&hero); err != nil {
|
||||
return
|
||||
}
|
||||
heroes = append(heroes, hero)
|
||||
}
|
||||
return heroes, rows.Err()
|
||||
}
|
||||
|
||||
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
|
||||
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
|
||||
// string as the membership.
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
@ -92,8 +93,61 @@ func (d *DatabaseTransaction) MembershipCount(ctx context.Context, roomID, membe
|
||||
return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos)
|
||||
}
|
||||
|
||||
func (d *DatabaseTransaction) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) {
|
||||
return d.Memberships.SelectHeroes(ctx, d.txn, roomID, userID, memberships)
|
||||
func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID string) (*types.Summary, error) {
|
||||
summary := &types.Summary{Heroes: []string{}}
|
||||
|
||||
joinCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Join)
|
||||
if err != nil {
|
||||
return summary, err
|
||||
}
|
||||
inviteCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Invite)
|
||||
if err != nil {
|
||||
return summary, err
|
||||
}
|
||||
summary.InvitedMemberCount = &inviteCount
|
||||
summary.JoinedMemberCount = &joinCount
|
||||
|
||||
// Get the room name and canonical alias, if any
|
||||
filter := gomatrixserverlib.DefaultStateFilter()
|
||||
filterTypes := []string{gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomCanonicalAlias}
|
||||
filterRooms := []string{roomID}
|
||||
|
||||
filter.Types = &filterTypes
|
||||
filter.Rooms = &filterRooms
|
||||
evs, err := d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, &filter, nil)
|
||||
if err != nil {
|
||||
return summary, err
|
||||
}
|
||||
|
||||
for _, ev := range evs {
|
||||
switch ev.Type() {
|
||||
case gomatrixserverlib.MRoomName:
|
||||
if gjson.GetBytes(ev.Content(), "name").Str != "" {
|
||||
return summary, nil
|
||||
}
|
||||
case gomatrixserverlib.MRoomCanonicalAlias:
|
||||
if gjson.GetBytes(ev.Content(), "alias").Str != "" {
|
||||
return summary, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there's no room name or canonical alias, get the room heroes, excluding the user
|
||||
heroes, err := d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Join, gomatrixserverlib.Invite})
|
||||
if err != nil {
|
||||
return summary, err
|
||||
}
|
||||
|
||||
// "When no joined or invited members are available, this should consist of the banned and left users"
|
||||
if len(heroes) == 0 {
|
||||
heroes, err = d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Leave, gomatrixserverlib.Ban})
|
||||
if err != nil {
|
||||
return summary, err
|
||||
}
|
||||
}
|
||||
summary.Heroes = heroes
|
||||
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@ -95,6 +96,15 @@ const selectSharedUsersSQL = "" +
|
||||
" SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
|
||||
") AND type = 'm.room.member' AND state_key IN ($2) AND membership IN ('join', 'invite');"
|
||||
|
||||
const selectMembershipCount = `SELECT count(*) FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND membership = $2`
|
||||
|
||||
const selectRoomHeroes = `
|
||||
SELECT state_key FROM syncapi_current_room_state
|
||||
WHERE type = 'm.room.member' AND room_id = $1 AND state_key != $2 AND membership IN ($3)
|
||||
ORDER BY added_at, state_key
|
||||
LIMIT 5
|
||||
`
|
||||
|
||||
type currentRoomStateStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *StreamIDStatements
|
||||
@ -107,6 +117,8 @@ type currentRoomStateStatements struct {
|
||||
//selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
selectStateEventStmt *sql.Stmt
|
||||
//selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic
|
||||
selectMembershipCountStmt *sql.Stmt
|
||||
//selectRoomHeroes *sql.Stmt - prepared at runtime due to variadic
|
||||
}
|
||||
|
||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
|
||||
@ -129,31 +141,16 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteRoomStateForRoomStmt, err = db.Prepare(deleteRoomStateForRoomSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertRoomStateStmt, upsertRoomStateSQL},
|
||||
{&s.deleteRoomStateByEventIDStmt, deleteRoomStateByEventIDSQL},
|
||||
{&s.deleteRoomStateForRoomStmt, deleteRoomStateForRoomSQL},
|
||||
{&s.selectRoomIDsWithMembershipStmt, selectRoomIDsWithMembershipSQL},
|
||||
{&s.selectRoomIDsWithAnyMembershipStmt, selectRoomIDsWithAnyMembershipSQL},
|
||||
{&s.selectJoinedUsersStmt, selectJoinedUsersSQL},
|
||||
{&s.selectStateEventStmt, selectStateEventSQL},
|
||||
{&s.selectMembershipCountStmt, selectMembershipCount},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
|
||||
@ -485,3 +482,53 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *currentRoomStateStatements) SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) {
|
||||
params := make([]interface{}, len(memberships)+2)
|
||||
params[0] = roomID
|
||||
params[1] = excludeUserID
|
||||
for k, v := range memberships {
|
||||
params[k+2] = v
|
||||
}
|
||||
|
||||
query := strings.Replace(selectRoomHeroes, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
|
||||
var stmt *sql.Stmt
|
||||
var err error
|
||||
if txn != nil {
|
||||
stmt, err = txn.Prepare(query)
|
||||
} else {
|
||||
stmt, err = s.db.Prepare(query)
|
||||
}
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, stmt, "selectRoomHeroes: stmt.close() failed")
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomHeroes: rows.close() failed")
|
||||
|
||||
var stateKey string
|
||||
result := make([]string, 0, 5)
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&stateKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, stateKey)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *currentRoomStateStatements) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (count int, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt)
|
||||
err = stmt.QueryRowContext(ctx, roomID, membership).Scan(&count)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
@ -18,11 +18,9 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
@ -64,9 +62,6 @@ const selectMembershipCountSQL = "" +
|
||||
" SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" +
|
||||
") t WHERE t.membership = $3"
|
||||
|
||||
const selectHeroesSQL = "" +
|
||||
"SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5"
|
||||
|
||||
const selectMembershipBeforeSQL = "" +
|
||||
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
|
||||
|
||||
@ -99,7 +94,6 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
|
||||
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
|
||||
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
|
||||
{&s.selectMembersStmt, selectMembersSQL},
|
||||
// {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
@ -131,39 +125,6 @@ func (s *membershipsStatements) SelectMembershipCount(
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipsStatements) SelectHeroes(
|
||||
ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string,
|
||||
) (heroes []string, err error) {
|
||||
stmtSQL := strings.Replace(selectHeroesSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
|
||||
stmt, err := s.db.PrepareContext(ctx, stmtSQL)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, stmt, "SelectHeroes: stmt.close() failed")
|
||||
params := []interface{}{
|
||||
roomID, userID,
|
||||
}
|
||||
for _, membership := range memberships {
|
||||
params = append(params, membership)
|
||||
}
|
||||
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
var rows *sql.Rows
|
||||
rows, err = stmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed")
|
||||
var hero string
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&hero); err != nil {
|
||||
return
|
||||
}
|
||||
heroes = append(heroes, hero)
|
||||
}
|
||||
return heroes, rows.Err()
|
||||
}
|
||||
|
||||
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
|
||||
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
|
||||
// string as the membership.
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
@ -664,3 +665,181 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ
|
||||
return &tok
|
||||
}
|
||||
*/
|
||||
|
||||
func pointer[t any](s t) *t {
|
||||
return &s
|
||||
}
|
||||
|
||||
func TestRoomSummary(t *testing.T) {
|
||||
|
||||
alice := test.NewUser(t)
|
||||
bob := test.NewUser(t)
|
||||
charlie := test.NewUser(t)
|
||||
|
||||
// Create some dummy users
|
||||
moreUsers := []*test.User{}
|
||||
moreUserIDs := []string{}
|
||||
for i := 0; i < 10; i++ {
|
||||
u := test.NewUser(t)
|
||||
moreUsers = append(moreUsers, u)
|
||||
moreUserIDs = append(moreUserIDs, u.ID)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantSummary *types.Summary
|
||||
additionalEvents func(t *testing.T, room *test.Room)
|
||||
}{
|
||||
{
|
||||
name: "after initial creation",
|
||||
wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{}},
|
||||
},
|
||||
{
|
||||
name: "invited user",
|
||||
wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{bob.ID}},
|
||||
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "invite",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invited user, but declined",
|
||||
wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}},
|
||||
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "invite",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "leave",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "joined user after invitation",
|
||||
wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}},
|
||||
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "invite",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple joined user",
|
||||
wantSummary: &types.Summary{JoinedMemberCount: pointer(3), InvitedMemberCount: pointer(0), Heroes: []string{charlie.ID, bob.ID}},
|
||||
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||
room.CreateAndInsert(t, charlie, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, test.WithStateKey(charlie.ID))
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple joined/invited user",
|
||||
wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID, bob.ID}},
|
||||
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "invite",
|
||||
}, test.WithStateKey(charlie.ID))
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple joined/invited/left user",
|
||||
wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID}},
|
||||
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "invite",
|
||||
}, test.WithStateKey(charlie.ID))
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "leave",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "leaving user after joining",
|
||||
wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}},
|
||||
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "leave",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "many users", // heroes ordered by stream id
|
||||
wantSummary: &types.Summary{JoinedMemberCount: pointer(len(moreUserIDs) + 1), InvitedMemberCount: pointer(0), Heroes: moreUserIDs[:5]},
|
||||
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||
for _, x := range moreUsers {
|
||||
room.CreateAndInsert(t, x, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, test.WithStateKey(x.ID))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "canonical alias set",
|
||||
wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}},
|
||||
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomCanonicalAlias, map[string]interface{}{
|
||||
"alias": "myalias",
|
||||
}, test.WithStateKey(""))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "room name set",
|
||||
wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}},
|
||||
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomName, map[string]interface{}{
|
||||
"name": "my room name",
|
||||
}, test.WithStateKey(""))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close, closeBase := MustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
defer closeBase()
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
r := test.NewRoom(t, alice)
|
||||
|
||||
if tc.additionalEvents != nil {
|
||||
tc.additionalEvents(t, r)
|
||||
}
|
||||
|
||||
// write the room before creating a transaction
|
||||
MustWriteEvents(t, db, r.Events())
|
||||
|
||||
transaction, err := db.NewDatabaseTransaction(ctx)
|
||||
assert.NoError(t, err)
|
||||
defer transaction.Rollback()
|
||||
|
||||
summary, err := transaction.GetRoomSummary(ctx, r.ID, alice.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.wantSummary, summary)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -115,6 +115,9 @@ type CurrentRoomState interface {
|
||||
SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error)
|
||||
// SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
|
||||
SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
|
||||
|
||||
SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error)
|
||||
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (int, error)
|
||||
}
|
||||
|
||||
// BackwardsExtremities keeps track of backwards extremities for a room.
|
||||
@ -185,7 +188,6 @@ type Receipts interface {
|
||||
type Memberships interface {
|
||||
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
|
||||
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
|
||||
SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error)
|
||||
SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
|
||||
SelectMemberships(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
|
@ -3,8 +3,6 @@ package tables_test
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -88,43 +86,9 @@ func TestMembershipsTable(t *testing.T) {
|
||||
|
||||
testUpsert(t, ctx, table, userEvents[0], alice, room)
|
||||
testMembershipCount(t, ctx, table, room)
|
||||
testHeroes(t, ctx, table, alice, room, users)
|
||||
})
|
||||
}
|
||||
|
||||
func testHeroes(t *testing.T, ctx context.Context, table tables.Memberships, user *test.User, room *test.Room, users []string) {
|
||||
|
||||
// Re-slice and sort the expected users
|
||||
users = users[1:]
|
||||
sort.Strings(users)
|
||||
type testCase struct {
|
||||
name string
|
||||
memberships []string
|
||||
wantHeroes []string
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{name: "no memberships queried", memberships: []string{}},
|
||||
{name: "joined memberships queried should be limited", memberships: []string{gomatrixserverlib.Join}, wantHeroes: users[:5]},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := table.SelectHeroes(ctx, nil, room.ID, user.ID, tc.memberships)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to select heroes: %s", err)
|
||||
}
|
||||
if gotLen := len(got); gotLen != len(tc.wantHeroes) {
|
||||
t.Fatalf("expected %d heroes, got %d", len(tc.wantHeroes), gotLen)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, tc.wantHeroes) {
|
||||
t.Fatalf("expected heroes to be %+v, got %+v", tc.wantHeroes, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) {
|
||||
t.Run("membership counts are correct", func(t *testing.T) {
|
||||
// After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users)
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
@ -14,11 +13,9 @@ import (
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
|
||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/matrix-org/dendrite/syncapi/notifier"
|
||||
)
|
||||
|
||||
// The max number of per-room goroutines to have running.
|
||||
@ -339,7 +336,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||
case gomatrixserverlib.Join:
|
||||
jr := types.NewJoinResponse()
|
||||
if hasMembershipChange {
|
||||
p.addRoomSummary(ctx, snapshot, jr, delta.RoomID, device.UserID, latestPosition)
|
||||
jr.Summary, err = snapshot.GetRoomSummary(ctx, delta.RoomID, device.UserID)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Warn("failed to get room summary")
|
||||
}
|
||||
}
|
||||
jr.Timeline.PrevBatch = &prevBatch
|
||||
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
|
||||
@ -411,45 +411,6 @@ func applyHistoryVisibilityFilter(
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, snapshot storage.DatabaseTransaction, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) {
|
||||
// Work out how many members are in the room.
|
||||
joinedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition)
|
||||
invitedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition)
|
||||
|
||||
jr.Summary.JoinedMemberCount = &joinedCount
|
||||
jr.Summary.InvitedMemberCount = &invitedCount
|
||||
|
||||
fetchStates := []gomatrixserverlib.StateKeyTuple{
|
||||
{EventType: gomatrixserverlib.MRoomName},
|
||||
{EventType: gomatrixserverlib.MRoomCanonicalAlias},
|
||||
}
|
||||
// Check if the room has a name or a canonical alias
|
||||
latestState := &roomserverAPI.QueryLatestEventsAndStateResponse{}
|
||||
err := p.rsAPI.QueryLatestEventsAndState(ctx, &roomserverAPI.QueryLatestEventsAndStateRequest{StateToFetch: fetchStates, RoomID: roomID}, latestState)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Check if the room has a name or canonical alias, if so, return.
|
||||
for _, ev := range latestState.StateEvents {
|
||||
switch ev.Type() {
|
||||
case gomatrixserverlib.MRoomName:
|
||||
if gjson.GetBytes(ev.Content(), "name").Str != "" {
|
||||
return
|
||||
}
|
||||
case gomatrixserverlib.MRoomCanonicalAlias:
|
||||
if gjson.GetBytes(ev.Content(), "alias").Str != "" {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
heroes, err := snapshot.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sort.Strings(heroes)
|
||||
jr.Summary.Heroes = heroes
|
||||
}
|
||||
|
||||
func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||
ctx context.Context,
|
||||
snapshot storage.DatabaseTransaction,
|
||||
@ -493,7 +454,10 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
||||
return
|
||||
}
|
||||
|
||||
p.addRoomSummary(ctx, snapshot, jr, roomID, device.UserID, r.From)
|
||||
jr.Summary, err = snapshot.GetRoomSummary(ctx, roomID, device.UserID)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Warn("failed to get room summary")
|
||||
}
|
||||
|
||||
// We don't include a device here as we don't need to send down
|
||||
// transaction IDs for complete syncs, but we do it anyway because Sytest demands it for:
|
||||
|
Loading…
Reference in New Issue
Block a user