Key Backups (3/3) : Implement querying keys and various bugfixes (#1946)

* Add querying device keys

Makes a bunch of sytests pass

* Apparently only the current version supports uploading keys

* Linting
This commit is contained in:
kegsay 2021-07-27 19:29:32 +01:00 committed by GitHub
parent b3754d68fc
commit 32bf14a37c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 362 additions and 101 deletions

View File

@ -37,7 +37,7 @@ type keyBackupVersionCreateResponse struct {
type keyBackupVersionResponse struct { type keyBackupVersionResponse struct {
Algorithm string `json:"algorithm"` Algorithm string `json:"algorithm"`
AuthData json.RawMessage `json:"auth_data"` AuthData json.RawMessage `json:"auth_data"`
Count int `json:"count"` Count int64 `json:"count"`
ETag string `json:"etag"` ETag string `json:"etag"`
Version string `json:"version"` Version string `json:"version"`
} }
@ -89,7 +89,10 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI,
// Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version} // Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version}
func KeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { func KeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse {
var queryResp userapi.QueryKeyBackupResponse var queryResp userapi.QueryKeyBackupResponse
userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{}, &queryResp) userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
UserID: device.UserID,
Version: version,
}, &queryResp)
if queryResp.Error != "" { if queryResp.Error != "" {
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
} }
@ -216,3 +219,73 @@ func UploadBackupKeys(
}, },
} }
} }
// Get keys from a given backup version. Response returned varies depending on if roomID and sessionID are set.
func GetBackupKeys(
req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version, roomID, sessionID string,
) util.JSONResponse {
var queryResp userapi.QueryKeyBackupResponse
userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
UserID: device.UserID,
Version: version,
ReturnKeys: true,
KeysForRoomID: roomID,
KeysForSessionID: sessionID,
}, &queryResp)
if queryResp.Error != "" {
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
}
if !queryResp.Exists {
return util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound("version not found"),
}
}
if sessionID != "" {
// return the key itself if it was found
roomData, ok := queryResp.Keys[roomID]
if ok {
key, ok := roomData[sessionID]
if ok {
return util.JSONResponse{
Code: 200,
JSON: key,
}
}
}
} else if roomID != "" {
roomData, ok := queryResp.Keys[roomID]
if ok {
// wrap response in "sessions"
return util.JSONResponse{
Code: 200,
JSON: struct {
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
}{
Sessions: roomData,
},
}
}
} else {
// response is the same as the upload request
var resp keyBackupSessionRequest
resp.Rooms = make(map[string]struct {
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
})
for roomID, roomData := range queryResp.Keys {
resp.Rooms[roomID] = struct {
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
}{
Sessions: roomData,
}
}
return util.JSONResponse{
Code: 200,
JSON: resp,
}
}
return util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound("keys not found"),
}
}

View File

