Optimise QuerySharedUsers so that we can only work on local users (#2766)

Otherwise the sync API key change consumer wastes a lot of time trying
to wake up the notifiers for non-local users.
This commit is contained in:
Neil Alexander 2022-10-05 12:47:53 +01:00 committed by GitHub
parent 6f602bb096
commit c85bc3434f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 36 additions and 23 deletions

View File

@ -278,6 +278,7 @@ type QuerySharedUsersRequest struct {
OtherUserIDs []string OtherUserIDs []string
ExcludeRoomIDs []string ExcludeRoomIDs []string
IncludeRoomIDs []string IncludeRoomIDs []string
LocalOnly bool
} }
type QuerySharedUsersResponse struct { type QuerySharedUsersResponse struct {

View File

@ -799,7 +799,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser
} }
roomIDs = roomIDs[:j] roomIDs = roomIDs[:j]
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs) users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs, req.LocalOnly)
if err != nil { if err != nil {
return err return err
} }

View File

@ -157,7 +157,7 @@ type Database interface {
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
// JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms. // JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms.
JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string, localOnly bool) (map[string]int, error)
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
// GetServerInRoom returns true if we think a server is in a given room or false otherwise. // GetServerInRoom returns true if we think a server is in a given room or false otherwise.

View File

@ -68,14 +68,18 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
var selectJoinedUsersSetForRoomsAndUserSQL = "" + var selectJoinedUsersSetForRoomsAndUserSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" + " WHERE (target_local OR $1 = false)" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " AND room_nid = ANY($2) AND target_nid = ANY($3)" +
" AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
" AND forgotten = false" +
" GROUP BY target_nid" " GROUP BY target_nid"
var selectJoinedUsersSetForRoomsSQL = "" + var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid = ANY($1) AND" + " WHERE (target_local OR $1 = false) " +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " AND room_nid = ANY($2)" +
" AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
" AND forgotten = false" +
" GROUP BY target_nid" " GROUP BY target_nid"
// Insert a row in to membership table so that it can be locked by the // Insert a row in to membership table so that it can be locked by the
@ -334,6 +338,7 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomNIDs []types.RoomNID, roomNIDs []types.RoomNID,
userNIDs []types.EventStateKeyNID, userNIDs []types.EventStateKeyNID,
localOnly bool,
) (map[types.EventStateKeyNID]int, error) { ) (map[types.EventStateKeyNID]int, error) {
var ( var (
rows *sql.Rows rows *sql.Rows
@ -342,9 +347,9 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt) stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
if len(userNIDs) > 0 { if len(userNIDs) > 0 {
stmt = sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsAndUserStmt) stmt = sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsAndUserStmt)
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs)) rows, err = stmt.QueryContext(ctx, localOnly, pq.Array(roomNIDs), pq.Array(userNIDs))
} else { } else {
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs)) rows, err = stmt.QueryContext(ctx, localOnly, pq.Array(roomNIDs))
} }
if err != nil { if err != nil {

View File

@ -1280,7 +1280,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
} }
// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms. // JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms.
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) { func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string, localOnly bool) (map[string]int, error) {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs) roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
if err != nil { if err != nil {
return nil, err return nil, err
@ -1295,7 +1295,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
userNIDs = append(userNIDs, nid) userNIDs = append(userNIDs, nid)
nidToUserID[nid] = id nidToUserID[nid] = id
} }
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs) userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs, localOnly)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -44,14 +44,18 @@ const membershipSchema = `
var selectJoinedUsersSetForRoomsAndUserSQL = "" + var selectJoinedUsersSetForRoomsAndUserSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" + " WHERE (target_local OR $1 = false)" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " AND room_nid IN ($2) AND target_nid IN ($3)" +
" AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
" AND forgotten = false" +
" GROUP BY target_nid" " GROUP BY target_nid"
var selectJoinedUsersSetForRoomsSQL = "" + var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid IN ($1) AND " + " WHERE (target_local OR $1 = false)" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " AND room_nid IN ($2)" +
" AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
" AND forgotten = false" +
" GROUP BY target_nid" " GROUP BY target_nid"
// Insert a row in to membership table so that it can be locked by the // Insert a row in to membership table so that it can be locked by the
@ -305,8 +309,9 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil return roomNIDs, nil
} }
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) { func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error) {
params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs)) params := make([]interface{}, 0, 1+len(roomNIDs)+len(userNIDs))
params = append(params, localOnly)
for _, v := range roomNIDs { for _, v := range roomNIDs {
params = append(params, v) params = append(params, v)
} }
@ -314,10 +319,10 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
params = append(params, v) params = append(params, v)
} }
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($2)", sqlutil.QueryVariadicOffset(len(roomNIDs), 1), 1)
if len(userNIDs) > 0 { if len(userNIDs) > 0 {
query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(roomNIDs), 1), 1)
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1) query = strings.Replace(query, "($3)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)+1), 1)
} }
var rows *sql.Rows var rows *sql.Rows
var err error var err error

View File

@ -137,7 +137,7 @@ type Membership interface {
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error)
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms. // SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)

View File

@ -79,7 +79,7 @@ func TestMembershipTable(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, inRoom) assert.True(t, inRoom)
userJoinedToRooms, err := tab.SelectJoinedUsersSetForRooms(ctx, nil, []types.RoomNID{1}, userNIDs) userJoinedToRooms, err := tab.SelectJoinedUsersSetForRooms(ctx, nil, []types.RoomNID{1}, userNIDs, false)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 1, len(userJoinedToRooms)) assert.Equal(t, 1, len(userJoinedToRooms))

View File

@ -111,7 +111,8 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, d
// work out who we need to notify about the new key // work out who we need to notify about the new key
var queryRes roomserverAPI.QuerySharedUsersResponse var queryRes roomserverAPI.QuerySharedUsersResponse
err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{ err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{
UserID: output.UserID, UserID: output.UserID,
LocalOnly: true,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server") logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server")
@ -135,7 +136,8 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage
// work out who we need to notify about the new key // work out who we need to notify about the new key
var queryRes roomserverAPI.QuerySharedUsersResponse var queryRes roomserverAPI.QuerySharedUsersResponse
err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{ err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{
UserID: output.UserID, UserID: output.UserID,
LocalOnly: true,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server") logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server")