// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package internal

import (
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"sync"
	"time"

	"github.com/matrix-org/gomatrixserverlib"
	"github.com/matrix-org/util"
	"github.com/sirupsen/logrus"
	"github.com/tidwall/gjson"
	"github.com/tidwall/sjson"

	fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
	"github.com/matrix-org/dendrite/keyserver/api"
	"github.com/matrix-org/dendrite/keyserver/producers"
	"github.com/matrix-org/dendrite/keyserver/storage"
	userapi "github.com/matrix-org/dendrite/userapi/api"
)

type KeyInternalAPI struct {
	DB         storage.Database
	ThisServer gomatrixserverlib.ServerName
	FedClient  fedsenderapi.KeyserverFederationAPI
	UserAPI    userapi.KeyserverUserAPI
	Producer   *producers.KeyChange
	Updater    *DeviceListUpdater
}

func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) {
	a.UserAPI = i
}

func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error {
	userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset)
	if err != nil {
		res.Error = &api.KeyError{
			Err: err.Error(),
		}
		return nil
	}
	res.Offset = latest
	res.UserIDs = userIDs
	return nil
}

func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error {
	res.KeyErrors = make(map[string]map[string]*api.KeyError)
	if len(req.DeviceKeys) > 0 {
		a.uploadLocalDeviceKeys(ctx, req, res)
	}
	if len(req.OneTimeKeys) > 0 {
		a.uploadOneTimeKeys(ctx, req, res)
	}
	otks, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
	if err != nil {
		return err
	}
	res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
	return nil
}

func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error {
	res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
	res.Failures = make(map[string]interface{})
	// wrap request map in a top-level by-domain map
	domainToDeviceKeys := make(map[string]map[string]map[string]string)
	for userID, val := range req.OneTimeKeys {
		_, serverName, err := gomatrixserverlib.SplitID('@', userID)
		if err != nil {
			continue // ignore invalid users
		}
		nested, ok := domainToDeviceKeys[string(serverName)]
		if !ok {
			nested = make(map[string]map[string]string)
		}
		nested[userID] = val
		domainToDeviceKeys[string(serverName)] = nested
	}
	// claim local keys
	if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok {
		keys, err := a.DB.ClaimKeys(ctx, local)
		if err != nil {
			res.Error = &api.KeyError{
				Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
			}
		}
		util.GetLogger(ctx).WithField("keys_claimed", len(keys)).WithField("num_users", len(local)).Info("Claimed local keys")
		for _, key := range keys {
			_, ok := res.OneTimeKeys[key.UserID]
			if !ok {
				res.OneTimeKeys[key.UserID] = make(map[string]map[string]json.RawMessage)
			}
			_, ok = res.OneTimeKeys[key.UserID][key.DeviceID]
			if !ok {
				res.OneTimeKeys[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
			}
			for keyID, keyJSON := range key.KeyJSON {
				res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
			}
		}
		delete(domainToDeviceKeys, string(a.ThisServer))
	}
	if len(domainToDeviceKeys) > 0 {
		a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
	}
	return nil
}

func (a *KeyInternalAPI) claimRemoteKeys(
	ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
) {
	var wg sync.WaitGroup // Wait for fan-out goroutines to finish
	var mu sync.Mutex     // Protects the response struct
	var claimed int       // Number of keys claimed in total
	var failures int      // Number of servers we failed to ask

	util.GetLogger(ctx).Infof("Claiming remote keys from %d server(s)", len(domainToDeviceKeys))
	wg.Add(len(domainToDeviceKeys))

	for d, k := range domainToDeviceKeys {
		go func(domain string, keysToClaim map[string]map[string]string) {
			fedCtx, cancel := context.WithTimeout(ctx, timeout)
			defer cancel()
			defer wg.Done()

			claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim)

			mu.Lock()
			defer mu.Unlock()

			if err != nil {
				util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed")
				res.Failures[domain] = map[string]interface{}{
					"message": err.Error(),
				}
				failures++
				return
			}

			for userID, deviceIDToKeys := range claimKeyRes.OneTimeKeys {
				res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
				for deviceID, keys := range deviceIDToKeys {
					res.OneTimeKeys[userID][deviceID] = keys
					claimed += len(keys)
				}
			}
		}(d, k)
	}

	wg.Wait()
	util.GetLogger(ctx).WithFields(logrus.Fields{
		"num_keys":     claimed,
		"num_failures": failures,
	}).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys))
}

func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
	if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
		res.Error = &api.KeyError{
			Err: fmt.Sprintf("Failed to delete device keys: %s", err),
		}
	}
	return nil
}

func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error {
	count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
	if err != nil {
		res.Error = &api.KeyError{
			Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
		}
		return nil
	}
	res.Count = *count
	return nil
}

func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
	msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
	if err != nil {
		res.Error = &api.KeyError{
			Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
		}
		return nil
	}
	maxStreamID := int64(0)
	// remove deleted devices
	var result []api.DeviceMessage
	for _, m := range msgs {
		if m.StreamID > maxStreamID {
			maxStreamID = m.StreamID
		}
		if m.KeyJSON == nil || len(m.KeyJSON) == 0 {
			continue
		}
		result = append(result, m)
	}
	res.Devices = result
	res.StreamID = maxStreamID
	return nil
}

// PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present
// in our database.
func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error {
	knownDevices, err := a.DB.DeviceKeysForUser(ctx, req.UserID, []string{}, true)
	if err != nil {
		return err
	}
	if len(knownDevices) == 0 {
		return nil // fmt.Errorf("unknown user %s", req.UserID)
	}

	for i := range knownDevices {
		if knownDevices[i].DeviceID == req.DeviceID {
			return nil // we already know about this device
		}
	}

	return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID)
}

// nolint:gocyclo
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
	var respMu sync.Mutex
	res.DeviceKeys = make(map[string]map[string]json.RawMessage)
	res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
	res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
	res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
	res.Failures = make(map[string]interface{})

	// make a map from domain to device keys
	domainToDeviceKeys := make(map[string]map[string][]string)
	domainToCrossSigningKeys := make(map[string]map[string]struct{})
	for userID, deviceIDs := range req.UserToDevices {
		_, serverName, err := gomatrixserverlib.SplitID('@', userID)
		if err != nil {
			continue // ignore invalid users
		}
		domain := string(serverName)
		// query local devices
		if serverName == a.ThisServer {
			deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
			if err != nil {
				res.Error = &api.KeyError{
					Err: fmt.Sprintf("failed to query local device keys: %s", err),
				}
				return nil
			}

			// pull out display names after we have the keys so we handle wildcards correctly
			var dids []string
			for _, dk := range deviceKeys {
				dids = append(dids, dk.DeviceID)
			}
			var queryRes userapi.QueryDeviceInfosResponse
			err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{
				DeviceIDs: dids,
			}, &queryRes)
			if err != nil {
				util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing")
			}

			if res.DeviceKeys[userID] == nil {
				res.DeviceKeys[userID] = make(map[string]json.RawMessage)
			}
			for _, dk := range deviceKeys {
				if len(dk.KeyJSON) == 0 {
					continue // don't include blank keys
				}
				// inject display name if known (either locally or remotely)
				displayName := dk.DisplayName
				if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" {
					displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName
				}
				dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
					DisplayName string `json:"device_display_name,omitempty"`
				}{displayName})
				res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
			}
		} else {
			domainToDeviceKeys[domain] = make(map[string][]string)
			domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
		}
		// work out if our cross-signing request for this user was
		// satisfied, if not add them to the list of things to fetch
		if _, ok := res.MasterKeys[userID]; !ok {
			if _, ok := domainToCrossSigningKeys[domain]; !ok {
				domainToCrossSigningKeys[domain] = make(map[string]struct{})
			}
			domainToCrossSigningKeys[domain][userID] = struct{}{}
		}
		if _, ok := res.SelfSigningKeys[userID]; !ok {
			if _, ok := domainToCrossSigningKeys[domain]; !ok {
				domainToCrossSigningKeys[domain] = make(map[string]struct{})
			}
			domainToCrossSigningKeys[domain][userID] = struct{}{}
		}
	}

	// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
	domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys)
	if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
		// perform key queries for remote devices
		a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
	}

	// Now that we've done the potentially expensive work of asking the federation,
	// try filling the cross-signing keys from the database that we know about.
	a.crossSigningKeysFromDatabase(ctx, req, res)

	// Finally, append signatures that we know about
	// TODO: This is horrible because we need to round-trip the signature from
	// JSON, add the signatures and marshal it again, for some reason?

	for targetUserID, masterKey := range res.MasterKeys {
		if masterKey.Signatures == nil {
			masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
		}
		for targetKeyID := range masterKey.Keys {
			sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
			if err != nil {
				// Stop executing the function if the context was canceled/the deadline was exceeded,
				// as we can't continue without a valid context.
				if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
					return nil
				}
				logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
				continue
			}
			if len(sigMap) == 0 {
				continue
			}
			for sourceUserID, forSourceUser := range sigMap {
				for sourceKeyID, sourceSig := range forSourceUser {
					if _, ok := masterKey.Signatures[sourceUserID]; !ok {
						masterKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
					}
					masterKey.Signatures[sourceUserID][sourceKeyID] = sourceSig
				}
			}
		}
	}

	for targetUserID, forUserID := range res.DeviceKeys {
		for targetKeyID, key := range forUserID {
			sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
			if err != nil {
				// Stop executing the function if the context was canceled/the deadline was exceeded,
				// as we can't continue without a valid context.
				if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
					return nil
				}
				logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
				continue
			}
			if len(sigMap) == 0 {
				continue
			}
			var deviceKey gomatrixserverlib.DeviceKeys
			if err = json.Unmarshal(key, &deviceKey); err != nil {
				continue
			}
			if deviceKey.Signatures == nil {
				deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
			}
			for sourceUserID, forSourceUser := range sigMap {
				for sourceKeyID, sourceSig := range forSourceUser {
					if _, ok := deviceKey.Signatures[sourceUserID]; !ok {
						deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
					}
					deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig
				}
			}
			if js, err := json.Marshal(deviceKey); err == nil {
				res.DeviceKeys[targetUserID][targetKeyID] = js
			}
		}
	}
	return nil
}

func (a *KeyInternalAPI) remoteKeysFromDatabase(
	ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string,
) map[string]map[string][]string {
	fetchRemote := make(map[string]map[string][]string)
	for domain, userToDeviceMap := range domainToDeviceKeys {
		for userID, deviceIDs := range userToDeviceMap {
			// we can't safely return keys from the db when all devices are requested as we don't
			// know if one has just been added.
			if len(deviceIDs) > 0 {
				err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs)
				if err == nil {
					continue
				}
				util.GetLogger(ctx).WithError(err).Error("populateResponseWithDeviceKeysFromDatabase")
			}
			// fetch device lists from remote
			if _, ok := fetchRemote[domain]; !ok {
				fetchRemote[domain] = make(map[string][]string)
			}
			fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)

		}
	}
	return fetchRemote
}

func (a *KeyInternalAPI) queryRemoteKeys(
	ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse,
	domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{},
) {
	resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys))
	// allows us to wait until all federation servers have been poked
	var wg sync.WaitGroup
	// mutex for writing directly to res (e.g failures)
	var respMu sync.Mutex

	domains := map[string]struct{}{}
	for domain := range domainToDeviceKeys {
		if domain == string(a.ThisServer) {
			continue
		}
		domains[domain] = struct{}{}
	}
	for domain := range domainToCrossSigningKeys {
		if domain == string(a.ThisServer) {
			continue
		}
		domains[domain] = struct{}{}
	}
	wg.Add(len(domains))

	// fan out
	for domain := range domains {
		go a.queryRemoteKeysOnServer(
			ctx, domain, domainToDeviceKeys[domain], domainToCrossSigningKeys[domain],
			&wg, &respMu, timeout, resultCh, res,
		)
	}

	// Close the result channel when the goroutines have quit so the for .. range exits
	go func() {
		wg.Wait()
		close(resultCh)
	}()

	processResult := func(result *gomatrixserverlib.RespQueryKeys) {
		respMu.Lock()
		defer respMu.Unlock()
		for userID, nest := range result.DeviceKeys {
			res.DeviceKeys[userID] = make(map[string]json.RawMessage)
			for deviceID, deviceKey := range nest {
				keyJSON, err := json.Marshal(deviceKey)
				if err != nil {
					continue
				}
				res.DeviceKeys[userID][deviceID] = keyJSON
			}
		}

		for userID, body := range result.MasterKeys {
			res.MasterKeys[userID] = body
		}

		for userID, body := range result.SelfSigningKeys {
			res.SelfSigningKeys[userID] = body
		}

		// TODO: do we want to persist these somewhere now
		// that we have fetched them?
	}

	for result := range resultCh {
		processResult(result)
	}
}

