Don't get blacklisted hosts when querying joined servers (#2880)

Otherwise we just waste time/CPU.
This commit is contained in:
Neil Alexander 2022-11-15 17:21:16 +00:00 committed by GitHub
parent 5c9aed6af9
commit 9b8bb55430
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 56 additions and 32 deletions

View File

@ -200,6 +200,7 @@ type PerformInviteResponse struct {
type QueryJoinedHostServerNamesInRoomRequest struct { type QueryJoinedHostServerNamesInRoomRequest struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
ExcludeSelf bool `json:"exclude_self"` ExcludeSelf bool `json:"exclude_self"`
ExcludeBlacklisted bool `json:"exclude_blacklisted"`
} }
// QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames // QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames

View File

@ -128,7 +128,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
} }
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
@ -189,7 +189,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
return true return true
} }
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")

View File

@ -111,7 +111,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
} }
// send this presence to all servers who share rooms with this user. // send this presence to all servers who share rooms with this user.
joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
if err != nil { if err != nil {
log.WithError(err).Error("failed to get joined hosts") log.WithError(err).Error("failed to get joined hosts")
return true return true

View File

@ -16,7 +16,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom(
request *api.QueryJoinedHostServerNamesInRoomRequest, request *api.QueryJoinedHostServerNamesInRoomRequest,
response *api.QueryJoinedHostServerNamesInRoomResponse, response *api.QueryJoinedHostServerNamesInRoomResponse,
) (err error) { ) (err error) {
joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf) joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf, request.ExcludeBlacklisted)
if err != nil { if err != nil {
return return
} }

View File

@ -32,7 +32,7 @@ type Database interface {
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given. // GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error)
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)

View File

@ -66,6 +66,11 @@ const selectAllJoinedHostsSQL = "" +
const selectJoinedHostsForRoomsSQL = "" + const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)" "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)"
const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id = ANY($1) AND NOT EXISTS (" +
" SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" +
");"
type joinedHostsStatements struct { type joinedHostsStatements struct {
db *sql.DB db *sql.DB
insertJoinedHostsStmt *sql.Stmt insertJoinedHostsStmt *sql.Stmt
@ -74,6 +79,7 @@ type joinedHostsStatements struct {
selectJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt
selectJoinedHostsForRoomsStmt *sql.Stmt selectJoinedHostsForRoomsStmt *sql.Stmt
selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt
} }
func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
@ -102,6 +108,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro
if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil { if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil {
return return
} }
if s.selectJoinedHostsForRoomsExcludingBlacklistedStmt, err = s.db.Prepare(selectJoinedHostsForRoomsExcludingBlacklistedSQL); err != nil {
return
}
return return
} }
@ -167,9 +176,13 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
} }
func (s *joinedHostsStatements) SelectJoinedHostsForRooms( func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
ctx context.Context, roomIDs []string, ctx context.Context, roomIDs []string, excludingBlacklisted bool,
) ([]gomatrixserverlib.ServerName, error) { ) ([]gomatrixserverlib.ServerName, error) {
rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs)) stmt := s.selectJoinedHostsForRoomsStmt
if excludingBlacklisted {
stmt = s.selectJoinedHostsForRoomsExcludingBlacklistedStmt
}
rows, err := stmt.QueryContext(ctx, pq.StringArray(roomIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -42,6 +42,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
return nil, err return nil, err
} }
blacklist, err := NewPostgresBlacklistTable(d.db)
if err != nil {
return nil, err
}
joinedHosts, err := NewPostgresJoinedHostsTable(d.db) joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -58,10 +62,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, err return nil, err
} }
blacklist, err := NewPostgresBlacklistTable(d.db)
if err != nil {
return nil, err
}
inboundPeeks, err := NewPostgresInboundPeeksTable(d.db) inboundPeeks, err := NewPostgresInboundPeeksTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -117,8 +117,8 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
} }
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) { func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) {
servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs) servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -66,6 +66,11 @@ const selectAllJoinedHostsSQL = "" +
const selectJoinedHostsForRoomsSQL = "" + const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id IN ($1) AND NOT EXISTS (" +
" SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" +
");"
type joinedHostsStatements struct { type joinedHostsStatements struct {
db *sql.DB db *sql.DB
insertJoinedHostsStmt *sql.Stmt insertJoinedHostsStmt *sql.Stmt
@ -74,6 +79,7 @@ type joinedHostsStatements struct {
selectJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
// selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt - prepared at runtime due to variadic
} }
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
@ -168,14 +174,17 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
} }
func (s *joinedHostsStatements) SelectJoinedHostsForRooms( func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
ctx context.Context, roomIDs []string, ctx context.Context, roomIDs []string, excludingBlacklisted bool,
) ([]gomatrixserverlib.ServerName, error) { ) ([]gomatrixserverlib.ServerName, error) {
iRoomIDs := make([]interface{}, len(roomIDs)) iRoomIDs := make([]interface{}, len(roomIDs))
for i := range roomIDs { for i := range roomIDs {
iRoomIDs[i] = roomIDs[i] iRoomIDs[i] = roomIDs[i]
} }
query := selectJoinedHostsForRoomsSQL
sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) if excludingBlacklisted {
query = selectJoinedHostsForRoomsExcludingBlacklistedSQL
}
sql := strings.Replace(query, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...) rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -41,6 +41,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
return nil, err return nil, err
} }
blacklist, err := NewSQLiteBlacklistTable(d.db)
if err != nil {
return nil, err
}
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -57,10 +61,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, err return nil, err
} }
blacklist, err := NewSQLiteBlacklistTable(d.db)
if err != nil {
return nil, err
}
outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db) outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -58,7 +58,7 @@ type FederationJoinedHosts interface {
SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error) SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error)
SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludingBlacklisted bool) ([]gomatrixserverlib.ServerName, error)
} }
type FederationBlacklist interface { type FederationBlacklist interface {

View File

@ -167,6 +167,7 @@ func (r *Inputer) processRoomEvent(
serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{
RoomID: event.RoomID(), RoomID: event.RoomID(),
ExcludeSelf: true, ExcludeSelf: true,
ExcludeBlacklisted: true,
} }
if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil {
return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)