mirror of
https://github.com/1f349/dendrite.git
synced 2024-11-25 05:01:41 +00:00
Key Backups (3/3) : Implement querying keys and various bugfixes (#1946)
* Add querying device keys Makes a bunch of sytests pass * Apparently only the current version supports uploading keys * Linting
This commit is contained in:
parent
b3754d68fc
commit
32bf14a37c
@ -37,7 +37,7 @@ type keyBackupVersionCreateResponse struct {
|
||||
type keyBackupVersionResponse struct {
|
||||
Algorithm string `json:"algorithm"`
|
||||
AuthData json.RawMessage `json:"auth_data"`
|
||||
Count int `json:"count"`
|
||||
Count int64 `json:"count"`
|
||||
ETag string `json:"etag"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
@ -89,7 +89,10 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI,
|
||||
// Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version}
|
||||
func KeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse {
|
||||
var queryResp userapi.QueryKeyBackupResponse
|
||||
userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{}, &queryResp)
|
||||
userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
|
||||
UserID: device.UserID,
|
||||
Version: version,
|
||||
}, &queryResp)
|
||||
if queryResp.Error != "" {
|
||||
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
|
||||
}
|
||||
@ -216,3 +219,73 @@ func UploadBackupKeys(
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Get keys from a given backup version. Response returned varies depending on if roomID and sessionID are set.
|
||||
func GetBackupKeys(
|
||||
req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version, roomID, sessionID string,
|
||||
) util.JSONResponse {
|
||||
var queryResp userapi.QueryKeyBackupResponse
|
||||
userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
|
||||
UserID: device.UserID,
|
||||
Version: version,
|
||||
ReturnKeys: true,
|
||||
KeysForRoomID: roomID,
|
||||
KeysForSessionID: sessionID,
|
||||
}, &queryResp)
|
||||
if queryResp.Error != "" {
|
||||
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
|
||||
}
|
||||
if !queryResp.Exists {
|
||||
return util.JSONResponse{
|
||||
Code: 404,
|
||||
JSON: jsonerror.NotFound("version not found"),
|
||||
}
|
||||
}
|
||||
if sessionID != "" {
|
||||
// return the key itself if it was found
|
||||
roomData, ok := queryResp.Keys[roomID]
|
||||
if ok {
|
||||
key, ok := roomData[sessionID]
|
||||
if ok {
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: key,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if roomID != "" {
|
||||
roomData, ok := queryResp.Keys[roomID]
|
||||
if ok {
|
||||
// wrap response in "sessions"
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: struct {
|
||||
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
||||
}{
|
||||
Sessions: roomData,
|
||||
},
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// response is the same as the upload request
|
||||
var resp keyBackupSessionRequest
|
||||
resp.Rooms = make(map[string]struct {
|
||||
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
||||
})
|
||||
for roomID, roomData := range queryResp.Keys {
|
||||
resp.Rooms[roomID] = struct {
|
||||
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
||||
}{
|
||||
Sessions: roomData,
|
||||
}
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: resp,
|
||||
}
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: 404,
|
||||
JSON: jsonerror.NotFound("keys not found"),
|
||||
}
|
||||
}
|
||||
|
@ -896,11 +896,15 @@ func Setup(
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
// Key Backup Versions
|
||||
r0mux.Handle("/room_keys/version/{versionID}",
|
||||
// Key Backup Versions (Metadata)
|
||||
|
||||
r0mux.Handle("/room_keys/version/{version}",
|
||||
httputil.MakeAuthAPI("get_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
version := req.URL.Query().Get("version")
|
||||
return KeyBackupVersion(req, userAPI, device, version)
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return KeyBackupVersion(req, userAPI, device, vars["version"])
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
r0mux.Handle("/room_keys/version",
|
||||
@ -908,28 +912,22 @@ func Setup(
|
||||
return KeyBackupVersion(req, userAPI, device, "")
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
r0mux.Handle("/room_keys/version/{versionID}",
|
||||
r0mux.Handle("/room_keys/version/{version}",
|
||||
httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
version := req.URL.Query().Get("version")
|
||||
if version == "" {
|
||||
return util.JSONResponse{
|
||||
Code: 400,
|
||||
JSON: jsonerror.InvalidArgumentValue("version must be specified"),
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
}
|
||||
return ModifyKeyBackupVersionAuthData(req, userAPI, device, version)
|
||||
return ModifyKeyBackupVersionAuthData(req, userAPI, device, vars["version"])
|
||||
}),
|
||||
).Methods(http.MethodPut)
|
||||
r0mux.Handle("/room_keys/version/{versionID}",
|
||||
r0mux.Handle("/room_keys/version/{version}",
|
||||
httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
version := req.URL.Query().Get("version")
|
||||
if version == "" {
|
||||
return util.JSONResponse{
|
||||
Code: 400,
|
||||
JSON: jsonerror.InvalidArgumentValue("version must be specified"),
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
}
|
||||
return DeleteKeyBackupVersion(req, userAPI, device, version)
|
||||
return DeleteKeyBackupVersion(req, userAPI, device, vars["version"])
|
||||
}),
|
||||
).Methods(http.MethodDelete)
|
||||
r0mux.Handle("/room_keys/version",
|
||||
@ -938,7 +936,8 @@ func Setup(
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
// E2E Backup Keys
|
||||
// Inserting E2E Backup Keys
|
||||
|
||||
// Bulk room and session
|
||||
r0mux.Handle("/room_keys/keys",
|
||||
httputil.MakeAuthAPI("put_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
@ -973,6 +972,9 @@ func Setup(
|
||||
}
|
||||
roomID := vars["roomID"]
|
||||
var reqBody keyBackupSessionRequest
|
||||
reqBody.Rooms = make(map[string]struct {
|
||||
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
||||
})
|
||||
reqBody.Rooms[roomID] = struct {
|
||||
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
||||
}{
|
||||
@ -989,7 +991,7 @@ func Setup(
|
||||
).Methods(http.MethodPut)
|
||||
// Single room, single session
|
||||
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}",
|
||||
httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
httputil.MakeAuthAPI("put_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
@ -1009,14 +1011,47 @@ func Setup(
|
||||
roomID := vars["roomID"]
|
||||
sessionID := vars["sessionID"]
|
||||
var keyReq keyBackupSessionRequest
|
||||
keyReq.Rooms = make(map[string]struct {
|
||||
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
||||
})
|
||||
keyReq.Rooms[roomID] = struct {
|
||||
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
||||
}{}
|
||||
}{
|
||||
Sessions: make(map[string]userapi.KeyBackupSession),
|
||||
}
|
||||
keyReq.Rooms[roomID].Sessions[sessionID] = reqBody
|
||||
return UploadBackupKeys(req, userAPI, device, version, &keyReq)
|
||||
}),
|
||||
).Methods(http.MethodPut)
|
||||
|
||||
// Querying E2E Backup Keys
|
||||
|
||||
r0mux.Handle("/room_keys/keys",
|
||||
httputil.MakeAuthAPI("get_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), "", "")
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
r0mux.Handle("/room_keys/keys/{roomID}",
|
||||
httputil.MakeAuthAPI("get_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], "")
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}",
|
||||
httputil.MakeAuthAPI("get_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"])
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
// Deleting E2E Backup Keys
|
||||
|
||||
// Supplying a device ID is deprecated.
|
||||
r0mux.Handle("/keys/upload/{deviceID}",
|
||||
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
@ -540,3 +540,13 @@ Key notary server must not overwrite a valid key with a spurious result from the
|
||||
GET /rooms/:room_id/aliases lists aliases
|
||||
Only room members can list aliases of a room
|
||||
Users with sufficient power-level can delete other's aliases
|
||||
Can create backup version
|
||||
Can update backup version
|
||||
Responds correctly when backup is empty
|
||||
Can backup keys
|
||||
Can update keys with better versions
|
||||
Will not update keys with worse versions
|
||||
Will not back up to an old backup version
|
||||
Can create more than 10 backup versions
|
||||
Can delete backup
|
||||
Deleted & recreated backups are empty
|
||||
|
@ -67,6 +67,23 @@ type KeyBackupSession struct {
|
||||
SessionData json.RawMessage `json:"session_data"`
|
||||
}
|
||||
|
||||
func (a *KeyBackupSession) ShouldReplaceRoomKey(newKey *KeyBackupSession) bool {
|
||||
// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
|
||||
// "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
|
||||
if newKey.IsVerified && !a.IsVerified {
|
||||
return true
|
||||
}
|
||||
// "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
|
||||
if newKey.FirstMessageIndex < a.FirstMessageIndex {
|
||||
return true
|
||||
}
|
||||
// "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
|
||||
if newKey.ForwardedCount < a.ForwardedCount {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Internal KeyBackupData for passing to/from the storage layer
|
||||
type InternalKeyBackupSession struct {
|
||||
KeyBackupSession
|
||||
@ -88,6 +105,10 @@ type PerformKeyBackupResponse struct {
|
||||
type QueryKeyBackupRequest struct {
|
||||
UserID string
|
||||
Version string // the version to query, if blank it means the latest
|
||||
|
||||
ReturnKeys bool // whether to return keys in the backup response or just the metadata
|
||||
KeysForRoomID string // optional string to return keys which belong to this room
|
||||
KeysForSessionID string // optional string to return keys which belong to this (room, session)
|
||||
}
|
||||
|
||||
type QueryKeyBackupResponse struct {
|
||||
@ -96,9 +117,11 @@ type QueryKeyBackupResponse struct {
|
||||
|
||||
Algorithm string `json:"algorithm"`
|
||||
AuthData json.RawMessage `json:"auth_data"`
|
||||
Count int `json:"count"`
|
||||
Count int64 `json:"count"`
|
||||
ETag string `json:"etag"`
|
||||
Version string `json:"version"`
|
||||
|
||||
Keys map[string]map[string]KeyBackupSession // the keys if ReturnKeys=true
|
||||
}
|
||||
|
||||
// InputAccountDataRequest is the request for InputAccountData
|
||||
|
@ -475,6 +475,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to update backup: %s", err)
|
||||
}
|
||||
res.Exists = err == nil
|
||||
res.Version = req.Version
|
||||
return
|
||||
}
|
||||
@ -483,8 +484,8 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
|
||||
// ensure the version metadata exists
|
||||
version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version)
|
||||
// you can only upload keys for the CURRENT version
|
||||
version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "")
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to query version: %s", err)
|
||||
return
|
||||
@ -493,6 +494,11 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
|
||||
res.Error = "backup was deleted"
|
||||
return
|
||||
}
|
||||
if version != req.Version {
|
||||
res.BadInput = true
|
||||
res.Error = fmt.Sprintf("%s isn't the current version, %s is.", req.Version, version)
|
||||
return
|
||||
}
|
||||
res.Exists = true
|
||||
res.Version = version
|
||||
|
||||
@ -529,9 +535,21 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
|
||||
}
|
||||
res.Algorithm = algorithm
|
||||
res.AuthData = authData
|
||||
res.ETag = etag
|
||||
res.Exists = !deleted
|
||||
|
||||
// TODO:
|
||||
res.Count = 0
|
||||
res.ETag = etag
|
||||
if !req.ReturnKeys {
|
||||
res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID)
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to count keys: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
|
||||
if err != nil {
|
||||
res.Error = fmt.Sprintf("failed to query keys: %s", err)
|
||||
return
|
||||
}
|
||||
res.Keys = result
|
||||
}
|
||||
|
@ -61,6 +61,8 @@ type Database interface {
|
||||
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
|
||||
GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
|
||||
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
|
||||
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
|
||||
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
|
||||
}
|
||||
|
||||
// Err3PIDInUse is the error returned when trying to save an association involving
|
||||
|
@ -35,7 +35,8 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
|
||||
is_verified BOOLEAN NOT NULL,
|
||||
session_data TEXT NOT NULL
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
|
||||
`
|
||||
|
||||
const insertBackupKeySQL = "" +
|
||||
@ -53,14 +54,23 @@ const selectKeysSQL = "" +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||
"WHERE user_id = $1 AND version = $2"
|
||||
|
||||
const selectKeysByRoomIDSQL = "" +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||
"WHERE user_id = $1 AND version = $2 AND room_id = $3"
|
||||
|
||||
const selectKeysByRoomIDAndSessionIDSQL = "" +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||
"WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
|
||||
|
||||
type keyBackupStatements struct {
|
||||
insertBackupKeyStmt *sql.Stmt
|
||||
updateBackupKeyStmt *sql.Stmt
|
||||
countKeysStmt *sql.Stmt
|
||||
selectKeysStmt *sql.Stmt
|
||||
selectKeysByRoomIDStmt *sql.Stmt
|
||||
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
// nolint:unused
|
||||
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(keyBackupTableSchema)
|
||||
if err != nil {
|
||||
@ -78,6 +88,12 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
||||
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -109,11 +125,35 @@ func (s *keyBackupStatements) updateBackupKey(
|
||||
func (s *keyBackupStatements) selectKeys(
|
||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||
result := make(map[string]map[string]api.KeyBackupSession)
|
||||
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return unpackKeys(ctx, rows)
|
||||
}
|
||||
|
||||
func (s *keyBackupStatements) selectKeysByRoomID(
|
||||
ctx context.Context, txn *sql.Tx, userID, version, roomID string,
|
||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return unpackKeys(ctx, rows)
|
||||
}
|
||||
|
||||
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
|
||||
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
|
||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return unpackKeys(ctx, rows)
|
||||
}
|
||||
|
||||
func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) {
|
||||
result := make(map[string]map[string]api.KeyBackupSession)
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed")
|
||||
for rows.Next() {
|
||||
var key api.InternalKeyBackupSession
|
||||
|
@ -67,7 +67,6 @@ type keyBackupVersionStatements struct {
|
||||
updateKeyBackupETagStmt *sql.Stmt
|
||||
}
|
||||
|
||||
// nolint:unused
|
||||
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(keyBackupVersionTableSchema)
|
||||
if err != nil {
|
||||
|
@ -96,13 +96,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||
if err = d.openIDTokens.prepare(db, serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
/*
|
||||
if err = d.keyBackupVersions.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.keyBackups.prepare(db); err != nil {
|
||||
return nil, err
|
||||
} */
|
||||
}
|
||||
|
||||
return d, nil
|
||||
}
|
||||
@ -418,6 +417,37 @@ func (d *Database) GetKeyBackup(
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) GetBackupKeys(
|
||||
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
|
||||
) (result map[string]map[string]api.KeyBackupSession, err error) {
|
||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
if filterSessionID != "" {
|
||||
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
|
||||
return err
|
||||
}
|
||||
if filterRoomID != "" {
|
||||
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
|
||||
return err
|
||||
}
|
||||
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) CountBackupKeys(
|
||||
ctx context.Context, version, userID string,
|
||||
) (count int64, err error) {
|
||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// nolint:nakedret
|
||||
func (d *Database) UpsertBackupKeys(
|
||||
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
|
||||
@ -445,7 +475,7 @@ func (d *Database) UpsertBackupKeys(
|
||||
if existingRoom != nil {
|
||||
existingSession, ok := existingRoom[newKey.SessionID]
|
||||
if ok {
|
||||
if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) {
|
||||
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
|
||||
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
|
||||
changed = true
|
||||
if err != nil {
|
||||
@ -489,22 +519,3 @@ func (d *Database) UpsertBackupKeys(
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't
|
||||
// create circular import loops
|
||||
func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool {
|
||||
// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
|
||||
// "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
|
||||
if uploaded.IsVerified && !existing.IsVerified {
|
||||
return true
|
||||
}
|
||||
// "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
|
||||
if uploaded.FirstMessageIndex < existing.FirstMessageIndex {
|
||||
return true
|
||||
}
|
||||
// "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
|
||||
if uploaded.ForwardedCount < existing.ForwardedCount {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -35,7 +35,8 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
|
||||
is_verified BOOLEAN NOT NULL,
|
||||
session_data TEXT NOT NULL
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
|
||||
`
|
||||
|
||||
const insertBackupKeySQL = "" +
|
||||
@ -53,14 +54,23 @@ const selectKeysSQL = "" +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||
"WHERE user_id = $1 AND version = $2"
|
||||
|
||||
const selectKeysByRoomIDSQL = "" +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||
"WHERE user_id = $1 AND version = $2 AND room_id = $3"
|
||||
|
||||
const selectKeysByRoomIDAndSessionIDSQL = "" +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||
"WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
|
||||
|
||||
type keyBackupStatements struct {
|
||||
insertBackupKeyStmt *sql.Stmt
|
||||
updateBackupKeyStmt *sql.Stmt
|
||||
countKeysStmt *sql.Stmt
|
||||
selectKeysStmt *sql.Stmt
|
||||
selectKeysByRoomIDStmt *sql.Stmt
|
||||
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
// nolint:unused
|
||||
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(keyBackupTableSchema)
|
||||
if err != nil {
|
||||
@ -78,6 +88,12 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
||||
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -109,11 +125,35 @@ func (s *keyBackupStatements) updateBackupKey(
|
||||
func (s *keyBackupStatements) selectKeys(
|
||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||
result := make(map[string]map[string]api.KeyBackupSession)
|
||||
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return unpackKeys(ctx, rows)
|
||||
}
|
||||
|
||||
func (s *keyBackupStatements) selectKeysByRoomID(
|
||||
ctx context.Context, txn *sql.Tx, userID, version, roomID string,
|
||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return unpackKeys(ctx, rows)
|
||||
}
|
||||
|
||||
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
|
||||
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
|
||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return unpackKeys(ctx, rows)
|
||||
}
|
||||
|
||||
func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) {
|
||||
result := make(map[string]map[string]api.KeyBackupSession)
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed")
|
||||
for rows.Next() {
|
||||
var key api.InternalKeyBackupSession
|
||||
|
@ -65,7 +65,6 @@ type keyBackupVersionStatements struct {
|
||||
updateKeyBackupETagStmt *sql.Stmt
|
||||
}
|
||||
|
||||
// nolint:unused
|
||||
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(keyBackupVersionTableSchema)
|
||||
if err != nil {
|
||||
|
@ -100,13 +100,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||
if err = d.openIDTokens.prepare(db, serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
/*
|
||||
if err = d.keyBackupVersions.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.keyBackups.prepare(db); err != nil {
|
||||
return nil, err
|
||||
} */
|
||||
}
|
||||
|
||||
return d, nil
|
||||
}
|
||||
@ -459,6 +458,37 @@ func (d *Database) GetKeyBackup(
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) GetBackupKeys(
|
||||
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
|
||||
) (result map[string]map[string]api.KeyBackupSession, err error) {
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
if filterSessionID != "" {
|
||||
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
|
||||
return err
|
||||
}
|
||||
if filterRoomID != "" {
|
||||
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
|
||||
return err
|
||||
}
|
||||
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) CountBackupKeys(
|
||||
ctx context.Context, version, userID string,
|
||||
) (count int64, err error) {
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// nolint:nakedret
|
||||
func (d *Database) UpsertBackupKeys(
|
||||
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
|
||||
@ -486,7 +516,7 @@ func (d *Database) UpsertBackupKeys(
|
||||
if existingRoom != nil {
|
||||
existingSession, ok := existingRoom[newKey.SessionID]
|
||||
if ok {
|
||||
if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) {
|
||||
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
|
||||
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
|
||||
changed = true
|
||||
if err != nil {
|
||||
@ -531,22 +561,3 @@ func (d *Database) UpsertBackupKeys(
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't
|
||||
// create circular import loops
|
||||
func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool {
|
||||
// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
|
||||
// "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
|
||||
if uploaded.IsVerified && !existing.IsVerified {
|
||||
return true
|
||||
}
|
||||
// "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
|
||||
if uploaded.FirstMessageIndex < existing.FirstMessageIndex {
|
||||
return true
|
||||
}
|
||||
// "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
|
||||
if uploaded.ForwardedCount < existing.ForwardedCount {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user