func (a *KeyInternalAPI) queryRemoteKeysOnServer(
	ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{},
	wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys,
	res *api.QueryKeysResponse,
) {
	defer wg.Done()
	fedCtx := ctx
	if timeout > 0 {
		var cancel context.CancelFunc
		fedCtx, cancel = context.WithTimeout(ctx, timeout)
		defer cancel()
	}
	// for users who we do not have any knowledge about, try to start doing device list updates for them
	// by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but
	// lack a stream ID.
	userIDsForAllDevices := map[string]struct{}{}
	for userID, deviceIDs := range devKeys {
		if len(deviceIDs) == 0 {
			userIDsForAllDevices[userID] = struct{}{}
		}
	}
	// for cross-signing keys, it's probably easier just to hit /keys/query if we aren't already doing
	// a device list update, so we'll populate those back into the /keys/query list if not
	for userID := range crossSigningKeys {
		if devKeys == nil {
			devKeys = map[string][]string{}
		}
		if _, ok := userIDsForAllDevices[userID]; !ok {
			devKeys[userID] = []string{}
		}
	}
	for userID := range userIDsForAllDevices {
		err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID)
		if err != nil {
			logrus.WithFields(logrus.Fields{
				logrus.ErrorKey: err,
				"user_id":       userID,
				"server":        serverName,
			}).Error("Failed to manually update device lists for user")
			// try to do it via /keys/query
			devKeys[userID] = []string{}
			continue
		}
		// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
		// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
		err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
		if err != nil {
			logrus.WithFields(logrus.Fields{
				logrus.ErrorKey: err,
				"user_id":       userID,
				"server":        serverName,
			}).Error("Failed to manually update device lists for user")
			// try to do it via /keys/query
			devKeys[userID] = []string{}
			continue
		}
	}
	if len(devKeys) == 0 {
		return
	}
	queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys)
	if err == nil {
		resultCh <- &queryKeysResp
		return
	}
	respMu.Lock()
	res.Failures[serverName] = map[string]interface{}{
		"message": err.Error(),
	}
	respMu.Unlock()

	// last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server
	// is down, better to return something than nothing at all. Clients can know about the failure by
	// inspecting the failures map though so they can know it's a cached response.
	for userID, dkeys := range devKeys {
		// drop the error as it's already a failure at this point
		_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys)
	}

	// Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
	respMu.Lock()
	if len(res.DeviceKeys) > 0 {
		delete(res.Failures, serverName)
	}
	respMu.Unlock()
}

func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
	ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
) error {
	keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
	// if we can't query the db or there are fewer keys than requested, fetch from remote.
	if err != nil {
		return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
	}
	if len(keys) < len(deviceIDs) {
		return fmt.Errorf("DeviceKeysForUser %s returned fewer devices than requested, falling back to remote", userID)
	}
	if len(deviceIDs) == 0 && len(keys) == 0 {
		return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
	}
	respMu.Lock()
	if res.DeviceKeys[userID] == nil {
		res.DeviceKeys[userID] = make(map[string]json.RawMessage)
	}
	respMu.Unlock()

	for _, key := range keys {
		if len(key.KeyJSON) == 0 {
			continue // ignore deleted keys
		}
		// inject the display name
		key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
			DisplayName string `json:"device_display_name,omitempty"`
		}{key.DisplayName})
		respMu.Lock()
		res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
		respMu.Unlock()
	}
	return nil
}

func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
	// get a list of devices from the user API that actually exist, as
	// we won't store keys for devices that don't exist
	uapidevices := &userapi.QueryDevicesResponse{}
	if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
		res.Error = &api.KeyError{
			Err: err.Error(),
		}
		return
	}
	if !uapidevices.UserExists {
		res.Error = &api.KeyError{
			Err: fmt.Sprintf("user %q does not exist", req.UserID),
		}
		return
	}
	existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices))
	for _, key := range uapidevices.Devices {
		existingDeviceMap[key.ID] = struct{}{}
	}

	// Get all of the user existing device keys so we can check for changes.
	existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true)
	if err != nil {
		res.Error = &api.KeyError{
			Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
		}
		return
	}

	// Work out whether we have device keys in the keyserver for devices that
	// no longer exist in the user API. This is mostly an exercise to ensure
	// that we keep some integrity between the two.
	var toClean []gomatrixserverlib.KeyID
	for _, k := range existingKeys {
		if _, ok := existingDeviceMap[k.DeviceID]; !ok {
			toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID))
		}
	}

	if len(toClean) > 0 {
		if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
			logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean))
		} else {
			logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean))
		}
	}

	var keysToStore []api.DeviceMessage

	if req.OnlyDisplayNameUpdates {
		for _, existingKey := range existingKeys {
			for _, newKey := range req.DeviceKeys {
				switch {
				case existingKey.UserID != newKey.UserID:
					continue
				case existingKey.DeviceID != newKey.DeviceID:
					continue
				case existingKey.DisplayName != newKey.DisplayName:
					existingKey.DisplayName = newKey.DisplayName
				}
			}
			keysToStore = append(keysToStore, existingKey)
		}
	} else {
		// assert that the user ID / device ID are not lying for each key
		for _, key := range req.DeviceKeys {
			var serverName gomatrixserverlib.ServerName
			_, serverName, err = gomatrixserverlib.SplitID('@', key.UserID)
			if err != nil {
				continue // ignore invalid users
			}
			if serverName != a.ThisServer {
				continue // ignore remote users
			}
			if len(key.KeyJSON) == 0 {
				keysToStore = append(keysToStore, key.WithStreamID(0))
				continue // deleted keys don't need sanity checking
			}
			// check that the device in question actually exists in the user
			// API before we try and store a key for it
			if _, ok := existingDeviceMap[key.DeviceID]; !ok {
				continue
			}
			gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
			gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
			if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
				keysToStore = append(keysToStore, key.WithStreamID(0))
				continue
			}

			res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
				Err: fmt.Sprintf(
					"user_id or device_id mismatch: users: %s - %s, devices: %s - %s",
					gotUserID, key.UserID, gotDeviceID, key.DeviceID,
				),
			})
		}
	}

	// store the device keys and emit changes
	err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
	if err != nil {
		res.Error = &api.KeyError{
			Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
		}
		return
	}
	err = emitDeviceKeyChanges(a.Producer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates)
	if err != nil {
		util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
	}
}

func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
	if req.UserID == "" {
		res.Error = &api.KeyError{
			Err: "user ID  missing",
		}
	}
	if req.DeviceID != "" && len(req.OneTimeKeys) == 0 {
		counts, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
		if err != nil {
			res.Error = &api.KeyError{
				Err: fmt.Sprintf("a.DB.OneTimeKeysCount: %s", err),
			}
		}
		if counts != nil {
			res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
		}
		return
	}
	for _, key := range req.OneTimeKeys {
		// grab existing keys based on (user/device/algorithm/key ID)
		keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
		i := 0
		for keyIDWithAlgo := range key.KeyJSON {
			keyIDsWithAlgorithms[i] = keyIDWithAlgo
			i++
		}
		existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms)
		if err != nil {
			res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
				Err: "failed to query existing one-time keys: " + err.Error(),
			})
			continue
		}
		for keyIDWithAlgo := range existingKeys {
			// if keys exist and the JSON doesn't match, error out as the key already exists
			if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
				res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
					Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
				})
				continue
			}
		}
		// store one-time keys
		counts, err := a.DB.StoreOneTimeKeys(ctx, key)
		if err != nil {
			res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
				Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()),
			})
			continue
		}
		// collect counts
		res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
	}

}

func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
	// if we only want to update the display names, we can skip the checks below
	if onlyUpdateDisplayName {
		return producer.ProduceKeyChanges(new)
	}
	// find keys in new that are not in existing
	var keysAdded []api.DeviceMessage
	for _, newKey := range new {
		exists := false
		for _, existingKey := range existing {
			// Do not treat the absence of keys as equal, or else we will not emit key changes
			// when users delete devices which never had a key to begin with as both KeyJSONs are nil.
			if existingKey.DeviceKeysEqual(&newKey) {
				exists = true
				break
			}
		}
		if !exists {
			keysAdded = append(keysAdded, newKey)
		}
	}
	return producer.ProduceKeyChanges(keysAdded)
}