mirror of
https://github.com/1f349/dendrite.git
synced 2024-11-25 05:01:41 +00:00
Only call key update process functions if there are updates, don't send things to ourselves over federation
This commit is contained in:
parent
446819e4ac
commit
aad81b7b4d
@ -210,6 +210,7 @@ func (oqs *OutgoingQueues) SendEvent(
|
|||||||
destmap[d] = struct{}{}
|
destmap[d] = struct{}{}
|
||||||
}
|
}
|
||||||
delete(destmap, oqs.origin)
|
delete(destmap, oqs.origin)
|
||||||
|
delete(destmap, oqs.signing.ServerName)
|
||||||
|
|
||||||
// Check if any of the destinations are prohibited by server ACLs.
|
// Check if any of the destinations are prohibited by server ACLs.
|
||||||
for destination := range destmap {
|
for destination := range destmap {
|
||||||
@ -275,6 +276,7 @@ func (oqs *OutgoingQueues) SendEDU(
|
|||||||
destmap[d] = struct{}{}
|
destmap[d] = struct{}{}
|
||||||
}
|
}
|
||||||
delete(destmap, oqs.origin)
|
delete(destmap, oqs.origin)
|
||||||
|
delete(destmap, oqs.signing.ServerName)
|
||||||
|
|
||||||
// There is absolutely no guarantee that the EDU will have a room_id
|
// There is absolutely no guarantee that the EDU will have a room_id
|
||||||
// field, as it is not required by the spec. However, if it *does*
|
// field, as it is not required by the spec. However, if it *does*
|
||||||
|
@ -124,6 +124,7 @@ func Send(
|
|||||||
t := txnReq{
|
t := txnReq{
|
||||||
rsAPI: rsAPI,
|
rsAPI: rsAPI,
|
||||||
keys: keys,
|
keys: keys,
|
||||||
|
ourServerName: cfg.Matrix.ServerName,
|
||||||
federation: federation,
|
federation: federation,
|
||||||
servers: servers,
|
servers: servers,
|
||||||
keyAPI: keyAPI,
|
keyAPI: keyAPI,
|
||||||
@ -183,6 +184,7 @@ type txnReq struct {
|
|||||||
gomatrixserverlib.Transaction
|
gomatrixserverlib.Transaction
|
||||||
rsAPI api.RoomserverInternalAPI
|
rsAPI api.RoomserverInternalAPI
|
||||||
keyAPI keyapi.KeyInternalAPI
|
keyAPI keyapi.KeyInternalAPI
|
||||||
|
ourServerName gomatrixserverlib.ServerName
|
||||||
keys gomatrixserverlib.JSONVerifier
|
keys gomatrixserverlib.JSONVerifier
|
||||||
federation txnFederationClient
|
federation txnFederationClient
|
||||||
roomsMu *internal.MutexByRoom
|
roomsMu *internal.MutexByRoom
|
||||||
@ -303,6 +305,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res
|
|||||||
return &gomatrixserverlib.RespSend{PDUs: results}, nil
|
return &gomatrixserverlib.RespSend{PDUs: results}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint:gocyclo
|
||||||
func (t *txnReq) processEDUs(ctx context.Context) {
|
func (t *txnReq) processEDUs(ctx context.Context) {
|
||||||
for _, e := range t.EDUs {
|
for _, e := range t.EDUs {
|
||||||
eduCountTotal.Inc()
|
eduCountTotal.Inc()
|
||||||
@ -318,13 +321,11 @@ func (t *txnReq) processEDUs(ctx context.Context) {
|
|||||||
util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event")
|
util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
_, domain, err := gomatrixserverlib.SplitID('@', typingPayload.UserID)
|
if _, serverName, err := gomatrixserverlib.SplitID('@', typingPayload.UserID); err != nil {
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from typing event sender")
|
|
||||||
continue
|
continue
|
||||||
}
|
} else if serverName == t.ourServerName {
|
||||||
if domain != t.Origin {
|
continue
|
||||||
util.GetLogger(ctx).Debugf("Dropping typing event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin)
|
} else if serverName != t.Origin {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil {
|
if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil {
|
||||||
@ -337,6 +338,13 @@ func (t *txnReq) processEDUs(ctx context.Context) {
|
|||||||
util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events")
|
util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if _, serverName, err := gomatrixserverlib.SplitID('@', directPayload.Sender); err != nil {
|
||||||
|
continue
|
||||||
|
} else if serverName == t.ourServerName {
|
||||||
|
continue
|
||||||
|
} else if serverName != t.Origin {
|
||||||
|
continue
|
||||||
|
}
|
||||||
for userID, byUser := range directPayload.Messages {
|
for userID, byUser := range directPayload.Messages {
|
||||||
for deviceID, message := range byUser {
|
for deviceID, message := range byUser {
|
||||||
// TODO: check that the user and the device actually exist here
|
// TODO: check that the user and the device actually exist here
|
||||||
@ -405,6 +413,13 @@ func (t *txnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) e
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, content := range payload.Push {
|
for _, content := range payload.Push {
|
||||||
|
if _, serverName, err := gomatrixserverlib.SplitID('@', content.UserID); err != nil {
|
||||||
|
continue
|
||||||
|
} else if serverName == t.ourServerName {
|
||||||
|
continue
|
||||||
|
} else if serverName != t.Origin {
|
||||||
|
continue
|
||||||
|
}
|
||||||
presence, ok := syncTypes.PresenceFromString(content.Presence)
|
presence, ok := syncTypes.PresenceFromString(content.Presence)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
@ -424,7 +439,13 @@ func (t *txnReq) processSigningKeyUpdate(ctx context.Context, e gomatrixserverli
|
|||||||
}).Debug("Failed to unmarshal signing key update")
|
}).Debug("Failed to unmarshal signing key update")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil {
|
||||||
|
return nil
|
||||||
|
} else if serverName == t.ourServerName {
|
||||||
|
return nil
|
||||||
|
} else if serverName != t.Origin {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
keys := gomatrixserverlib.CrossSigningKeys{}
|
keys := gomatrixserverlib.CrossSigningKeys{}
|
||||||
if updatePayload.MasterKey != nil {
|
if updatePayload.MasterKey != nil {
|
||||||
keys.MasterKey = *updatePayload.MasterKey
|
keys.MasterKey = *updatePayload.MasterKey
|
||||||
@ -450,6 +471,13 @@ func (t *txnReq) processReceiptEvent(ctx context.Context,
|
|||||||
timestamp gomatrixserverlib.Timestamp,
|
timestamp gomatrixserverlib.Timestamp,
|
||||||
eventIDs []string,
|
eventIDs []string,
|
||||||
) error {
|
) error {
|
||||||
|
if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil {
|
||||||
|
return nil
|
||||||
|
} else if serverName == t.ourServerName {
|
||||||
|
return nil
|
||||||
|
} else if serverName != t.Origin {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
// store every event
|
// store every event
|
||||||
for _, eventID := range eventIDs {
|
for _, eventID := range eventIDs {
|
||||||
if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil {
|
if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil {
|
||||||
@ -466,6 +494,13 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli
|
|||||||
util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal device list update event")
|
util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal device list update event")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if _, serverName, err := gomatrixserverlib.SplitID('@', payload.UserID); err != nil {
|
||||||
|
return
|
||||||
|
} else if serverName == t.ourServerName {
|
||||||
|
return
|
||||||
|
} else if serverName != t.Origin {
|
||||||
|
return
|
||||||
|
}
|
||||||
var inputRes keyapi.InputDeviceListUpdateResponse
|
var inputRes keyapi.InputDeviceListUpdateResponse
|
||||||
t.keyAPI.InputDeviceListUpdate(context.Background(), &keyapi.InputDeviceListUpdateRequest{
|
t.keyAPI.InputDeviceListUpdate(context.Background(), &keyapi.InputDeviceListUpdateRequest{
|
||||||
Event: payload,
|
Event: payload,
|
||||||
|
@ -71,9 +71,13 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC
|
|||||||
|
|
||||||
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||||
res.KeyErrors = make(map[string]map[string]*api.KeyError)
|
res.KeyErrors = make(map[string]map[string]*api.KeyError)
|
||||||
|
if len(req.DeviceKeys) > 0 {
|
||||||
a.uploadLocalDeviceKeys(ctx, req, res)
|
a.uploadLocalDeviceKeys(ctx, req, res)
|
||||||
|
}
|
||||||
|
if len(req.OneTimeKeys) > 0 {
|
||||||
a.uploadOneTimeKeys(ctx, req, res)
|
a.uploadOneTimeKeys(ctx, req, res)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) {
|
func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) {
|
||||||
res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
|
res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
|
||||||
@ -663,6 +667,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
|||||||
// add the display name field from keysToStore into existingKeys
|
// add the display name field from keysToStore into existingKeys
|
||||||
keysToStore = appendDisplayNames(existingKeys, keysToStore)
|
keysToStore = appendDisplayNames(existingKeys, keysToStore)
|
||||||
}
|
}
|
||||||
|
|
||||||
// store the device keys and emit changes
|
// store the device keys and emit changes
|
||||||
err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
|
err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
Reference in New Issue
Block a user