2020-07-14 12:59:07 +01:00
|
|
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
|
2020-07-13 16:02:35 +01:00
|
|
|
package internal
|
|
|
|
|
|
|
|
import (
|
2020-07-14 12:59:07 +01:00
|
|
|
"bytes"
|
2020-07-13 16:02:35 +01:00
|
|
|
"context"
|
2020-07-15 18:40:41 +01:00
|
|
|
"encoding/json"
|
2020-07-14 12:59:07 +01:00
|
|
|
"fmt"
|
2020-07-13 16:02:35 +01:00
|
|
|
|
|
|
|
"github.com/matrix-org/dendrite/keyserver/api"
|
2020-07-14 12:59:07 +01:00
|
|
|
"github.com/matrix-org/dendrite/keyserver/storage"
|
2020-07-15 18:40:41 +01:00
|
|
|
"github.com/matrix-org/gomatrixserverlib"
|
2020-07-14 12:59:07 +01:00
|
|
|
"github.com/tidwall/gjson"
|
2020-07-15 18:40:41 +01:00
|
|
|
"github.com/tidwall/sjson"
|
2020-07-13 16:02:35 +01:00
|
|
|
)
|
|
|
|
|
2020-07-14 12:59:07 +01:00
|
|
|
type KeyInternalAPI struct {
|
2020-07-15 18:40:41 +01:00
|
|
|
DB storage.Database
|
|
|
|
ThisServer gomatrixserverlib.ServerName
|
2020-07-14 12:59:07 +01:00
|
|
|
}
|
2020-07-13 16:02:35 +01:00
|
|
|
|
|
|
|
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
2020-07-14 12:59:07 +01:00
|
|
|
res.KeyErrors = make(map[string]map[string]*api.KeyError)
|
|
|
|
a.uploadDeviceKeys(ctx, req, res)
|
|
|
|
a.uploadOneTimeKeys(ctx, req, res)
|
2020-07-13 16:02:35 +01:00
|
|
|
}
|
2020-07-21 14:47:53 +01:00
|
|
|
|
2020-07-13 16:02:35 +01:00
|
|
|
func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) {
|
2020-07-21 14:47:53 +01:00
|
|
|
res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
|
|
|
|
res.Failures = make(map[string]interface{})
|
|
|
|
// wrap request map in a top-level by-domain map
|
|
|
|
domainToDeviceKeys := make(map[string]map[string]map[string]string)
|
|
|
|
for userID, val := range req.OneTimeKeys {
|
|
|
|
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
|
|
|
if err != nil {
|
|
|
|
continue // ignore invalid users
|
|
|
|
}
|
|
|
|
nested, ok := domainToDeviceKeys[string(serverName)]
|
|
|
|
if !ok {
|
|
|
|
nested = make(map[string]map[string]string)
|
|
|
|
}
|
|
|
|
nested[userID] = val
|
|
|
|
domainToDeviceKeys[string(serverName)] = nested
|
|
|
|
}
|
|
|
|
// claim local keys
|
|
|
|
if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok {
|
|
|
|
keys, err := a.DB.ClaimKeys(ctx, local)
|
|
|
|
if err != nil {
|
|
|
|
res.Error = &api.KeyError{
|
|
|
|
Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
mergeInto(res.OneTimeKeys, keys)
|
|
|
|
delete(domainToDeviceKeys, string(a.ThisServer))
|
|
|
|
}
|
|
|
|
// TODO: claim remote keys
|
2020-07-13 16:02:35 +01:00
|
|
|
|
|
|
|
}
|
2020-07-21 14:47:53 +01:00
|
|
|
|
2020-07-13 16:02:35 +01:00
|
|
|
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
|
2020-07-15 18:40:41 +01:00
|
|
|
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
|
|
|
|
res.Failures = make(map[string]interface{})
|
|
|
|
// make a map from domain to device keys
|
|
|
|
domainToUserToDevice := make(map[string][]api.DeviceKeys)
|
|
|
|
for userID, deviceIDs := range req.UserToDevices {
|
|
|
|
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
|
|
|
if err != nil {
|
|
|
|
continue // ignore invalid users
|
|
|
|
}
|
|
|
|
domain := string(serverName)
|
|
|
|
// query local devices
|
|
|
|
if serverName == a.ThisServer {
|
|
|
|
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
|
|
|
|
if err != nil {
|
|
|
|
res.Error = &api.KeyError{
|
|
|
|
Err: fmt.Sprintf("failed to query local device keys: %s", err),
|
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
if res.DeviceKeys[userID] == nil {
|
|
|
|
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
|
|
|
}
|
|
|
|
for _, dk := range deviceKeys {
|
|
|
|
// inject an empty 'unsigned' key which should be used for display names
|
|
|
|
// (but not via this API? unsure when they should be added)
|
|
|
|
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct{}{})
|
|
|
|
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
for _, deviceID := range deviceIDs {
|
|
|
|
domainToUserToDevice[domain] = append(domainToUserToDevice[domain], api.DeviceKeys{
|
|
|
|
UserID: userID,
|
|
|
|
DeviceID: deviceID,
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// TODO: set device display names when they are known
|
|
|
|
// TODO: perform key queries for remote devices
|
2020-07-13 16:02:35 +01:00
|
|
|
}
|
2020-07-14 12:59:07 +01:00
|
|
|
|
|
|
|
func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
|
|
|
var keysToStore []api.DeviceKeys
|
|
|
|
// assert that the user ID / device ID are not lying for each key
|
|
|
|
for _, key := range req.DeviceKeys {
|
|
|
|
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
|
|
|
|
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
|
|
|
|
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
|
|
|
|
keysToStore = append(keysToStore, key)
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
|
2020-07-15 12:02:34 +01:00
|
|
|
Err: fmt.Sprintf(
|
2020-07-14 12:59:07 +01:00
|
|
|
"user_id or device_id mismatch: users: %s - %s, devices: %s - %s",
|
|
|
|
gotUserID, key.UserID, gotDeviceID, key.DeviceID,
|
|
|
|
),
|
|
|
|
})
|
|
|
|
}
|
|
|
|
// get existing device keys so we can check for changes
|
|
|
|
existingKeys := make([]api.DeviceKeys, len(keysToStore))
|
|
|
|
for i := range keysToStore {
|
|
|
|
existingKeys[i] = api.DeviceKeys{
|
|
|
|
UserID: keysToStore[i].UserID,
|
|
|
|
DeviceID: keysToStore[i].DeviceID,
|
|
|
|
}
|
|
|
|
}
|
2020-07-15 12:02:34 +01:00
|
|
|
if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
|
2020-07-14 12:59:07 +01:00
|
|
|
res.Error = &api.KeyError{
|
2020-07-15 12:02:34 +01:00
|
|
|
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
|
2020-07-14 12:59:07 +01:00
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
// store the device keys and emit changes
|
2020-07-15 12:02:34 +01:00
|
|
|
if err := a.DB.StoreDeviceKeys(ctx, keysToStore); err != nil {
|
2020-07-14 12:59:07 +01:00
|
|
|
res.Error = &api.KeyError{
|
2020-07-15 12:02:34 +01:00
|
|
|
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
|
2020-07-14 12:59:07 +01:00
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
a.emitDeviceKeyChanges(existingKeys, keysToStore)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
|
|
|
for _, key := range req.OneTimeKeys {
|
|
|
|
// grab existing keys based on (user/device/algorithm/key ID)
|
|
|
|
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
|
|
|
|
i := 0
|
|
|
|
for keyIDWithAlgo := range key.KeyJSON {
|
|
|
|
keyIDsWithAlgorithms[i] = keyIDWithAlgo
|
|
|
|
i++
|
|
|
|
}
|
2020-07-15 12:02:34 +01:00
|
|
|
existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, key.UserID, key.DeviceID, keyIDsWithAlgorithms)
|
2020-07-14 12:59:07 +01:00
|
|
|
if err != nil {
|
|
|
|
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
|
2020-07-15 12:02:34 +01:00
|
|
|
Err: "failed to query existing one-time keys: " + err.Error(),
|
2020-07-14 12:59:07 +01:00
|
|
|
})
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
for keyIDWithAlgo := range existingKeys {
|
|
|
|
// if keys exist and the JSON doesn't match, error out as the key already exists
|
|
|
|
if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
|
|
|
|
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
|
2020-07-15 12:02:34 +01:00
|
|
|
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", key.UserID, key.DeviceID, keyIDWithAlgo),
|
2020-07-14 12:59:07 +01:00
|
|
|
})
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// store one-time keys
|
2020-07-15 12:02:34 +01:00
|
|
|
counts, err := a.DB.StoreOneTimeKeys(ctx, key)
|
|
|
|
if err != nil {
|
2020-07-14 12:59:07 +01:00
|
|
|
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
|
2020-07-15 12:02:34 +01:00
|
|
|
Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", key.UserID, key.DeviceID, err.Error()),
|
2020-07-14 12:59:07 +01:00
|
|
|
})
|
2020-07-15 12:02:34 +01:00
|
|
|
continue
|
2020-07-14 12:59:07 +01:00
|
|
|
}
|
2020-07-15 12:02:34 +01:00
|
|
|
// collect counts
|
|
|
|
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
|
2020-07-14 12:59:07 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) {
|
|
|
|
// TODO
|
|
|
|
}
|
2020-07-21 14:47:53 +01:00
|
|
|
|
|
|
|
func mergeInto(dst map[string]map[string]map[string]json.RawMessage, src []api.OneTimeKeys) {
|
|
|
|
for _, key := range src {
|
|
|
|
_, ok := dst[key.UserID]
|
|
|
|
if !ok {
|
|
|
|
dst[key.UserID] = make(map[string]map[string]json.RawMessage)
|
|
|
|
}
|
|
|
|
_, ok = dst[key.UserID][key.DeviceID]
|
|
|
|
if !ok {
|
|
|
|
dst[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
|
|
|
|
}
|
|
|
|
for keyID, keyJSON := range key.KeyJSON {
|
|
|
|
dst[key.UserID][key.DeviceID][keyID] = keyJSON
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|