diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 1b7670e9..a638a574 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -76,21 +76,25 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re return } - // If there's room in memory to hold the event then add it to the - // list. - oq.pendingMutex.Lock() - if len(oq.pendingPDUs) < maxPDUsInMemory { - oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{ - pdu: event, - receipt: receipt, - }) - } else { - oq.overflowed.Store(true) - } - oq.pendingMutex.Unlock() + // Check if the destination is blacklisted. If it isn't then wake + // up the queue. + if !oq.statistics.Blacklisted() { + // If there's room in memory to hold the event then add it to the + // list. + oq.pendingMutex.Lock() + if len(oq.pendingPDUs) < maxPDUsInMemory { + oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{ + pdu: event, + receipt: receipt, + }) + } else { + oq.overflowed.Store(true) + } + oq.pendingMutex.Unlock() - if !oq.backingOff.Load() { - oq.wakeQueueAndNotify() + if !oq.backingOff.Load() { + oq.wakeQueueAndNotify() + } } } @@ -103,21 +107,25 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share return } - // If there's room in memory to hold the event then add it to the - // list. - oq.pendingMutex.Lock() - if len(oq.pendingEDUs) < maxEDUsInMemory { - oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{ - edu: event, - receipt: receipt, - }) - } else { - oq.overflowed.Store(true) - } - oq.pendingMutex.Unlock() + // Check if the destination is blacklisted. If it isn't then wake + // up the queue. + if !oq.statistics.Blacklisted() { + // If there's room in memory to hold the event then add it to the + // list. + oq.pendingMutex.Lock() + if len(oq.pendingEDUs) < maxEDUsInMemory { + oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{ + edu: event, + receipt: receipt, + }) + } else { + oq.overflowed.Store(true) + } + oq.pendingMutex.Unlock() - if !oq.backingOff.Load() { - oq.wakeQueueAndNotify() + if !oq.backingOff.Load() { + oq.wakeQueueAndNotify() + } } } diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 32833437..b5d0552c 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -247,9 +247,10 @@ func (oqs *OutgoingQueues) SendEvent( return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err) } + destQueues := make([]*destinationQueue, 0, len(destmap)) for destination := range destmap { - if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() { - queue.sendEvent(ev, nid) + if queue := oqs.getQueue(destination); queue != nil { + destQueues = append(destQueues, queue) } else { delete(destmap, destination) } @@ -267,6 +268,14 @@ func (oqs *OutgoingQueues) SendEvent( return err } + // NOTE : PDUs should be associated with destinations before sending + // them, otherwise this is technically a race. + // If the send completes before they are associated then they won't + // get properly cleaned up in the database. + for _, queue := range destQueues { + queue.sendEvent(ev, nid) + } + return nil } @@ -335,20 +344,21 @@ func (oqs *OutgoingQueues) SendEDU( return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err) } + destQueues := make([]*destinationQueue, 0, len(destmap)) for destination := range destmap { - if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() { - queue.sendEDU(e, nid) + if queue := oqs.getQueue(destination); queue != nil { + destQueues = append(destQueues, queue) } else { delete(destmap, destination) } } // Create a database entry that associates the given PDU NID with - // this destination queue. We'll then be able to retrieve the PDU + // these destination queues. We'll then be able to retrieve the PDU // later. if err := oqs.db.AssociateEDUWithDestinations( oqs.process.Context(), - destmap, // the destination server name + destmap, // the destination server names nid, // NIDs from federationapi_queue_json table e.Type, nil, // this will use the default expireEDUTypes map @@ -357,6 +367,14 @@ func (oqs *OutgoingQueues) SendEDU( return err } + // NOTE : EDUs should be associated with destinations before sending + // them, otherwise this is technically a race. + // If the send completes before they are associated then they won't + // get properly cleaned up in the database. + for _, queue := range destQueues { + queue.sendEDU(e, nid) + } + return nil }