diff --git a/federationapi/producers/syncapi.go b/federationapi/producers/syncapi.go index 0825bccb..86c8c10a 100644 --- a/federationapi/producers/syncapi.go +++ b/federationapi/producers/syncapi.go @@ -167,15 +167,11 @@ func (p *SyncAPIProducer) SendPresence( } func (p *SyncAPIProducer) SendDeviceListUpdate( - ctx context.Context, deviceListUpdate *gomatrixserverlib.DeviceListUpdateEvent, + ctx context.Context, deviceListUpdate gomatrixserverlib.RawJSON, origin string, ) (err error) { m := nats.NewMsg(p.TopicDeviceListUpdate) - m.Header.Set(jetstream.UserID, deviceListUpdate.UserID) - m.Data, err = json.Marshal(deviceListUpdate) - if err != nil { - return fmt.Errorf("json.Marshal: %w", err) - } - + m.Header.Set("origin", origin) + m.Data = deviceListUpdate log.Debugf("Sending device list update: %+v", m.Header) _, err = p.JetStream.PublishMsg(m, nats.Context(ctx)) return err diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 3d931996..a9714c65 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -359,7 +359,9 @@ func (t *txnReq) processEDUs(ctx context.Context) { } } case gomatrixserverlib.MDeviceListUpdate: - t.processDeviceListUpdate(ctx, e) + if err := t.producer.SendDeviceListUpdate(ctx, e.Content, e.Origin); err != nil { + util.GetLogger(ctx).WithError(err).Error("failed to InputDeviceListUpdate") + } case gomatrixserverlib.MReceipt: // https://matrix.org/docs/spec/server_server/r0.1.4#receipts payload := map[string]types.FederationReceiptMRead{} @@ -454,21 +456,3 @@ func (t *txnReq) processReceiptEvent(ctx context.Context, return nil } - -func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverlib.EDU) { - var payload gomatrixserverlib.DeviceListUpdateEvent - if err := json.Unmarshal(e.Content, &payload); err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal device list update event") - return - } - if _, serverName, err := gomatrixserverlib.SplitID('@', payload.UserID); err != nil { - return - } else if serverName == t.ourServerName { - return - } else if serverName != t.Origin { - return - } - if err := t.producer.SendDeviceListUpdate(ctx, &payload); err != nil { - util.GetLogger(ctx).WithError(err).WithField("user_id", payload.UserID).Error("failed to InputDeviceListUpdate") - } -} diff --git a/keyserver/consumers/devicelistupdate.go b/keyserver/consumers/devicelistupdate.go index d15f9426..575e4128 100644 --- a/keyserver/consumers/devicelistupdate.go +++ b/keyserver/consumers/devicelistupdate.go @@ -18,22 +18,24 @@ import ( "context" "encoding/json" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/keyserver/internal" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" ) // DeviceListUpdateConsumer consumes device list updates that came in over federation. type DeviceListUpdateConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - updater *internal.DeviceListUpdater + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + updater *internal.DeviceListUpdater + serverName gomatrixserverlib.ServerName } // NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. @@ -44,11 +46,12 @@ func NewDeviceListUpdateConsumer( updater *internal.DeviceListUpdater, ) *DeviceListUpdateConsumer { return &DeviceListUpdateConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), - updater: updater, + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + updater: updater, + serverName: cfg.Matrix.ServerName, } } @@ -69,6 +72,15 @@ func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M logrus.WithError(err).Errorf("Failed to read from device list update input topic") return true } + origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) + if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil { + return true + } else if serverName == t.serverName { + return true + } else if serverName != origin { + return true + } + err := t.updater.Update(ctx, m) if err != nil { logrus.WithFields(logrus.Fields{