Exclude our own server name in GetJoinedHostsForRooms (#2110)

* Exclude our own servername

* Make excluding self behaviour optional
This commit is contained in:
Neil Alexander 2022-01-25 17:00:39 +00:00 committed by GitHub
parent 49a618dfe2
commit 8a1bc70524
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 34 additions and 15 deletions

View File

@ -186,7 +186,8 @@ type PerformServersAliveResponse struct {
// QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames // QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames
type QueryJoinedHostServerNamesInRoomRequest struct { type QueryJoinedHostServerNamesInRoomRequest struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
ExcludeSelf bool `json:"exclude_self"`
} }
// QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames // QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames

View File

@ -128,7 +128,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error {
return nil return nil
} }
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
if err != nil { if err != nil {
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
return nil return nil
@ -180,7 +180,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error {
return nil return nil
} }
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
if err != nil { if err != nil {
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")
return nil return nil

View File

@ -78,7 +78,7 @@ func NewInternalAPI(
) api.FederationInternalAPI { ) api.FederationInternalAPI {
cfg := &base.Cfg.FederationAPI cfg := &base.Cfg.FederationAPI
federationDB, err := storage.NewDatabase(&cfg.Database, base.Caches) federationDB, err := storage.NewDatabase(&cfg.Database, base.Caches, base.Cfg.Global.ServerName)
if err != nil { if err != nil {
logrus.WithError(err).Panic("failed to connect to federation sender db") logrus.WithError(err).Panic("failed to connect to federation sender db")
} }

View File

@ -16,7 +16,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom(
request *api.QueryJoinedHostServerNamesInRoomRequest, request *api.QueryJoinedHostServerNamesInRoomRequest,
response *api.QueryJoinedHostServerNamesInRoomResponse, response *api.QueryJoinedHostServerNamesInRoomResponse,
) (err error) { ) (err error) {
joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}) joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf)
if err != nil { if err != nil {
return return
} }

View File

@ -32,7 +32,7 @@ type Database interface {
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given. // GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error)
PurgeRoomState(ctx context.Context, roomID string) error PurgeRoomState(ctx context.Context, roomID string) error
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)

View File

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
) )
// Database stores information needed by the federation sender // Database stores information needed by the federation sender
@ -35,7 +36,7 @@ type Database struct {
} }
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) {
var d Database var d Database
var err error var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {
@ -89,6 +90,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationC
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
ServerName: serverName,
Cache: cache, Cache: cache,
Writer: d.writer, Writer: d.writer,
FederationJoinedHosts: joinedHosts, FederationJoinedHosts: joinedHosts,

View File

@ -29,6 +29,7 @@ import (
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
ServerName gomatrixserverlib.ServerName
Cache caching.FederationCache Cache caching.FederationCache
Writer sqlutil.Writer Writer sqlutil.Writer
FederationQueuePDUs tables.FederationQueuePDUs FederationQueuePDUs tables.FederationQueuePDUs
@ -102,8 +103,19 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
} }
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) { func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) {
return d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs) servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs)
if err != nil {
return nil, err
}
if excludeSelf {
for i, server := range servers {
if server == d.ServerName {
servers = append(servers[:i], servers[i+1:]...)
}
}
}
return servers, nil
} }
// StoreJSON adds a JSON blob into the queue JSON table and returns // StoreJSON adds a JSON blob into the queue JSON table and returns

View File

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
) )
// Database stores information needed by the federation sender // Database stores information needed by the federation sender
@ -34,7 +35,7 @@ type Database struct {
} }
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) {
var d Database var d Database
var err error var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil { if d.db, err = sqlutil.Open(dbProperties); err != nil {
@ -88,6 +89,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationC
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
ServerName: serverName,
Cache: cache, Cache: cache,
Writer: d.writer, Writer: d.writer,
FederationJoinedHosts: joinedHosts, FederationJoinedHosts: joinedHosts,

View File

@ -24,15 +24,16 @@ import (
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3" "github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
) )
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, cache) return sqlite3.NewDatabase(dbProperties, cache, serverName)
case dbProperties.ConnectionString.IsPostgres(): case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(dbProperties, cache) return postgres.NewDatabase(dbProperties, cache, serverName)
default: default:
return nil, fmt.Errorf("unexpected database type") return nil, fmt.Errorf("unexpected database type")
} }

View File

@ -20,13 +20,14 @@ import (
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3" "github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
) )
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, cache) return sqlite3.NewDatabase(dbProperties, cache, serverName)
case dbProperties.ConnectionString.IsPostgres(): case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation") return nil, fmt.Errorf("can't use Postgres implementation")
default: default: