Only call key update process functions if there are updates, don't send things to ourselves over federation

This commit is contained in:
Neil Alexander 2022-04-25 14:22:46 +01:00
parent 446819e4ac
commit aad81b7b4d
No known key found for this signature in database
GPG Key ID: A02A2019A2BB0944
3 changed files with 51 additions and 9 deletions

View File

@ -210,6 +210,7 @@ func (oqs *OutgoingQueues) SendEvent(
destmap[d] = struct{}{} destmap[d] = struct{}{}
} }
delete(destmap, oqs.origin) delete(destmap, oqs.origin)
delete(destmap, oqs.signing.ServerName)
// Check if any of the destinations are prohibited by server ACLs. // Check if any of the destinations are prohibited by server ACLs.
for destination := range destmap { for destination := range destmap {
@ -275,6 +276,7 @@ func (oqs *OutgoingQueues) SendEDU(
destmap[d] = struct{}{} destmap[d] = struct{}{}
} }
delete(destmap, oqs.origin) delete(destmap, oqs.origin)
delete(destmap, oqs.signing.ServerName)
// There is absolutely no guarantee that the EDU will have a room_id // There is absolutely no guarantee that the EDU will have a room_id
// field, as it is not required by the spec. However, if it *does* // field, as it is not required by the spec. However, if it *does*

View File

@ -124,6 +124,7 @@ func Send(
t := txnReq{ t := txnReq{
rsAPI: rsAPI, rsAPI: rsAPI,
keys: keys, keys: keys,
ourServerName: cfg.Matrix.ServerName,
federation: federation, federation: federation,
servers: servers, servers: servers,
keyAPI: keyAPI, keyAPI: keyAPI,
@ -183,6 +184,7 @@ type txnReq struct {
gomatrixserverlib.Transaction gomatrixserverlib.Transaction
rsAPI api.RoomserverInternalAPI rsAPI api.RoomserverInternalAPI
keyAPI keyapi.KeyInternalAPI keyAPI keyapi.KeyInternalAPI
ourServerName gomatrixserverlib.ServerName
keys gomatrixserverlib.JSONVerifier keys gomatrixserverlib.JSONVerifier
federation txnFederationClient federation txnFederationClient
roomsMu *internal.MutexByRoom roomsMu *internal.MutexByRoom
@ -303,6 +305,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res
return &gomatrixserverlib.RespSend{PDUs: results}, nil return &gomatrixserverlib.RespSend{PDUs: results}, nil
} }
// nolint:gocyclo
func (t *txnReq) processEDUs(ctx context.Context) { func (t *txnReq) processEDUs(ctx context.Context) {
for _, e := range t.EDUs { for _, e := range t.EDUs {
eduCountTotal.Inc() eduCountTotal.Inc()
@ -318,13 +321,11 @@ func (t *txnReq) processEDUs(ctx context.Context) {
util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event") util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event")
continue continue
} }
_, domain, err := gomatrixserverlib.SplitID('@', typingPayload.UserID) if _, serverName, err := gomatrixserverlib.SplitID('@', typingPayload.UserID); err != nil {
if err != nil {
util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from typing event sender")
continue continue
} } else if serverName == t.ourServerName {
if domain != t.Origin { continue
util.GetLogger(ctx).Debugf("Dropping typing event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) } else if serverName != t.Origin {
continue continue
} }
if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil {
@ -337,6 +338,13 @@ func (t *txnReq) processEDUs(ctx context.Context) {
util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events") util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events")
continue continue
} }
if _, serverName, err := gomatrixserverlib.SplitID('@', directPayload.Sender); err != nil {
continue
} else if serverName == t.ourServerName {
continue
} else if serverName != t.Origin {
continue
}
for userID, byUser := range directPayload.Messages { for userID, byUser := range directPayload.Messages {
for deviceID, message := range byUser { for deviceID, message := range byUser {
// TODO: check that the user and the device actually exist here // TODO: check that the user and the device actually exist here
@ -405,6 +413,13 @@ func (t *txnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) e
return err return err
} }
for _, content := range payload.Push { for _, content := range payload.Push {
if _, serverName, err := gomatrixserverlib.SplitID('@', content.UserID); err != nil {
continue
} else if serverName == t.ourServerName {
continue
} else if serverName != t.Origin {
continue
}
presence, ok := syncTypes.PresenceFromString(content.Presence) presence, ok := syncTypes.PresenceFromString(content.Presence)
if !ok { if !ok {
continue continue
@ -424,7 +439,13 @@ func (t *txnReq) processSigningKeyUpdate(ctx context.Context, e gomatrixserverli
}).Debug("Failed to unmarshal signing key update") }).Debug("Failed to unmarshal signing key update")
return err return err
} }
if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil {
return nil
} else if serverName == t.ourServerName {
return nil
} else if serverName != t.Origin {
return nil
}
keys := gomatrixserverlib.CrossSigningKeys{} keys := gomatrixserverlib.CrossSigningKeys{}
if updatePayload.MasterKey != nil { if updatePayload.MasterKey != nil {
keys.MasterKey = *updatePayload.MasterKey keys.MasterKey = *updatePayload.MasterKey
@ -450,6 +471,13 @@ func (t *txnReq) processReceiptEvent(ctx context.Context,
timestamp gomatrixserverlib.Timestamp, timestamp gomatrixserverlib.Timestamp,
eventIDs []string, eventIDs []string,
) error { ) error {
if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil {
return nil
} else if serverName == t.ourServerName {
return nil
} else if serverName != t.Origin {
return nil
}
// store every event // store every event
for _, eventID := range eventIDs { for _, eventID := range eventIDs {
if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil { if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil {
@ -466,6 +494,13 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli
util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal device list update event") util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal device list update event")
return return
} }
if _, serverName, err := gomatrixserverlib.SplitID('@', payload.UserID); err != nil {
return
} else if serverName == t.ourServerName {
return
} else if serverName != t.Origin {
return
}
var inputRes keyapi.InputDeviceListUpdateResponse var inputRes keyapi.InputDeviceListUpdateResponse
t.keyAPI.InputDeviceListUpdate(context.Background(), &keyapi.InputDeviceListUpdateRequest{ t.keyAPI.InputDeviceListUpdate(context.Background(), &keyapi.InputDeviceListUpdateRequest{
Event: payload, Event: payload,

View File

@ -71,8 +71,12 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
res.KeyErrors = make(map[string]map[string]*api.KeyError) res.KeyErrors = make(map[string]map[string]*api.KeyError)
a.uploadLocalDeviceKeys(ctx, req, res) if len(req.DeviceKeys) > 0 {
a.uploadOneTimeKeys(ctx, req, res) a.uploadLocalDeviceKeys(ctx, req, res)
}
if len(req.OneTimeKeys) > 0 {
a.uploadOneTimeKeys(ctx, req, res)
}
} }
func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) { func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) {
@ -663,6 +667,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
// add the display name field from keysToStore into existingKeys // add the display name field from keysToStore into existingKeys
keysToStore = appendDisplayNames(existingKeys, keysToStore) keysToStore = appendDisplayNames(existingKeys, keysToStore)
} }
// store the device keys and emit changes // store the device keys and emit changes
err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore) err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
if err != nil { if err != nil {