diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index 79c33142..c9430543 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -186,7 +186,7 @@ func main() { ServerKeyAPI: serverKeyAPI, StateAPI: stateAPI, UserAPI: userAPI, - KeyAPI: keyserver.NewInternalAPI(base.Base.Cfg, federation), + KeyAPI: keyserver.NewInternalAPI(base.Base.Cfg, federation, userAPI), ExtPublicRoomsProvider: provider, } monolith.AddAllPublicRoutes(base.Base.PublicAPIMux) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 3cf0168e..8666e8f5 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -141,7 +141,7 @@ func main() { RoomserverAPI: rsAPI, UserAPI: userAPI, StateAPI: stateAPI, - KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation), + KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation, userAPI), //ServerKeyAPI: serverKeyAPI, ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ygg, fsAPI, federation, diff --git a/cmd/dendrite-federation-api-server/main.go b/cmd/dendrite-federation-api-server/main.go index 1bde5636..70d8394f 100644 --- a/cmd/dendrite-federation-api-server/main.go +++ b/cmd/dendrite-federation-api-server/main.go @@ -30,10 +30,11 @@ func main() { keyRing := serverKeyAPI.KeyRing() fsAPI := base.FederationSenderHTTPClient() rsAPI := base.RoomserverHTTPClient() + keyAPI := base.KeyServerHTTPClient() federationapi.AddPublicRoutes( base.PublicAPIMux, base.Cfg, userAPI, federation, keyRing, - rsAPI, fsAPI, base.EDUServerClient(), base.CurrentStateAPIClient(), + rsAPI, fsAPI, base.EDUServerClient(), base.CurrentStateAPIClient(), keyAPI, ) base.SetupAndServeHTTP(string(base.Cfg.Bind.FederationAPI), string(base.Cfg.Listen.FederationAPI)) diff --git a/cmd/dendrite-key-server/main.go b/cmd/dendrite-key-server/main.go index 7dabc258..1aafa144 100644 --- a/cmd/dendrite-key-server/main.go +++ b/cmd/dendrite-key-server/main.go @@ -24,7 +24,7 @@ func main() { base := setup.NewBaseDendrite(cfg, "KeyServer", true) defer base.Close() // nolint: errcheck - intAPI := keyserver.NewInternalAPI(base.Cfg, base.CreateFederationClient()) + intAPI := keyserver.NewInternalAPI(base.Cfg, base.CreateFederationClient(), base.UserAPIClient()) keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 93d62343..80a45c99 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -119,7 +119,7 @@ func main() { rsImpl.SetFederationSenderAPI(fsAPI) stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer) - keyAPI := keyserver.NewInternalAPI(base.Cfg, federation) + keyAPI := keyserver.NewInternalAPI(base.Cfg, federation, userAPI) monolith := setup.Monolith{ Config: base.Cfg, diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index 3d58d957..0bb2dbe9 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -233,7 +233,7 @@ func main() { RoomserverAPI: rsAPI, StateAPI: stateAPI, UserAPI: userAPI, - KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation), + KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation, userAPI), //ServerKeyAPI: serverKeyAPI, ExtPublicRoomsProvider: p2pPublicRoomProvider, } diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 7d1994b2..079f333a 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -20,6 +20,7 @@ import ( eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/config" + keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -38,11 +39,12 @@ func AddPublicRoutes( federationSenderAPI federationSenderAPI.FederationSenderInternalAPI, eduAPI eduserverAPI.EDUServerInputAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, + keyAPI keyserverAPI.KeyInternalAPI, ) { routing.Setup( router, cfg, rsAPI, eduAPI, federationSenderAPI, keyRing, - federation, userAPI, stateAPI, + federation, userAPI, stateAPI, keyAPI, ) } diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 6bbe9d80..8bc4277e 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -31,7 +31,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { fsAPI := base.FederationSenderHTTPClient() // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. - federationapi.AddPublicRoutes(base.PublicAPIMux, cfg, nil, nil, keyRing, nil, fsAPI, nil, nil) + federationapi.AddPublicRoutes(base.PublicAPIMux, cfg, nil, nil, keyRing, nil, fsAPI, nil, nil, nil) httputil.SetupHTTPAPI( base.BaseMux, base.PublicAPIMux, diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index a1dd0fd0..90eec9e0 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -19,12 +19,106 @@ import ( "net/http" "time" + "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "golang.org/x/crypto/ed25519" ) +type queryKeysRequest struct { + DeviceKeys map[string][]string `json:"device_keys"` +} + +// QueryDeviceKeys returns device keys for users on this server. +// https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-query +func QueryDeviceKeys( + httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.KeyInternalAPI, thisServer gomatrixserverlib.ServerName, +) util.JSONResponse { + var qkr queryKeysRequest + err := json.Unmarshal(request.Content(), &qkr) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + } + } + // make sure we only query users on our domain + for userID := range qkr.DeviceKeys { + _, serverName, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + delete(qkr.DeviceKeys, userID) + continue // ignore invalid users + } + if serverName != thisServer { + delete(qkr.DeviceKeys, userID) + continue + } + } + + var queryRes api.QueryKeysResponse + keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{ + UserToDevices: qkr.DeviceKeys, + }, &queryRes) + if queryRes.Error != nil { + util.GetLogger(httpReq.Context()).WithError(queryRes.Error).Error("Failed to QueryKeys") + return jsonerror.InternalServerError() + } + return util.JSONResponse{ + Code: 200, + JSON: struct { + DeviceKeys interface{} `json:"device_keys"` + }{queryRes.DeviceKeys}, + } +} + +type claimOTKsRequest struct { + OneTimeKeys map[string]map[string]string `json:"one_time_keys"` +} + +// ClaimOneTimeKeys claims OTKs for users on this server. +// https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-claim +func ClaimOneTimeKeys( + httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.KeyInternalAPI, thisServer gomatrixserverlib.ServerName, +) util.JSONResponse { + var cor claimOTKsRequest + err := json.Unmarshal(request.Content(), &cor) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + } + } + // make sure we only claim users on our domain + for userID := range cor.OneTimeKeys { + _, serverName, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + delete(cor.OneTimeKeys, userID) + continue // ignore invalid users + } + if serverName != thisServer { + delete(cor.OneTimeKeys, userID) + continue + } + } + + var claimRes api.PerformClaimKeysResponse + keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{ + OneTimeKeys: cor.OneTimeKeys, + }, &claimRes) + if claimRes.Error != nil { + util.GetLogger(httpReq.Context()).WithError(claimRes.Error).Error("Failed to PerformClaimKeys") + return jsonerror.InternalServerError() + } + return util.JSONResponse{ + Code: 200, + JSON: struct { + OneTimeKeys interface{} `json:"one_time_keys"` + }{claimRes.OneTimeKeys}, + } +} + // LocalKeys returns the local keys for the server. // See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys func LocalKeys(cfg *config.Dendrite) util.JSONResponse { diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index cd97f297..50b7bdd2 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -24,6 +24,7 @@ import ( federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/httputil" + keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -54,6 +55,7 @@ func Setup( federation *gomatrixserverlib.FederationClient, userAPI userapi.UserInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, + keyAPI keyserverAPI.KeyInternalAPI, ) { v2keysmux := publicAPIMux.PathPrefix(pathPrefixV2Keys).Subrouter() v1fedmux := publicAPIMux.PathPrefix(pathPrefixV1Federation).Subrouter() @@ -299,4 +301,18 @@ func Setup( return GetPostPublicRooms(req, rsAPI, stateAPI) }), ).Methods(http.MethodGet) + + v1fedmux.Handle("/user/keys/claim", httputil.MakeFedAPI( + "federation_keys_claim", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + return ClaimOneTimeKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) + }, + )).Methods(http.MethodPost) + + v1fedmux.Handle("/user/keys/query", httputil.MakeFedAPI( + "federation_keys_query", cfg.Matrix.ServerName, keys, wakeup, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) + }, + )).Methods(http.MethodPost) } diff --git a/go.mod b/go.mod index dfdc6644..f087b087 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 - github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b + github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 github.com/mattn/go-sqlite3 v2.0.2+incompatible diff --git a/go.sum b/go.sum index a7c8a05b..de7527d9 100644 --- a/go.sum +++ b/go.sum @@ -423,6 +423,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b h1:ul/Jc5q5+QBHNvhd9idfglOwyGf/Tc3ittINEbKJPsQ= github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d h1:WZXyd8YI+PQIDYjN8HxtqNRJ1DCckt9wPTi2P8cdnKM= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= diff --git a/internal/setup/monolith.go b/internal/setup/monolith.go index 39013a2c..1f6d9a76 100644 --- a/internal/setup/monolith.go +++ b/internal/setup/monolith.go @@ -73,7 +73,7 @@ func (m *Monolith) AddAllPublicRoutes(publicMux *mux.Router) { federationapi.AddPublicRoutes( publicMux, m.Config, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationSenderAPI, - m.EDUInternalAPI, m.StateAPI, + m.EDUInternalAPI, m.StateAPI, m.KeyAPI, ) mediaapi.AddPublicRoutes(publicMux, m.Config, m.UserAPI, m.Client) syncapi.AddPublicRoutes( diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index e406dab4..174a72dc 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -24,7 +24,9 @@ import ( "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -33,6 +35,7 @@ type KeyInternalAPI struct { DB storage.Database ThisServer gomatrixserverlib.ServerName FedClient *gomatrixserverlib.FederationClient + UserAPI userapi.UserInternalAPI } func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { @@ -66,11 +69,25 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err), } } - mergeInto(res.OneTimeKeys, keys) + util.GetLogger(ctx).WithField("keys_claimed", len(keys)).WithField("num_users", len(local)).Info("Claimed local keys") + for _, key := range keys { + _, ok := res.OneTimeKeys[key.UserID] + if !ok { + res.OneTimeKeys[key.UserID] = make(map[string]map[string]json.RawMessage) + } + _, ok = res.OneTimeKeys[key.UserID][key.DeviceID] + if !ok { + res.OneTimeKeys[key.UserID][key.DeviceID] = make(map[string]json.RawMessage) + } + for keyID, keyJSON := range key.KeyJSON { + res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON + } + } delete(domainToDeviceKeys, string(a.ThisServer)) } - // claim remote keys - a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) + if len(domainToDeviceKeys) > 0 { + a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) + } } func (a *KeyInternalAPI) claimRemoteKeys( @@ -82,6 +99,7 @@ func (a *KeyInternalAPI) claimRemoteKeys( wg.Add(len(domainToDeviceKeys)) // mutex for failures var failMu sync.Mutex + util.GetLogger(ctx).WithField("num_servers", len(domainToDeviceKeys)).Info("Claiming remote keys from servers") // fan out for d, k := range domainToDeviceKeys { @@ -91,6 +109,7 @@ func (a *KeyInternalAPI) claimRemoteKeys( defer cancel() claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) if err != nil { + util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed") failMu.Lock() res.Failures[domain] = map[string]interface{}{ "message": err.Error(), @@ -108,6 +127,7 @@ func (a *KeyInternalAPI) claimRemoteKeys( close(resultCh) }() + keysClaimed := 0 for result := range resultCh { for userID, nest := range result.OneTimeKeys { res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage) @@ -119,10 +139,12 @@ func (a *KeyInternalAPI) claimRemoteKeys( continue } res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON + keysClaimed++ } } } } + util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys") } func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { @@ -145,13 +167,28 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } return } + + // pull out display names after we have the keys so we handle wildcards correctly + var dids []string + for _, dk := range deviceKeys { + dids = append(dids, dk.DeviceID) + } + var queryRes userapi.QueryDeviceInfosResponse + err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{ + DeviceIDs: dids, + }, &queryRes) + if err != nil { + util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing") + } + if res.DeviceKeys[userID] == nil { res.DeviceKeys[userID] = make(map[string]json.RawMessage) } for _, dk := range deviceKeys { - // inject an empty 'unsigned' key which should be used for display names - // (but not via this API? unsure when they should be added) - dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct{}{}) + // inject display name if known + dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { + DisplayName string `json:"device_display_name,omitempty"` + }{queryRes.DeviceInfo[dk.DeviceID].DisplayName}) res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON } } else { @@ -298,19 +335,3 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) { // TODO } - -func mergeInto(dst map[string]map[string]map[string]json.RawMessage, src []api.OneTimeKeys) { - for _, key := range src { - _, ok := dst[key.UserID] - if !ok { - dst[key.UserID] = make(map[string]map[string]json.RawMessage) - } - _, ok = dst[key.UserID][key.DeviceID] - if !ok { - dst[key.UserID][key.DeviceID] = make(map[string]json.RawMessage) - } - for keyID, keyJSON := range key.KeyJSON { - dst[key.UserID][key.DeviceID][keyID] = keyJSON - } - } -} diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 714b59f0..2e1ddb6c 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/keyserver/internal" "github.com/matrix-org/dendrite/keyserver/inthttp" "github.com/matrix-org/dendrite/keyserver/storage" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) @@ -33,7 +34,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. -func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.FederationClient) api.KeyInternalAPI { +func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.FederationClient, userAPI userapi.UserInternalAPI) api.KeyInternalAPI { db, err := storage.NewDatabase( string(cfg.Database.E2EKey), cfg.DbProperties(), @@ -45,5 +46,6 @@ func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.Federatio DB: db, ThisServer: cfg.Matrix.ServerName, FedClient: fedClient, + UserAPI: userAPI, } } diff --git a/sytest-whitelist b/sytest-whitelist index f21432fb..5bf6d68b 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -122,9 +122,11 @@ User can invite local user to room with version 1 Can upload device keys Should reject keys claiming to belong to a different user Can query device keys using POST +Can query remote device keys using POST Can query specific device keys using POST query for user with no keys returns empty key dict Can claim one time key using POST +Can claim remote one time key using POST Can add account data Can add account data to room Can get account data without syncing diff --git a/userapi/api/api.go b/userapi/api/api.go index cf0f0563..bd0773f8 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -30,6 +30,7 @@ type UserInternalAPI interface { QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error + QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error } // InputAccountDataRequest is the request for InputAccountData @@ -44,6 +45,19 @@ type InputAccountDataRequest struct { type InputAccountDataResponse struct { } +// QueryDeviceInfosRequest is the request to QueryDeviceInfos +type QueryDeviceInfosRequest struct { + DeviceIDs []string +} + +// QueryDeviceInfosResponse is the response to QueryDeviceInfos +type QueryDeviceInfosResponse struct { + DeviceInfo map[string]struct { + DisplayName string + UserID string + } +} + // QueryAccessTokenRequest is the request for QueryAccessToken type QueryAccessTokenRequest struct { AccessToken string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 1d10d1d8..2de8f960 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -125,6 +125,27 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil return nil } +func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error { + devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs) + if err != nil { + return err + } + res.DeviceInfo = make(map[string]struct { + DisplayName string + UserID string + }) + for _, d := range devices { + res.DeviceInfo[d.ID] = struct { + DisplayName string + UserID string + }{ + DisplayName: d.DisplayName, + UserID: d.UserID, + } + } + return nil +} + func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 4ab0d690..b2b42823 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -35,6 +35,7 @@ const ( QueryAccessTokenPath = "/userapi/queryAccessToken" QueryDevicesPath = "/userapi/queryDevices" QueryAccountDataPath = "/userapi/queryAccountData" + QueryDeviceInfosPath = "/userapi/queryDeviceInfos" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -101,6 +102,18 @@ func (h *httpUserInternalAPI) QueryProfile( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +func (h *httpUserInternalAPI) QueryDeviceInfos( + ctx context.Context, + request *api.QueryDeviceInfosRequest, + response *api.QueryDeviceInfosResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceInfos") + defer span.Finish() + + apiURL := h.apiURL + QueryDeviceInfosPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + func (h *httpUserInternalAPI) QueryAccessToken( ctx context.Context, request *api.QueryAccessTokenRequest, diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 8f3be773..d8e151ad 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/util" ) +// nolint: gocyclo func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { internalAPIMux.Handle(PerformAccountCreationPath, httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse { @@ -103,4 +104,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryDeviceInfosPath, + httputil.MakeInternalAPI("queryDeviceInfos", func(req *http.Request) util.JSONResponse { + request := api.QueryDeviceInfosRequest{} + response := api.QueryDeviceInfosResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryDeviceInfos(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index 4bdb5785..3c9ec934 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -24,6 +24,7 @@ type Database interface { GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) + GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) // CreateDevice makes a new device associated with the given user ID localpart. // If there is already a device with the same device ID for this user, that access token will be revoked // and replaced with the given accessToken. If the given accessToken is already in use for another device, diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go index 1d036d1b..03bf7c72 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/devices/postgres/devices_table.go @@ -84,11 +84,15 @@ const deleteDevicesByLocalpartSQL = "" + const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" +const selectDevicesByIDSQL = "" + + "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)" + type devicesStatements struct { insertDeviceStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt selectDeviceByIDStmt *sql.Stmt selectDevicesByLocalpartStmt *sql.Stmt + selectDevicesByIDStmt *sql.Stmt updateDeviceNameStmt *sql.Stmt deleteDeviceStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt @@ -125,6 +129,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil { return } + if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { + return + } s.serverName = server return } @@ -207,15 +214,42 @@ func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device + var displayName sql.NullString stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) if err == nil { dev.ID = deviceID dev.UserID = userutil.MakeUserID(localpart, s.serverName) + if displayName.Valid { + dev.DisplayName = displayName.String + } } return &dev, err } +func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed") + var devices []api.Device + for rows.Next() { + var dev api.Device + var localpart string + var displayName sql.NullString + if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil { + return nil, err + } + if displayName.Valid { + dev.DisplayName = displayName.String + } + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + devices = append(devices, dev) + } + return devices, rows.Err() +} + func (s *devicesStatements) selectDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index 801657bd..6ac802bb 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -71,6 +71,10 @@ func (d *Database) GetDevicesByLocalpart( return d.devices.selectDevicesByLocalpart(ctx, localpart) } +func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + return d.devices.selectDevicesByID(ctx, deviceIDs) +} + // CreateDevice makes a new device associated with the given user ID localpart. // If there is already a device with the same device ID for this user, that access token will be revoked // and replaced with the given accessToken. If the given accessToken is already in use for another device, diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index ec52c64b..efe6f927 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -20,6 +20,7 @@ import ( "strings" "time" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" @@ -72,6 +73,9 @@ const deleteDevicesByLocalpartSQL = "" + const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" +const selectDevicesByIDSQL = "" + + "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)" + type devicesStatements struct { db *sql.DB writer *sqlutil.TransactionWriter @@ -79,6 +83,7 @@ type devicesStatements struct { selectDevicesCountStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt selectDeviceByIDStmt *sql.Stmt + selectDevicesByIDStmt *sql.Stmt selectDevicesByLocalpartStmt *sql.Stmt updateDeviceNameStmt *sql.Stmt deleteDeviceStmt *sql.Stmt @@ -117,6 +122,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { return } + if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { + return + } s.serverName = server return } @@ -224,11 +232,15 @@ func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device + var displayName sql.NullString stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) if err == nil { dev.ID = deviceID dev.UserID = userutil.MakeUserID(localpart, s.serverName) + if displayName.Valid { + dev.DisplayName = displayName.String + } } return &dev, err } @@ -263,3 +275,32 @@ func (s *devicesStatements) selectDevicesByLocalpart( return devices, nil } + +func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1) + iDeviceIDs := make([]interface{}, len(deviceIDs)) + for i := range deviceIDs { + iDeviceIDs[i] = deviceIDs[i] + } + + rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed") + var devices []api.Device + for rows.Next() { + var dev api.Device + var localpart string + var displayName sql.NullString + if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil { + return nil, err + } + if displayName.Valid { + dev.DisplayName = displayName.String + } + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + devices = append(devices, dev) + } + return devices, rows.Err() +} diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index f248abda..b9f08ca1 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -77,6 +77,10 @@ func (d *Database) GetDevicesByLocalpart( return d.devices.selectDevicesByLocalpart(ctx, localpart) } +func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + return d.devices.selectDevicesByID(ctx, deviceIDs) +} + // CreateDevice makes a new device associated with the given user ID localpart. // If there is already a device with the same device ID for this user, that access token will be revoked // and replaced with the given accessToken. If the given accessToken is already in use for another device,