mirror of
https://github.com/1f349/dendrite.git
synced 2025-01-10 17:36:28 +00:00
Exclude our own server name in GetJoinedHostsForRooms
(#2110)
* Exclude our own servername * Make excluding self behaviour optional
This commit is contained in:
parent
49a618dfe2
commit
8a1bc70524
@ -187,6 +187,7 @@ type PerformServersAliveResponse struct {
|
|||||||
// QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames
|
// QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames
|
||||||
type QueryJoinedHostServerNamesInRoomRequest struct {
|
type QueryJoinedHostServerNamesInRoomRequest struct {
|
||||||
RoomID string `json:"room_id"`
|
RoomID string `json:"room_id"`
|
||||||
|
ExcludeSelf bool `json:"exclude_self"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames
|
// QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames
|
||||||
|
@ -128,7 +128,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// send this key change to all servers who share rooms with this user.
|
// send this key change to all servers who share rooms with this user.
|
||||||
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs)
|
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
|
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
|
||||||
return nil
|
return nil
|
||||||
@ -180,7 +180,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// send this key change to all servers who share rooms with this user.
|
// send this key change to all servers who share rooms with this user.
|
||||||
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs)
|
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")
|
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")
|
||||||
return nil
|
return nil
|
||||||
|
@ -78,7 +78,7 @@ func NewInternalAPI(
|
|||||||
) api.FederationInternalAPI {
|
) api.FederationInternalAPI {
|
||||||
cfg := &base.Cfg.FederationAPI
|
cfg := &base.Cfg.FederationAPI
|
||||||
|
|
||||||
federationDB, err := storage.NewDatabase(&cfg.Database, base.Caches)
|
federationDB, err := storage.NewDatabase(&cfg.Database, base.Caches, base.Cfg.Global.ServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Panic("failed to connect to federation sender db")
|
logrus.WithError(err).Panic("failed to connect to federation sender db")
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom(
|
|||||||
request *api.QueryJoinedHostServerNamesInRoomRequest,
|
request *api.QueryJoinedHostServerNamesInRoomRequest,
|
||||||
response *api.QueryJoinedHostServerNamesInRoomResponse,
|
response *api.QueryJoinedHostServerNamesInRoomResponse,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID})
|
joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -32,7 +32,7 @@ type Database interface {
|
|||||||
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
|
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
|
||||||
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||||
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
|
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
|
||||||
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
|
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error)
|
||||||
PurgeRoomState(ctx context.Context, roomID string) error
|
PurgeRoomState(ctx context.Context, roomID string) error
|
||||||
|
|
||||||
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
|
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
|
||||||
|
@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Database stores information needed by the federation sender
|
// Database stores information needed by the federation sender
|
||||||
@ -35,7 +36,7 @@ type Database struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase opens a new database
|
// NewDatabase opens a new database
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (*Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) {
|
||||||
var d Database
|
var d Database
|
||||||
var err error
|
var err error
|
||||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||||
@ -89,6 +90,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationC
|
|||||||
}
|
}
|
||||||
d.Database = shared.Database{
|
d.Database = shared.Database{
|
||||||
DB: d.db,
|
DB: d.db,
|
||||||
|
ServerName: serverName,
|
||||||
Cache: cache,
|
Cache: cache,
|
||||||
Writer: d.writer,
|
Writer: d.writer,
|
||||||
FederationJoinedHosts: joinedHosts,
|
FederationJoinedHosts: joinedHosts,
|
||||||
|
@ -29,6 +29,7 @@ import (
|
|||||||
|
|
||||||
type Database struct {
|
type Database struct {
|
||||||
DB *sql.DB
|
DB *sql.DB
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
Cache caching.FederationCache
|
Cache caching.FederationCache
|
||||||
Writer sqlutil.Writer
|
Writer sqlutil.Writer
|
||||||
FederationQueuePDUs tables.FederationQueuePDUs
|
FederationQueuePDUs tables.FederationQueuePDUs
|
||||||
@ -102,8 +103,19 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S
|
|||||||
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
|
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) {
|
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) {
|
||||||
return d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs)
|
servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if excludeSelf {
|
||||||
|
for i, server := range servers {
|
||||||
|
if server == d.ServerName {
|
||||||
|
servers = append(servers[:i], servers[i+1:]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return servers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// StoreJSON adds a JSON blob into the queue JSON table and returns
|
// StoreJSON adds a JSON blob into the queue JSON table and returns
|
||||||
|
@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Database stores information needed by the federation sender
|
// Database stores information needed by the federation sender
|
||||||
@ -34,7 +35,7 @@ type Database struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase opens a new database
|
// NewDatabase opens a new database
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (*Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) {
|
||||||
var d Database
|
var d Database
|
||||||
var err error
|
var err error
|
||||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||||
@ -88,6 +89,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationC
|
|||||||
}
|
}
|
||||||
d.Database = shared.Database{
|
d.Database = shared.Database{
|
||||||
DB: d.db,
|
DB: d.db,
|
||||||
|
ServerName: serverName,
|
||||||
Cache: cache,
|
Cache: cache,
|
||||||
Writer: d.writer,
|
Writer: d.writer,
|
||||||
FederationJoinedHosts: joinedHosts,
|
FederationJoinedHosts: joinedHosts,
|
||||||
|
@ -24,15 +24,16 @@ import (
|
|||||||
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
|
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewDatabase opens a new database
|
// NewDatabase opens a new database
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (Database, error) {
|
||||||
switch {
|
switch {
|
||||||
case dbProperties.ConnectionString.IsSQLite():
|
case dbProperties.ConnectionString.IsSQLite():
|
||||||
return sqlite3.NewDatabase(dbProperties, cache)
|
return sqlite3.NewDatabase(dbProperties, cache, serverName)
|
||||||
case dbProperties.ConnectionString.IsPostgres():
|
case dbProperties.ConnectionString.IsPostgres():
|
||||||
return postgres.NewDatabase(dbProperties, cache)
|
return postgres.NewDatabase(dbProperties, cache, serverName)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unexpected database type")
|
return nil, fmt.Errorf("unexpected database type")
|
||||||
}
|
}
|
||||||
|
@ -20,13 +20,14 @@ import (
|
|||||||
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
|
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewDatabase opens a new database
|
// NewDatabase opens a new database
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (Database, error) {
|
||||||
switch {
|
switch {
|
||||||
case dbProperties.ConnectionString.IsSQLite():
|
case dbProperties.ConnectionString.IsSQLite():
|
||||||
return sqlite3.NewDatabase(dbProperties, cache)
|
return sqlite3.NewDatabase(dbProperties, cache, serverName)
|
||||||
case dbProperties.ConnectionString.IsPostgres():
|
case dbProperties.ConnectionString.IsPostgres():
|
||||||
return nil, fmt.Errorf("can't use Postgres implementation")
|
return nil, fmt.Errorf("can't use Postgres implementation")
|
||||||
default:
|
default:
|
||||||
|
Loading…
Reference in New Issue
Block a user