@ -896,11 +896,15 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
// Key Backup Versions // Key Backup Versions (Metadata)
r0mux.Handle("/room_keys/version/{versionID}",
r0mux.Handle("/room_keys/version/{version}",
httputil.MakeAuthAPI("get_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("get_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
version := req.URL.Query().Get("version") vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
return KeyBackupVersion(req, userAPI, device, version) if err != nil {
return util.ErrorResponse(err)
}
return KeyBackupVersion(req, userAPI, device, vars["version"])
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/version", r0mux.Handle("/room_keys/version",
@ -908,28 +912,22 @@ func Setup(
return KeyBackupVersion(req, userAPI, device, "") return KeyBackupVersion(req, userAPI, device, "")
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/version/{versionID}", r0mux.Handle("/room_keys/version/{version}",
httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
version := req.URL.Query().Get("version") vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if version == "" { if err != nil {
return util.JSONResponse{ return util.ErrorResponse(err)
Code: 400,
JSON: jsonerror.InvalidArgumentValue("version must be specified"),
}
} }
return ModifyKeyBackupVersionAuthData(req, userAPI, device, version) return ModifyKeyBackupVersionAuthData(req, userAPI, device, vars["version"])
}), }),
).Methods(http.MethodPut) ).Methods(http.MethodPut)
r0mux.Handle("/room_keys/version/{versionID}", r0mux.Handle("/room_keys/version/{version}",
httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
version := req.URL.Query().Get("version") vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if version == "" { if err != nil {
return util.JSONResponse{ return util.ErrorResponse(err)
Code: 400,
JSON: jsonerror.InvalidArgumentValue("version must be specified"),
}
} }
return DeleteKeyBackupVersion(req, userAPI, device, version) return DeleteKeyBackupVersion(req, userAPI, device, vars["version"])
}), }),
).Methods(http.MethodDelete) ).Methods(http.MethodDelete)
r0mux.Handle("/room_keys/version", r0mux.Handle("/room_keys/version",
@ -938,7 +936,8 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
// E2E Backup Keys // Inserting E2E Backup Keys
// Bulk room and session // Bulk room and session
r0mux.Handle("/room_keys/keys", r0mux.Handle("/room_keys/keys",
httputil.MakeAuthAPI("put_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("put_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -973,6 +972,9 @@ func Setup(
} }
roomID := vars["roomID"] roomID := vars["roomID"]
var reqBody keyBackupSessionRequest var reqBody keyBackupSessionRequest
reqBody.Rooms = make(map[string]struct {
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
})
reqBody.Rooms[roomID] = struct { reqBody.Rooms[roomID] = struct {
Sessions map[string]userapi.KeyBackupSession `json:"sessions"` Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
}{ }{
@ -989,7 +991,7 @@ func Setup(
).Methods(http.MethodPut) ).Methods(http.MethodPut)
// Single room, single session // Single room, single session
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}",
httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("put_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -1009,14 +1011,47 @@ func Setup(
roomID := vars["roomID"] roomID := vars["roomID"]
sessionID := vars["sessionID"] sessionID := vars["sessionID"]
var keyReq keyBackupSessionRequest var keyReq keyBackupSessionRequest
keyReq.Rooms = make(map[string]struct {
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
})
keyReq.Rooms[roomID] = struct { keyReq.Rooms[roomID] = struct {
Sessions map[string]userapi.KeyBackupSession `json:"sessions"` Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
}{} }{
Sessions: make(map[string]userapi.KeyBackupSession),
}
keyReq.Rooms[roomID].Sessions[sessionID] = reqBody keyReq.Rooms[roomID].Sessions[sessionID] = reqBody
return UploadBackupKeys(req, userAPI, device, version, &keyReq) return UploadBackupKeys(req, userAPI, device, version, &keyReq)
}), }),
).Methods(http.MethodPut) ).Methods(http.MethodPut)
// Querying E2E Backup Keys
r0mux.Handle("/room_keys/keys",
httputil.MakeAuthAPI("get_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), "", "")
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/keys/{roomID}",
httputil.MakeAuthAPI("get_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], "")
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}",
httputil.MakeAuthAPI("get_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"])
}),
).Methods(http.MethodGet, http.MethodOptions)
// Deleting E2E Backup Keys
// Supplying a device ID is deprecated. // Supplying a device ID is deprecated.
r0mux.Handle("/keys/upload/{deviceID}", r0mux.Handle("/keys/upload/{deviceID}",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {

View File

@ -540,3 +540,13 @@ Key notary server must not overwrite a valid key with a spurious result from the
GET /rooms/:room_id/aliases lists aliases GET /rooms/:room_id/aliases lists aliases
Only room members can list aliases of a room Only room members can list aliases of a room
Users with sufficient power-level can delete other's aliases Users with sufficient power-level can delete other's aliases
Can create backup version
Can update backup version
Responds correctly when backup is empty
Can backup keys
Can update keys with better versions
Will not update keys with worse versions
Will not back up to an old backup version
Can create more than 10 backup versions
Can delete backup
Deleted & recreated backups are empty

View File

@ -67,6 +67,23 @@ type KeyBackupSession struct {
SessionData json.RawMessage `json:"session_data"` SessionData json.RawMessage `json:"session_data"`
} }
func (a *KeyBackupSession) ShouldReplaceRoomKey(newKey *KeyBackupSession) bool {
// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
// "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
if newKey.IsVerified && !a.IsVerified {
return true
}
// "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
if newKey.FirstMessageIndex < a.FirstMessageIndex {
return true
}
// "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
if newKey.ForwardedCount < a.ForwardedCount {
return true
}
return false
}
// Internal KeyBackupData for passing to/from the storage layer // Internal KeyBackupData for passing to/from the storage layer
type InternalKeyBackupSession struct { type InternalKeyBackupSession struct {
KeyBackupSession KeyBackupSession
@ -88,6 +105,10 @@ type PerformKeyBackupResponse struct {
type QueryKeyBackupRequest struct { type QueryKeyBackupRequest struct {
UserID string UserID string
Version string // the version to query, if blank it means the latest Version string // the version to query, if blank it means the latest
ReturnKeys bool // whether to return keys in the backup response or just the metadata
KeysForRoomID string // optional string to return keys which belong to this room
KeysForSessionID string // optional string to return keys which belong to this (room, session)
} }
type QueryKeyBackupResponse struct { type QueryKeyBackupResponse struct {
@ -96,9 +117,11 @@ type QueryKeyBackupResponse struct {
Algorithm string `json:"algorithm"` Algorithm string `json:"algorithm"`
AuthData json.RawMessage `json:"auth_data"` AuthData json.RawMessage `json:"auth_data"`
Count int `json:"count"` Count int64 `json:"count"`
ETag string `json:"etag"` ETag string `json:"etag"`
Version string `json:"version"` Version string `json:"version"`
Keys map[string]map[string]KeyBackupSession // the keys if ReturnKeys=true
} }
// InputAccountDataRequest is the request for InputAccountData // InputAccountDataRequest is the request for InputAccountData

View File

@ -475,6 +475,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to update backup: %s", err) res.Error = fmt.Sprintf("failed to update backup: %s", err)
} }
res.Exists = err == nil
res.Version = req.Version res.Version = req.Version
return return
} }
@ -483,8 +484,8 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
} }
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
// ensure the version metadata exists // you can only upload keys for the CURRENT version
version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version) version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "")
if err != nil { if err != nil {
res.Error = fmt.Sprintf("failed to query version: %s", err) res.Error = fmt.Sprintf("failed to query version: %s", err)
return return
@ -493,6 +494,11 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
res.Error = "backup was deleted" res.Error = "backup was deleted"
return return
} }
if version != req.Version {
res.BadInput = true
res.Error = fmt.Sprintf("%s isn't the current version, %s is.", req.Version, version)
return
}
res.Exists = true res.Exists = true
res.Version = version res.Version = version
@ -529,9 +535,21 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
} }
res.Algorithm = algorithm res.Algorithm = algorithm
res.AuthData = authData res.AuthData = authData
res.ETag = etag
res.Exists = !deleted res.Exists = !deleted
// TODO: if !req.ReturnKeys {
res.Count = 0 res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID)
res.ETag = etag if err != nil {
res.Error = fmt.Sprintf("failed to count keys: %s", err)
}
return
}
result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
if err != nil {
res.Error = fmt.Sprintf("failed to query keys: %s", err)
return
}
res.Keys = result
} }

