diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 9ba3988b..c9ec59a7 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -21,9 +21,10 @@ import ( "strings" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/keyserver/types" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) type KeyInternalAPI interface { @@ -56,6 +57,7 @@ type UserKeyAPI interface { type SyncKeyAPI interface { QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error + PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error } type FederationKeyAPI interface { @@ -335,3 +337,9 @@ type QuerySignaturesResponse struct { // The request error, if any Error *KeyError } + +type PerformMarkAsStaleRequest struct { + UserID string + Domain gomatrixserverlib.ServerName + DeviceID string +} diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 41b4d44a..a8d1128c 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -23,16 +23,17 @@ import ( "sync" "time" - fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/producers" - "github.com/matrix-org/dendrite/keyserver/storage" - userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/tidwall/sjson" + + fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/producers" + "github.com/matrix-org/dendrite/keyserver/storage" + userapi "github.com/matrix-org/dendrite/userapi/api" ) type KeyInternalAPI struct { @@ -224,6 +225,19 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query return nil } +// PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present +// in our database. +func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error { + knownDevices, err := a.DB.DeviceKeysForUser(ctx, req.UserID, []string{req.DeviceID}, true) + if err != nil { + return err + } + if len(knownDevices) == 0 { + return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID) + } + return nil +} + // nolint:gocyclo func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { res.DeviceKeys = make(map[string]map[string]json.RawMessage) diff --git a/keyserver/inthttp/client.go b/keyserver/inthttp/client.go index 7a713114..75d537d9 100644 --- a/keyserver/inthttp/client.go +++ b/keyserver/inthttp/client.go @@ -37,6 +37,7 @@ const ( QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys" QueryDeviceMessagesPath = "/keyserver/queryDeviceMessages" QuerySignaturesPath = "/keyserver/querySignatures" + PerformMarkAsStalePath = "/keyserver/markAsStale" ) // NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API. @@ -172,3 +173,14 @@ func (h *httpKeyInternalAPI) QuerySignatures( h.httpClient, ctx, request, response, ) } + +func (h *httpKeyInternalAPI) PerformMarkAsStaleIfNeeded( + ctx context.Context, + request *api.PerformMarkAsStaleRequest, + response *struct{}, +) error { + return httputil.CallInternalRPCAPI( + "MarkAsStale", h.apiURL+PerformMarkAsStalePath, + h.httpClient, ctx, request, response, + ) +} diff --git a/keyserver/inthttp/server.go b/keyserver/inthttp/server.go index 4e5f9fba..7af0ff6e 100644 --- a/keyserver/inthttp/server.go +++ b/keyserver/inthttp/server.go @@ -16,6 +16,7 @@ package inthttp import ( "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver/api" ) @@ -70,4 +71,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) { QuerySignaturesPath, httputil.MakeInternalRPCAPI("KeyserverQuerySignatures", s.QuerySignatures), ) + + internalAPIMux.Handle( + PerformMarkAsStalePath, + httputil.MakeInternalRPCAPI("KeyserverMarkAsStale", s.PerformMarkAsStaleIfNeeded), + ) } diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go index 7d6aae59..c0b43225 100644 --- a/syncapi/consumers/sendtodevice.go +++ b/syncapi/consumers/sendtodevice.go @@ -23,7 +23,9 @@ import ( "github.com/matrix-org/util" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -39,6 +41,7 @@ type OutputSendToDeviceEventConsumer struct { durable string topic string db storage.Database + keyAPI keyapi.SyncKeyAPI serverName gomatrixserverlib.ServerName // our server name stream types.StreamProvider notifier *notifier.Notifier @@ -51,6 +54,7 @@ func NewOutputSendToDeviceEventConsumer( cfg *config.SyncAPI, js nats.JetStreamContext, store storage.Database, + keyAPI keyapi.SyncKeyAPI, notifier *notifier.Notifier, stream types.StreamProvider, ) *OutputSendToDeviceEventConsumer { @@ -60,6 +64,7 @@ func NewOutputSendToDeviceEventConsumer( topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), durable: cfg.Matrix.JetStream.Durable("SyncAPISendToDeviceConsumer"), db: store, + keyAPI: keyAPI, serverName: cfg.Matrix.ServerName, notifier: notifier, stream: stream, @@ -96,12 +101,28 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs [] return true } - util.GetLogger(context.TODO()).WithFields(log.Fields{ + logger := util.GetLogger(context.TODO()).WithFields(log.Fields{ "sender": output.Sender, "user_id": output.UserID, "device_id": output.DeviceID, "event_type": output.Type, - }).Debugf("sync API received send-to-device event from the clientapi/federationsender") + }) + logger.Debugf("sync API received send-to-device event from the clientapi/federationsender") + + // Check we actually got the requesting device in our store, if we receive a room key request + if output.Type == "m.room_key_request" { + requestingDeviceID := gjson.GetBytes(output.SendToDeviceEvent.Content, "requesting_device_id").Str + _, senderDomain, _ := gomatrixserverlib.SplitID('@', output.Sender) + if requestingDeviceID != "" && senderDomain != s.serverName { + // Mark the requesting device as stale, if we don't know about it. + if err = s.keyAPI.PerformMarkAsStaleIfNeeded(ctx, &keyapi.PerformMarkAsStaleRequest{ + UserID: output.Sender, Domain: senderDomain, DeviceID: requestingDeviceID, + }, &struct{}{}); err != nil { + logger.WithError(err).Errorf("failed to mark as stale if needed") + return false + } + } + } streamPos, err := s.db.StoreNewSendForDeviceMessage( s.ctx, output.UserID, output.DeviceID, output.SendToDeviceEvent, diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index 80d2811b..3b9c8221 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -22,6 +22,10 @@ var ( type mockKeyAPI struct{} +func (k *mockKeyAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *keyapi.PerformMarkAsStaleRequest, res *struct{}) error { + return nil +} + func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) error { return nil } diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 92db18d5..68537bc4 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -17,9 +17,10 @@ package syncapi import ( "context" - "github.com/matrix-org/dendrite/internal/caching" "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal/caching" + keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" @@ -126,7 +127,7 @@ func AddPublicRoutes( } sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer( - base.ProcessContext, cfg, js, syncDB, notifier, streams.SendToDeviceStreamProvider, + base.ProcessContext, cfg, js, syncDB, keyAPI, notifier, streams.SendToDeviceStreamProvider, ) if err = sendToDeviceConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start send-to-device consumer")