diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index 56b2faf7..5f7bfb18 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -17,6 +17,7 @@ package routing import ( "encoding/json" "net/http" + "time" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -25,18 +26,6 @@ import ( "github.com/matrix-org/util" ) -func QueryKeys( - req *http.Request, -) util.JSONResponse { - return util.JSONResponse{ - Code: http.StatusOK, - JSON: map[string]interface{}{ - "failures": map[string]interface{}{}, - "device_keys": map[string]interface{}{}, - }, - } -} - type uploadKeysRequest struct { DeviceKeys json.RawMessage `json:"device_keys"` OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"` @@ -94,3 +83,37 @@ func UploadKeys(req *http.Request, keyAPI api.KeyInternalAPI, device *userapi.De }{keyCount}, } } + +type queryKeysRequest struct { + Timeout int `json:"timeout"` + Token string `json:"token"` + DeviceKeys map[string][]string `json:"device_keys"` +} + +func (r *queryKeysRequest) GetTimeout() time.Duration { + if r.Timeout == 0 { + return 10 * time.Second + } + return time.Duration(r.Timeout) * time.Millisecond +} + +func QueryKeys(req *http.Request, keyAPI api.KeyInternalAPI) util.JSONResponse { + var r queryKeysRequest + resErr := httputil.UnmarshalJSONRequest(req, &r) + if resErr != nil { + return *resErr + } + queryRes := api.QueryKeysResponse{} + keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{ + UserToDevices: r.DeviceKeys, + Timeout: r.GetTimeout(), + // TODO: Token? + }, &queryRes) + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{ + "device_keys": queryRes.DeviceKeys, + "failures": queryRes.Failures, + }, + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 4879ddaa..492b7e25 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -698,12 +698,6 @@ func Setup( }), ).Methods(http.MethodGet) - r0mux.Handle("/keys/query", - httputil.MakeAuthAPI("queryKeys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return QueryKeys(req) - }), - ).Methods(http.MethodPost, http.MethodOptions) - // Supplying a device ID is deprecated. r0mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -715,4 +709,9 @@ func Setup( return UploadKeys(req, keyAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) + r0mux.Handle("/keys/query", + httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return QueryKeys(req, keyAPI) + }), + ).Methods(http.MethodPost, http.MethodOptions) } diff --git a/keyserver/api/api.go b/keyserver/api/api.go index d1eac703..0f6cb797 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "strings" + "time" ) type KeyInternalAPI interface { @@ -108,8 +109,16 @@ type PerformClaimKeysResponse struct { } type QueryKeysRequest struct { + // Maps user IDs to a list of devices + UserToDevices map[string][]string + Timeout time.Duration } type QueryKeysResponse struct { + // Map of remote server domain to error JSON + Failures map[string]interface{} + // Map of user_id to device_id to device_key + DeviceKeys map[string]map[string]json.RawMessage + // Set if there was a fatal error processing this query Error *KeyError } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index a7b0f93c..5be87aa4 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -17,15 +17,19 @@ package internal import ( "bytes" "context" + "encoding/json" "fmt" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage" + "github.com/matrix-org/gomatrixserverlib" "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) type KeyInternalAPI struct { - DB storage.Database + DB storage.Database + ThisServer gomatrixserverlib.ServerName } func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { @@ -37,7 +41,45 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC } func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { - + res.DeviceKeys = make(map[string]map[string]json.RawMessage) + res.Failures = make(map[string]interface{}) + // make a map from domain to device keys + domainToUserToDevice := make(map[string][]api.DeviceKeys) + for userID, deviceIDs := range req.UserToDevices { + _, serverName, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + continue // ignore invalid users + } + domain := string(serverName) + // query local devices + if serverName == a.ThisServer { + deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to query local device keys: %s", err), + } + return + } + if res.DeviceKeys[userID] == nil { + res.DeviceKeys[userID] = make(map[string]json.RawMessage) + } + for _, dk := range deviceKeys { + // inject an empty 'unsigned' key which should be used for display names + // (but not via this API? unsure when they should be added) + dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct{}{}) + res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON + } + } else { + for _, deviceID := range deviceIDs { + domainToUserToDevice[domain] = append(domainToUserToDevice[domain], api.DeviceKeys{ + UserID: userID, + DeviceID: deviceID, + }) + } + } + } + // TODO: set device display names when they are known + // TODO: perform key queries for remote devices } func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 405eac52..3c70fc21 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -41,6 +41,7 @@ func NewInternalAPI(cfg *config.Dendrite) api.KeyInternalAPI { logrus.WithError(err).Panicf("failed to connect to key server database") } return &internal.KeyInternalAPI{ - DB: db, + DB: db, + ThisServer: cfg.Matrix.ServerName, } } diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 3697b197..a626c66a 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -35,4 +35,8 @@ type Database interface { // StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. // Returns an error if there was a problem storing the keys. StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error + + // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. + // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. + DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) } diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index b05ec093..d915246c 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "time" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" @@ -45,10 +46,14 @@ const upsertDeviceKeysSQL = "" + const selectDeviceKeysSQL = "" + "SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" +const selectBatchDeviceKeysSQL = "" + + "SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1" + type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt } func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -65,6 +70,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil { return nil, err } + if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { + return nil, err + } return s, nil } @@ -95,3 +103,30 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api. return nil }) } + +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { + rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") + deviceIDMap := make(map[string]bool) + for _, d := range deviceIDs { + deviceIDMap[d] = true + } + var result []api.DeviceKeys + for rows.Next() { + var dk api.DeviceKeys + dk.UserID = userID + var keyJSON string + if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil { + return nil, err + } + dk.KeyJSON = []byte(keyJSON) + // include the key if we want all keys (no device) or it was asked + if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { + result = append(result, dk) + } + } + return result, rows.Err() +} diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 28e1f459..d5ac6458 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -44,3 +44,7 @@ func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) er func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { return d.DeviceKeysTable.InsertDeviceKeys(ctx, keys) } + +func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { + return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs) +} diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 93b8ecd8..69fe7a6e 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "time" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" @@ -45,10 +46,14 @@ const upsertDeviceKeysSQL = "" + const selectDeviceKeysSQL = "" + "SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" +const selectBatchDeviceKeysSQL = "" + + "SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1" + type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt } func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -65,9 +70,39 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil { return nil, err } + if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { + return nil, err + } return s, nil } +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { + deviceIDMap := make(map[string]bool) + for _, d := range deviceIDs { + deviceIDMap[d] = true + } + rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") + var result []api.DeviceKeys + for rows.Next() { + var dk api.DeviceKeys + dk.UserID = userID + var keyJSON string + if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil { + return nil, err + } + dk.KeyJSON = []byte(keyJSON) + // include the key if we want all keys (no device) or it was asked + if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { + result = append(result, dk) + } + } + return result, rows.Err() +} + func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { for i, key := range keys { var keyJSONStr string diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 20667ffb..1f7f686b 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -29,4 +29,5 @@ type OneTimeKeys interface { type DeviceKeys interface { SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error + SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) } diff --git a/sytest-whitelist b/sytest-whitelist index 3d40f042..a3df4e0c 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -121,6 +121,9 @@ local user can join room with version 1 User can invite local user to room with version 1 Can upload device keys Should reject keys claiming to belong to a different user +Can query device keys using POST +Can query specific device keys using POST +query for user with no keys returns empty key dict Can add account data Can add account data to room Can get account data without syncing