View File

@ -61,6 +61,8 @@ type Database interface {
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error) DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error) UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
} }
// Err3PIDInUse is the error returned when trying to save an association involving // Err3PIDInUse is the error returned when trying to save an association involving

View File

@ -35,7 +35,8 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
is_verified BOOLEAN NOT NULL, is_verified BOOLEAN NOT NULL,
session_data TEXT NOT NULL session_data TEXT NOT NULL
); );
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id); CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
` `
const insertBackupKeySQL = "" + const insertBackupKeySQL = "" +
@ -53,14 +54,23 @@ const selectKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"WHERE user_id = $1 AND version = $2" "WHERE user_id = $1 AND version = $2"
const selectKeysByRoomIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3"
const selectKeysByRoomIDAndSessionIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
type keyBackupStatements struct { type keyBackupStatements struct {
insertBackupKeyStmt *sql.Stmt insertBackupKeyStmt *sql.Stmt
updateBackupKeyStmt *sql.Stmt updateBackupKeyStmt *sql.Stmt
countKeysStmt *sql.Stmt countKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt selectKeysStmt *sql.Stmt
selectKeysByRoomIDStmt *sql.Stmt
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
} }
// nolint:unused
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupTableSchema) _, err = db.Exec(keyBackupTableSchema)
if err != nil { if err != nil {
@ -78,6 +88,12 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
return return
} }
if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil {
return
}
if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil {
return
}
return return
} }
@ -109,11 +125,35 @@ func (s *keyBackupStatements) updateBackupKey(
func (s *keyBackupStatements) selectKeys( func (s *keyBackupStatements) selectKeys(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) { ) (map[string]map[string]api.KeyBackupSession, error) {
result := make(map[string]map[string]api.KeyBackupSession)
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return unpackKeys(ctx, rows)
}
func (s *keyBackupStatements) selectKeysByRoomID(
ctx context.Context, txn *sql.Tx, userID, version, roomID string,
) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
if err != nil {
return nil, err
}
return unpackKeys(ctx, rows)
}
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
if err != nil {
return nil, err
}
return unpackKeys(ctx, rows)
}
func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) {
result := make(map[string]map[string]api.KeyBackupSession)
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed") defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed")
for rows.Next() { for rows.Next() {
var key api.InternalKeyBackupSession var key api.InternalKeyBackupSession

View File

@ -67,7 +67,6 @@ type keyBackupVersionStatements struct {
updateKeyBackupETagStmt *sql.Stmt updateKeyBackupETagStmt *sql.Stmt
} }
// nolint:unused
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupVersionTableSchema) _, err = db.Exec(keyBackupVersionTableSchema)
if err != nil { if err != nil {

View File

@ -96,13 +96,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.openIDTokens.prepare(db, serverName); err != nil { if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err return nil, err
} }
/* if err = d.keyBackupVersions.prepare(db); err != nil {
if err = d.keyBackupVersions.prepare(db); err != nil { return nil, err
return nil, err }
} if err = d.keyBackups.prepare(db); err != nil {
if err = d.keyBackups.prepare(db); err != nil { return nil, err
return nil, err }
} */
return d, nil return d, nil
} }
@ -418,6 +417,37 @@ func (d *Database) GetKeyBackup(
return return
} }
func (d *Database) GetBackupKeys(
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
) (result map[string]map[string]api.KeyBackupSession, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if filterSessionID != "" {
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
return err
}
if filterRoomID != "" {
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
return err
}
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) CountBackupKeys(
ctx context.Context, version, userID string,
) (count int64, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
if err != nil {
return err
}
return nil
})
return
}
// nolint:nakedret // nolint:nakedret
func (d *Database) UpsertBackupKeys( func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
@ -445,7 +475,7 @@ func (d *Database) UpsertBackupKeys(
if existingRoom != nil { if existingRoom != nil {
existingSession, ok := existingRoom[newKey.SessionID] existingSession, ok := existingRoom[newKey.SessionID]
if ok { if ok {
if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) { if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
changed = true changed = true
if err != nil { if err != nil {
@ -489,22 +519,3 @@ func (d *Database) UpsertBackupKeys(
}) })
return return
} }
// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't
// create circular import loops
func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool {
// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
// "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
if uploaded.IsVerified && !existing.IsVerified {
return true
}
// "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
if uploaded.FirstMessageIndex < existing.FirstMessageIndex {
return true
}
// "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
if uploaded.ForwardedCount < existing.ForwardedCount {
return true
}
return false
}

View File

@ -35,7 +35,8 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
is_verified BOOLEAN NOT NULL, is_verified BOOLEAN NOT NULL,
session_data TEXT NOT NULL session_data TEXT NOT NULL
); );
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id); CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
` `
const insertBackupKeySQL = "" + const insertBackupKeySQL = "" +
@ -53,14 +54,23 @@ const selectKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"WHERE user_id = $1 AND version = $2" "WHERE user_id = $1 AND version = $2"
const selectKeysByRoomIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3"
const selectKeysByRoomIDAndSessionIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
type keyBackupStatements struct { type keyBackupStatements struct {
insertBackupKeyStmt *sql.Stmt insertBackupKeyStmt *sql.Stmt
updateBackupKeyStmt *sql.Stmt updateBackupKeyStmt *sql.Stmt
countKeysStmt *sql.Stmt countKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt selectKeysStmt *sql.Stmt
selectKeysByRoomIDStmt *sql.Stmt
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
} }
// nolint:unused
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupTableSchema) _, err = db.Exec(keyBackupTableSchema)
if err != nil { if err != nil {
@ -78,6 +88,12 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
return return
} }
if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil {
return
}
if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil {
return
}
return return
} }
@ -109,11 +125,35 @@ func (s *keyBackupStatements) updateBackupKey(
func (s *keyBackupStatements) selectKeys( func (s *keyBackupStatements) selectKeys(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) { ) (map[string]map[string]api.KeyBackupSession, error) {
result := make(map[string]map[string]api.KeyBackupSession)
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return unpackKeys(ctx, rows)
}
func (s *keyBackupStatements) selectKeysByRoomID(
ctx context.Context, txn *sql.Tx, userID, version, roomID string,
) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
if err != nil {
return nil, err
}
return unpackKeys(ctx, rows)
}
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
if err != nil {
return nil, err
}
return unpackKeys(ctx, rows)
}
func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) {
result := make(map[string]map[string]api.KeyBackupSession)
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed") defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed")
for rows.Next() { for rows.Next() {
var key api.InternalKeyBackupSession var key api.InternalKeyBackupSession

View File

@ -65,7 +65,6 @@ type keyBackupVersionStatements struct {
updateKeyBackupETagStmt *sql.Stmt updateKeyBackupETagStmt *sql.Stmt
} }
// nolint:unused
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupVersionTableSchema) _, err = db.Exec(keyBackupVersionTableSchema)
if err != nil { if err != nil {

View File

@ -100,13 +100,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.openIDTokens.prepare(db, serverName); err != nil { if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err return nil, err
} }
/* if err = d.keyBackupVersions.prepare(db); err != nil {
if err = d.keyBackupVersions.prepare(db); err != nil { return nil, err
return nil, err }
} if err = d.keyBackups.prepare(db); err != nil {
if err = d.keyBackups.prepare(db); err != nil { return nil, err
return nil, err }
} */
return d, nil return d, nil
} }
@ -459,6 +458,37 @@ func (d *Database) GetKeyBackup(
return return
} }
func (d *Database) GetBackupKeys(
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
) (result map[string]map[string]api.KeyBackupSession, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if filterSessionID != "" {
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
return err
}
if filterRoomID != "" {
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
return err
}
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) CountBackupKeys(
ctx context.Context, version, userID string,
) (count int64, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
if err != nil {
return err
}
return nil
})
return
}
// nolint:nakedret // nolint:nakedret
func (d *Database) UpsertBackupKeys( func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
@ -486,7 +516,7 @@ func (d *Database) UpsertBackupKeys(
if existingRoom != nil { if existingRoom != nil {
existingSession, ok := existingRoom[newKey.SessionID] existingSession, ok := existingRoom[newKey.SessionID]
if ok { if ok {
if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) { if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
changed = true changed = true
if err != nil { if err != nil {
@ -531,22 +561,3 @@ func (d *Database) UpsertBackupKeys(
}) })
return return
} }
// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't
// create circular import loops
func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool {
// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
// "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
if uploaded.IsVerified && !existing.IsVerified {
return true
}
// "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
if uploaded.FirstMessageIndex < existing.FirstMessageIndex {
return true
}
// "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
if uploaded.ForwardedCount < existing.ForwardedCount {
return true
}
return false
}