Refactor Federation Destination Queues (#2807)

This is a refactor of the federation destination queues.
It fixes a few things, namely:
- actually retry outgoing events with backoff behaviour
- obtain enough events from the database to fill messages as much as
possible
- minimize the amount of running goroutines
  - use pure timers for backoff
  - don't restart queue unless necessary
  - close the background task when backing off
- increase max edus in a transaction to match the spec
- cleanup timers more aggresively to reduce memory usage
- add jitter to backoff timers to reduce resource spikes
- add a bunch of tests (with real and fake databases) to ensure
everything is working
This commit is contained in:
devonh 2022-10-19 10:03:16 +00:00 committed by GitHub
parent 3aa92efaa3
commit 241d5c47df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1410 additions and 202 deletions

View File

@ -116,17 +116,14 @@ func NewInternalAPI(
_ = federationDB.RemoveAllServersFromBlacklist() _ = federationDB.RemoveAllServersFromBlacklist()
} }
stats := &statistics.Statistics{ stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1)
DB: federationDB,
FailuresUntilBlacklist: cfg.FederationMaxRetries,
}
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
queues := queue.NewOutgoingQueues( queues := queue.NewOutgoingQueues(
federationDB, base.ProcessContext, federationDB, base.ProcessContext,
cfg.Matrix.DisableFederation, cfg.Matrix.DisableFederation,
cfg.Matrix.ServerName, federation, rsAPI, stats, cfg.Matrix.ServerName, federation, rsAPI, &stats,
&queue.SigningInfo{ &queue.SigningInfo{
KeyID: cfg.Matrix.KeyID, KeyID: cfg.Matrix.KeyID,
PrivateKey: cfg.Matrix.PrivateKey, PrivateKey: cfg.Matrix.PrivateKey,
@ -183,5 +180,5 @@ func NewInternalAPI(
} }
time.AfterFunc(time.Minute, cleanExpiredEDUs) time.AfterFunc(time.Minute, cleanExpiredEDUs)
return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, stats, caches, queues, keyRing) return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, &stats, caches, queues, keyRing)
} }

View File

@ -35,7 +35,7 @@ import (
const ( const (
maxPDUsPerTransaction = 50 maxPDUsPerTransaction = 50
maxEDUsPerTransaction = 50 maxEDUsPerTransaction = 100
maxPDUsInMemory = 128 maxPDUsInMemory = 128
maxEDUsInMemory = 128 maxEDUsInMemory = 128
queueIdleTimeout = time.Second * 30 queueIdleTimeout = time.Second * 30
@ -64,7 +64,6 @@ type destinationQueue struct {
pendingPDUs []*queuedPDU // PDUs waiting to be sent pendingPDUs []*queuedPDU // PDUs waiting to be sent
pendingEDUs []*queuedEDU // EDUs waiting to be sent pendingEDUs []*queuedEDU // EDUs waiting to be sent
pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs
interruptBackoff chan bool // interrupts backoff
} }
// Send event adds the event to the pending queue for the destination. // Send event adds the event to the pending queue for the destination.
@ -75,6 +74,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination) logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination)
return return
} }
// Create a database entry that associates the given PDU NID with // Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU // this destination queue. We'll then be able to retrieve the PDU
// later. // later.
@ -102,12 +102,12 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
oq.overflowed.Store(true) oq.overflowed.Store(true)
} }
oq.pendingMutex.Unlock() oq.pendingMutex.Unlock()
// Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded() if !oq.backingOff.Load() {
select { oq.wakeQueueAndNotify()
case oq.notify <- struct{}{}:
default:
} }
} else {
oq.overflowed.Store(true)
} }
} }
@ -147,12 +147,37 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
oq.overflowed.Store(true) oq.overflowed.Store(true)
} }
oq.pendingMutex.Unlock() oq.pendingMutex.Unlock()
// Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded() if !oq.backingOff.Load() {
select { oq.wakeQueueAndNotify()
case oq.notify <- struct{}{}:
default:
} }
} else {
oq.overflowed.Store(true)
}
}
// handleBackoffNotifier is registered as the backoff notification
// callback with Statistics. It will wakeup and notify the queue
// if the queue is currently backing off.
func (oq *destinationQueue) handleBackoffNotifier() {
// Only wake up the queue if it is backing off.
// Otherwise there is no pending work for the queue to handle
// so waking the queue would be a waste of resources.
if oq.backingOff.Load() {
oq.wakeQueueAndNotify()
}
}
// wakeQueueAndNotify ensures the destination queue is running and notifies it
// that there is pending work.
func (oq *destinationQueue) wakeQueueAndNotify() {
// Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded()
// Notify the queue that there are events ready to send.
select {
case oq.notify <- struct{}{}:
default:
} }
} }
@ -161,10 +186,11 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
// then we will interrupt the backoff, causing any federation // then we will interrupt the backoff, causing any federation
// requests to retry. // requests to retry.
func (oq *destinationQueue) wakeQueueIfNeeded() { func (oq *destinationQueue) wakeQueueIfNeeded() {
// If we are backing off then interrupt the backoff. // Clear the backingOff flag and update the backoff metrics if it was set.
if oq.backingOff.CompareAndSwap(true, false) { if oq.backingOff.CompareAndSwap(true, false) {
oq.interruptBackoff <- true destinationQueueBackingOff.Dec()
} }
// If we aren't running then wake up the queue. // If we aren't running then wake up the queue.
if !oq.running.Load() { if !oq.running.Load() {
// Start the queue. // Start the queue.
@ -196,38 +222,54 @@ func (oq *destinationQueue) getPendingFromDatabase() {
gotEDUs[edu.receipt.String()] = struct{}{} gotEDUs[edu.receipt.String()] = struct{}{}
} }
overflowed := false
if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 { if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 {
// We have room in memory for some PDUs - let's request no more than that. // We have room in memory for some PDUs - let's request no more than that.
if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil { if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, maxPDUsInMemory); err == nil {
if len(pdus) == maxPDUsInMemory {
overflowed = true
}
for receipt, pdu := range pdus { for receipt, pdu := range pdus {
if _, ok := gotPDUs[receipt.String()]; ok { if _, ok := gotPDUs[receipt.String()]; ok {
continue continue
} }
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu}) oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu})
retrieved = true retrieved = true
if len(oq.pendingPDUs) == maxPDUsInMemory {
break
}
} }
} else { } else {
logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination) logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination)
} }
} }
if eduCapacity := maxEDUsInMemory - len(oq.pendingEDUs); eduCapacity > 0 { if eduCapacity := maxEDUsInMemory - len(oq.pendingEDUs); eduCapacity > 0 {
// We have room in memory for some EDUs - let's request no more than that. // We have room in memory for some EDUs - let's request no more than that.
if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, eduCapacity); err == nil { if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, maxEDUsInMemory); err == nil {
if len(edus) == maxEDUsInMemory {
overflowed = true
}
for receipt, edu := range edus { for receipt, edu := range edus {
if _, ok := gotEDUs[receipt.String()]; ok { if _, ok := gotEDUs[receipt.String()]; ok {
continue continue
} }
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu}) oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu})
retrieved = true retrieved = true
if len(oq.pendingEDUs) == maxEDUsInMemory {
break
}
} }
} else { } else {
logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination) logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination)
} }
} }
// If we've retrieved all of the events from the database with room to spare // If we've retrieved all of the events from the database with room to spare
// in memory then we'll no longer consider this queue to be overflowed. // in memory then we'll no longer consider this queue to be overflowed.
if len(oq.pendingPDUs) < maxPDUsInMemory && len(oq.pendingEDUs) < maxEDUsInMemory { if !overflowed {
oq.overflowed.Store(false) oq.overflowed.Store(false)
} else {
} }
// If we've retrieved some events then notify the destination queue goroutine. // If we've retrieved some events then notify the destination queue goroutine.
if retrieved { if retrieved {
@ -238,6 +280,24 @@ func (oq *destinationQueue) getPendingFromDatabase() {
} }
} }
// checkNotificationsOnClose checks for any remaining notifications
// and starts a new backgroundSend goroutine if any exist.
func (oq *destinationQueue) checkNotificationsOnClose() {
// NOTE : If we are stopping the queue due to blacklist then it
// doesn't matter if we have been notified of new work since
// this queue instance will be deleted anyway.
if !oq.statistics.Blacklisted() {
select {
case <-oq.notify:
// We received a new notification in between the
// idle timeout firing and stopping the goroutine.
// Immediately restart the queue.
oq.wakeQueueAndNotify()
default:
}
}
}
// backgroundSend is the worker goroutine for sending events. // backgroundSend is the worker goroutine for sending events.
func (oq *destinationQueue) backgroundSend() { func (oq *destinationQueue) backgroundSend() {
// Check if a worker is already running, and if it isn't, then // Check if a worker is already running, and if it isn't, then
@ -245,10 +305,17 @@ func (oq *destinationQueue) backgroundSend() {
if !oq.running.CompareAndSwap(false, true) { if !oq.running.CompareAndSwap(false, true) {
return return
} }
// Register queue cleanup functions.
// NOTE : The ordering here is very intentional.
defer oq.checkNotificationsOnClose()
defer oq.running.Store(false)
destinationQueueRunning.Inc() destinationQueueRunning.Inc()
defer destinationQueueRunning.Dec() defer destinationQueueRunning.Dec()
defer oq.queues.clearQueue(oq)
defer oq.running.Store(false) idleTimeout := time.NewTimer(queueIdleTimeout)
defer idleTimeout.Stop()
// Mark the queue as overflowed, so we will consult the database // Mark the queue as overflowed, so we will consult the database
// to see if there's anything new to send. // to see if there's anything new to send.
@ -261,59 +328,33 @@ func (oq *destinationQueue) backgroundSend() {
oq.getPendingFromDatabase() oq.getPendingFromDatabase()
} }
// Reset the queue idle timeout.
if !idleTimeout.Stop() {
select {
case <-idleTimeout.C:
default:
}
}
idleTimeout.Reset(queueIdleTimeout)
// If we have nothing to do then wait either for incoming events, or // If we have nothing to do then wait either for incoming events, or
// until we hit an idle timeout. // until we hit an idle timeout.
select { select {
case <-oq.notify: case <-oq.notify:
// There's work to do, either because getPendingFromDatabase // There's work to do, either because getPendingFromDatabase
// told us there is, or because a new event has come in via // told us there is, a new event has come in via sendEvent/sendEDU,
// sendEvent/sendEDU. // or we are backing off and it is time to retry.
case <-time.After(queueIdleTimeout): case <-idleTimeout.C:
// The worker is idle so stop the goroutine. It'll get // The worker is idle so stop the goroutine. It'll get
// restarted automatically the next time we have an event to // restarted automatically the next time we have an event to
// send. // send.
return return
case <-oq.process.Context().Done(): case <-oq.process.Context().Done():
// The parent process is shutting down, so stop. // The parent process is shutting down, so stop.
oq.statistics.ClearBackoff()
return return
} }
// If we are backing off this server then wait for the
// backoff duration to complete first, or until explicitly
// told to retry.
until, blacklisted := oq.statistics.BackoffInfo()
if blacklisted {
// It's been suggested that we should give up because the backoff
// has exceeded a maximum allowable value. Clean up the in-memory
// buffers at this point. The PDU clean-up is already on a defer.
logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
oq.pendingMutex.Lock()
for i := range oq.pendingPDUs {
oq.pendingPDUs[i] = nil
}
for i := range oq.pendingEDUs {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = nil
oq.pendingEDUs = nil
oq.pendingMutex.Unlock()
return
}
if until != nil && until.After(time.Now()) {
// We haven't backed off yet, so wait for the suggested amount of
// time.
duration := time.Until(*until)
logrus.Debugf("Backing off %q for %s", oq.destination, duration)
oq.backingOff.Store(true)
destinationQueueBackingOff.Inc()
select {
case <-time.After(duration):
case <-oq.interruptBackoff:
}
destinationQueueBackingOff.Dec()
oq.backingOff.Store(false)
}
// Work out which PDUs/EDUs to include in the next transaction. // Work out which PDUs/EDUs to include in the next transaction.
oq.pendingMutex.RLock() oq.pendingMutex.RLock()
pduCount := len(oq.pendingPDUs) pduCount := len(oq.pendingPDUs)
@ -328,99 +369,52 @@ func (oq *destinationQueue) backgroundSend() {
toSendEDUs := oq.pendingEDUs[:eduCount] toSendEDUs := oq.pendingEDUs[:eduCount]
oq.pendingMutex.RUnlock() oq.pendingMutex.RUnlock()
// If we didn't get anything from the database and there are no
// pending EDUs then there's nothing to do - stop here.
if pduCount == 0 && eduCount == 0 {
continue
}
// If we have pending PDUs or EDUs then construct a transaction. // If we have pending PDUs or EDUs then construct a transaction.
// Try sending the next transaction and see what happens. // Try sending the next transaction and see what happens.
transaction, pc, ec, terr := oq.nextTransaction(toSendPDUs, toSendEDUs) terr := oq.nextTransaction(toSendPDUs, toSendEDUs)
if terr != nil { if terr != nil {
// We failed to send the transaction. Mark it as a failure. // We failed to send the transaction. Mark it as a failure.
oq.statistics.Failure() _, blacklisted := oq.statistics.Failure()
if !blacklisted {
} else if transaction { // Register the backoff state and exit the goroutine.
// If we successfully sent the transaction then clear out // It'll get restarted automatically when the backoff
// the pending events and EDUs, and wipe our transaction ID. // completes.
oq.statistics.Success() oq.backingOff.Store(true)
oq.pendingMutex.Lock() destinationQueueBackingOff.Inc()
for i := range oq.pendingPDUs[:pc] { return
oq.pendingPDUs[i] = nil } else {
// Immediately trigger the blacklist logic.
oq.blacklistDestination()
return
} }
for i := range oq.pendingEDUs[:ec] { } else {
oq.pendingEDUs[i] = nil oq.handleTransactionSuccess(pduCount, eduCount)
}
oq.pendingPDUs = oq.pendingPDUs[pc:]
oq.pendingEDUs = oq.pendingEDUs[ec:]
oq.pendingMutex.Unlock()
} }
} }
} }
// nextTransaction creates a new transaction from the pending event // nextTransaction creates a new transaction from the pending event
// queue and sends it. Returns true if a transaction was sent or // queue and sends it.
// false otherwise. // Returns an error if the transaction wasn't sent.
func (oq *destinationQueue) nextTransaction( func (oq *destinationQueue) nextTransaction(
pdus []*queuedPDU, pdus []*queuedPDU,
edus []*queuedEDU, edus []*queuedEDU,
) (bool, int, int, error) { ) error {
// If there's no projected transaction ID then generate one. If
// the transaction succeeds then we'll set it back to "" so that
// we generate a new one next time. If it fails, we'll preserve
// it so that we retry with the same transaction ID.
oq.transactionIDMutex.Lock()
if oq.transactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
oq.transactionIDMutex.Unlock()
// Create the transaction. // Create the transaction.
t := gomatrixserverlib.Transaction{ t, pduReceipts, eduReceipts := oq.createTransaction(pdus, edus)
PDUs: []json.RawMessage{},
EDUs: []gomatrixserverlib.EDU{},
}
t.Origin = oq.origin
t.Destination = oq.destination
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = oq.transactionID
// If we didn't get anything from the database and there are no
// pending EDUs then there's nothing to do - stop here.
if len(pdus) == 0 && len(edus) == 0 {
return false, 0, 0, nil
}
var pduReceipts []*shared.Receipt
var eduReceipts []*shared.Receipt
// Go through PDUs that we retrieved from the database, if any,
// and add them into the transaction.
for _, pdu := range pdus {
if pdu == nil || pdu.pdu == nil {
continue
}
// Append the JSON of the event, since this is a json.RawMessage type in the
// gomatrixserverlib.Transaction struct
t.PDUs = append(t.PDUs, pdu.pdu.JSON())
pduReceipts = append(pduReceipts, pdu.receipt)
}
// Do the same for pending EDUS in the queue.
for _, edu := range edus {
if edu == nil || edu.edu == nil {
continue
}
t.EDUs = append(t.EDUs, *edu.edu)
eduReceipts = append(eduReceipts, edu.receipt)
}
logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
// Try to send the transaction to the destination server. // Try to send the transaction to the destination server.
// TODO: we should check for 500-ish fails vs 400-ish here,
// since we shouldn't queue things indefinitely in response
// to a 400-ish error
ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5) ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5)
defer cancel() defer cancel()
_, err := oq.client.SendTransaction(ctx, t) _, err := oq.client.SendTransaction(ctx, t)
switch err.(type) { switch errResponse := err.(type) {
case nil: case nil:
// Clean up the transaction in the database. // Clean up the transaction in the database.
if pduReceipts != nil { if pduReceipts != nil {
@ -439,16 +433,128 @@ func (oq *destinationQueue) nextTransaction(
oq.transactionIDMutex.Lock() oq.transactionIDMutex.Lock()
oq.transactionID = "" oq.transactionID = ""
oq.transactionIDMutex.Unlock() oq.transactionIDMutex.Unlock()
return true, len(t.PDUs), len(t.EDUs), nil return nil
case gomatrix.HTTPError: case gomatrix.HTTPError:
// Report that we failed to send the transaction and we // Report that we failed to send the transaction and we
// will retry again, subject to backoff. // will retry again, subject to backoff.
return false, 0, 0, err
// TODO: we should check for 500-ish fails vs 400-ish here,
// since we shouldn't queue things indefinitely in response
// to a 400-ish error
code := errResponse.Code
logrus.Debug("Transaction failed with HTTP", code)
return err
default: default:
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"destination": oq.destination, "destination": oq.destination,
logrus.ErrorKey: err, logrus.ErrorKey: err,
}).Debugf("Failed to send transaction %q", t.TransactionID) }).Debugf("Failed to send transaction %q", t.TransactionID)
return false, 0, 0, err return err
}
}
// createTransaction generates a gomatrixserverlib.Transaction from the provided pdus and edus.
// It also returns the associated event receipts so they can be cleaned from the database in
// the case of a successful transaction.
func (oq *destinationQueue) createTransaction(
pdus []*queuedPDU,
edus []*queuedEDU,
) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) {
// If there's no projected transaction ID then generate one. If
// the transaction succeeds then we'll set it back to "" so that
// we generate a new one next time. If it fails, we'll preserve
// it so that we retry with the same transaction ID.
oq.transactionIDMutex.Lock()
if oq.transactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
oq.transactionIDMutex.Unlock()
t := gomatrixserverlib.Transaction{
PDUs: []json.RawMessage{},
EDUs: []gomatrixserverlib.EDU{},
}
t.Origin = oq.origin
t.Destination = oq.destination
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = oq.transactionID
var pduReceipts []*shared.Receipt
var eduReceipts []*shared.Receipt
// Go through PDUs that we retrieved from the database, if any,
// and add them into the transaction.
for _, pdu := range pdus {
// These should never be nil.
if pdu == nil || pdu.pdu == nil {
continue
}
// Append the JSON of the event, since this is a json.RawMessage type in the
// gomatrixserverlib.Transaction struct
t.PDUs = append(t.PDUs, pdu.pdu.JSON())
pduReceipts = append(pduReceipts, pdu.receipt)
}
// Do the same for pending EDUS in the queue.
for _, edu := range edus {
// These should never be nil.
if edu == nil || edu.edu == nil {
continue
}
t.EDUs = append(t.EDUs, *edu.edu)
eduReceipts = append(eduReceipts, edu.receipt)
}
return t, pduReceipts, eduReceipts
}
// blacklistDestination removes all pending PDUs and EDUs that have been cached
// and deletes this queue.
func (oq *destinationQueue) blacklistDestination() {
// It's been suggested that we should give up because the backoff
// has exceeded a maximum allowable value. Clean up the in-memory
// buffers at this point. The PDU clean-up is already on a defer.
logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
oq.pendingMutex.Lock()
for i := range oq.pendingPDUs {
oq.pendingPDUs[i] = nil
}
for i := range oq.pendingEDUs {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = nil
oq.pendingEDUs = nil
oq.pendingMutex.Unlock()
// Delete this queue as no more messages will be sent to this
// destination until it is no longer blacklisted.
oq.statistics.AssignBackoffNotifier(nil)
oq.queues.clearQueue(oq)
}
// handleTransactionSuccess updates the cached event queues as well as the success and
// backoff information for this server.
func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) {
// If we successfully sent the transaction then clear out
// the pending events and EDUs, and wipe our transaction ID.
oq.statistics.Success()
oq.pendingMutex.Lock()
for i := range oq.pendingPDUs[:pduCount] {
oq.pendingPDUs[i] = nil
}
for i := range oq.pendingEDUs[:eduCount] {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = oq.pendingPDUs[pduCount:]
oq.pendingEDUs = oq.pendingEDUs[eduCount:]
oq.pendingMutex.Unlock()
if len(oq.pendingPDUs) > 0 || len(oq.pendingEDUs) > 0 {
select {
case oq.notify <- struct{}{}:
default:
}
} }
} }

View File

@ -162,23 +162,25 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d
if !ok || oq == nil { if !ok || oq == nil {
destinationQueueTotal.Inc() destinationQueueTotal.Inc()
oq = &destinationQueue{ oq = &destinationQueue{
queues: oqs, queues: oqs,
db: oqs.db, db: oqs.db,
process: oqs.process, process: oqs.process,
rsAPI: oqs.rsAPI, rsAPI: oqs.rsAPI,
origin: oqs.origin, origin: oqs.origin,
destination: destination, destination: destination,
client: oqs.client, client: oqs.client,
statistics: oqs.statistics.ForServer(destination), statistics: oqs.statistics.ForServer(destination),
notify: make(chan struct{}, 1), notify: make(chan struct{}, 1),
interruptBackoff: make(chan bool), signing: oqs.signing,
signing: oqs.signing,
} }
oq.statistics.AssignBackoffNotifier(oq.handleBackoffNotifier)
oqs.queues[destination] = oq oqs.queues[destination] = oq
} }
return oq return oq
} }
// clearQueue removes the queue for the provided destination from the
// set of destination queues.
func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) { func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) {
oqs.queuesMutex.Lock() oqs.queuesMutex.Lock()
defer oqs.queuesMutex.Unlock() defer oqs.queuesMutex.Unlock()
@ -332,7 +334,9 @@ func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) {
if oqs.disabled { if oqs.disabled {
return return
} }
oqs.statistics.ForServer(srv).RemoveBlacklist()
if queue := oqs.getQueue(srv); queue != nil { if queue := oqs.getQueue(srv); queue != nil {
queue.statistics.ClearBackoff()
queue.wakeQueueIfNeeded() queue.wakeQueueIfNeeded()
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -2,6 +2,7 @@ package statistics
import ( import (
"math" "math"
"math/rand"
"sync" "sync"
"time" "time"
@ -20,12 +21,23 @@ type Statistics struct {
servers map[gomatrixserverlib.ServerName]*ServerStatistics servers map[gomatrixserverlib.ServerName]*ServerStatistics
mutex sync.RWMutex mutex sync.RWMutex
backoffTimers map[gomatrixserverlib.ServerName]*time.Timer
backoffMutex sync.RWMutex
// How many times should we tolerate consecutive failures before we // How many times should we tolerate consecutive failures before we
// just blacklist the host altogether? The backoff is exponential, // just blacklist the host altogether? The backoff is exponential,
// so the max time here to attempt is 2**failures seconds. // so the max time here to attempt is 2**failures seconds.
FailuresUntilBlacklist uint32 FailuresUntilBlacklist uint32
} }
func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistics {
return Statistics{
DB: db,
FailuresUntilBlacklist: failuresUntilBlacklist,
backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer),
}
}
// ForServer returns server statistics for the given server name. If it // ForServer returns server statistics for the given server name. If it
// does not exist, it will create empty statistics and return those. // does not exist, it will create empty statistics and return those.
func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics { func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics {
@ -45,7 +57,6 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
server = &ServerStatistics{ server = &ServerStatistics{
statistics: s, statistics: s,
serverName: serverName, serverName: serverName,
interrupt: make(chan struct{}),
} }
s.servers[serverName] = server s.servers[serverName] = server
s.mutex.Unlock() s.mutex.Unlock()
@ -64,29 +75,43 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
// many times we failed etc. It also manages the backoff time and black- // many times we failed etc. It also manages the backoff time and black-
// listing a remote host if it remains uncooperative. // listing a remote host if it remains uncooperative.
type ServerStatistics struct { type ServerStatistics struct {
statistics *Statistics // statistics *Statistics //
serverName gomatrixserverlib.ServerName // serverName gomatrixserverlib.ServerName //
blacklisted atomic.Bool // is the node blacklisted blacklisted atomic.Bool // is the node blacklisted
backoffStarted atomic.Bool // is the backoff started backoffStarted atomic.Bool // is the backoff started
backoffUntil atomic.Value // time.Time until this backoff interval ends backoffUntil atomic.Value // time.Time until this backoff interval ends
backoffCount atomic.Uint32 // number of times BackoffDuration has been called backoffCount atomic.Uint32 // number of times BackoffDuration has been called
interrupt chan struct{} // interrupts the backoff goroutine successCounter atomic.Uint32 // how many times have we succeeded?
successCounter atomic.Uint32 // how many times have we succeeded? backoffNotifier func() // notifies destination queue when backoff completes
notifierMutex sync.Mutex
} }
const maxJitterMultiplier = 1.4
const minJitterMultiplier = 0.8
// duration returns how long the next backoff interval should be. // duration returns how long the next backoff interval should be.
func (s *ServerStatistics) duration(count uint32) time.Duration { func (s *ServerStatistics) duration(count uint32) time.Duration {
return time.Second * time.Duration(math.Exp2(float64(count))) // Add some jitter to minimise the chance of having multiple backoffs
// ending at the same time.
jitter := rand.Float64()*(maxJitterMultiplier-minJitterMultiplier) + minJitterMultiplier
duration := time.Millisecond * time.Duration(math.Exp2(float64(count))*jitter*1000)
return duration
} }
// cancel will interrupt the currently active backoff. // cancel will interrupt the currently active backoff.
func (s *ServerStatistics) cancel() { func (s *ServerStatistics) cancel() {
s.blacklisted.Store(false) s.blacklisted.Store(false)
s.backoffUntil.Store(time.Time{}) s.backoffUntil.Store(time.Time{})
select {
case s.interrupt <- struct{}{}: s.ClearBackoff()
default: }
}
// AssignBackoffNotifier configures the channel to send to when
// a backoff completes.
func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) {
s.notifierMutex.Lock()
defer s.notifierMutex.Unlock()
s.backoffNotifier = notifier
} }
// Success updates the server statistics with a new successful // Success updates the server statistics with a new successful
@ -95,8 +120,8 @@ func (s *ServerStatistics) cancel() {
// we will unblacklist it. // we will unblacklist it.
func (s *ServerStatistics) Success() { func (s *ServerStatistics) Success() {
s.cancel() s.cancel()
s.successCounter.Inc()
s.backoffCount.Store(0) s.backoffCount.Store(0)
s.successCounter.Inc()
if s.statistics.DB != nil { if s.statistics.DB != nil {
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
@ -105,13 +130,17 @@ func (s *ServerStatistics) Success() {
} }
// Failure marks a failure and starts backing off if needed. // Failure marks a failure and starts backing off if needed.
// The next call to BackoffIfRequired will do the right thing // It will return the time that the current failure
// after this. It will return the time that the current failure
// will result in backoff waiting until, and a bool signalling // will result in backoff waiting until, and a bool signalling
// whether we have blacklisted and therefore to give up. // whether we have blacklisted and therefore to give up.
func (s *ServerStatistics) Failure() (time.Time, bool) { func (s *ServerStatistics) Failure() (time.Time, bool) {
// Return immediately if we have blacklisted this node.
if s.blacklisted.Load() {
return time.Time{}, true
}
// If we aren't already backing off, this call will start // If we aren't already backing off, this call will start
// a new backoff period. Increase the failure counter and // a new backoff period, increase the failure counter and
// start a goroutine which will wait out the backoff and // start a goroutine which will wait out the backoff and
// unset the backoffStarted flag when done. // unset the backoffStarted flag when done.
if s.backoffStarted.CompareAndSwap(false, true) { if s.backoffStarted.CompareAndSwap(false, true) {
@ -122,40 +151,48 @@ func (s *ServerStatistics) Failure() (time.Time, bool) {
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName) logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
} }
} }
s.ClearBackoff()
return time.Time{}, true return time.Time{}, true
} }
go func() { // We're starting a new back off so work out what the next interval
until, ok := s.backoffUntil.Load().(time.Time) // will be.
if ok && !until.IsZero() { count := s.backoffCount.Load()
select { until := time.Now().Add(s.duration(count))
case <-time.After(time.Until(until)): s.backoffUntil.Store(until)
case <-s.interrupt:
} s.statistics.backoffMutex.Lock()
s.backoffStarted.Store(false) defer s.statistics.backoffMutex.Unlock()
} s.statistics.backoffTimers[s.serverName] = time.AfterFunc(time.Until(until), s.backoffFinished)
}()
} }
// Check if we have blacklisted this node. return s.backoffUntil.Load().(time.Time), false
if s.blacklisted.Load() { }
return time.Now(), true
}
// If we're already backing off and we haven't yet surpassed // ClearBackoff stops the backoff timer for this destination if it is running
// the deadline then return that. Repeated calls to Failure // and removes the timer from the backoffTimers map.
// within a single backoff interval will have no side effects. func (s *ServerStatistics) ClearBackoff() {
if until, ok := s.backoffUntil.Load().(time.Time); ok && !time.Now().After(until) { // If the timer is still running then stop it so it's memory is cleaned up sooner.
return until, false s.statistics.backoffMutex.Lock()
defer s.statistics.backoffMutex.Unlock()
if timer, ok := s.statistics.backoffTimers[s.serverName]; ok {
timer.Stop()
} }
delete(s.statistics.backoffTimers, s.serverName)
// We're either backing off and have passed the deadline, or s.backoffStarted.Store(false)
// we aren't backing off, so work out what the next interval }
// will be.
count := s.backoffCount.Load() // backoffFinished will clear the previous backoff and notify the destination queue.
until := time.Now().Add(s.duration(count)) func (s *ServerStatistics) backoffFinished() {
s.backoffUntil.Store(until) s.ClearBackoff()
return until, false
// Notify the destinationQueue if one is currently running.
s.notifierMutex.Lock()
defer s.notifierMutex.Unlock()
if s.backoffNotifier != nil {
s.backoffNotifier()
}
} }
// BackoffInfo returns information about the current or previous backoff. // BackoffInfo returns information about the current or previous backoff.
@ -174,6 +211,12 @@ func (s *ServerStatistics) Blacklisted() bool {
return s.blacklisted.Load() return s.blacklisted.Load()
} }
// RemoveBlacklist removes the blacklisted status from the server.
func (s *ServerStatistics) RemoveBlacklist() {
s.cancel()
s.backoffCount.Store(0)
}
// SuccessCount returns the number of successful requests. This is // SuccessCount returns the number of successful requests. This is
// usually useful in constructing transaction IDs. // usually useful in constructing transaction IDs.
func (s *ServerStatistics) SuccessCount() uint32 { func (s *ServerStatistics) SuccessCount() uint32 {

View File

@ -7,9 +7,7 @@ import (
) )
func TestBackoff(t *testing.T) { func TestBackoff(t *testing.T) {
stats := Statistics{ stats := NewStatistics(nil, 7)
FailuresUntilBlacklist: 7,
}
server := ServerStatistics{ server := ServerStatistics{
statistics: &stats, statistics: &stats,
serverName: "test.com", serverName: "test.com",
@ -36,7 +34,7 @@ func TestBackoff(t *testing.T) {
// Get the duration. // Get the duration.
_, blacklist := server.BackoffInfo() _, blacklist := server.BackoffInfo()
duration := time.Until(until).Round(time.Second) duration := time.Until(until)
// Unset the backoff, or otherwise our next call will think that // Unset the backoff, or otherwise our next call will think that
// there's a backoff in progress and return the same result. // there's a backoff in progress and return the same result.
@ -57,8 +55,17 @@ func TestBackoff(t *testing.T) {
// Check if the duration is what we expect. // Check if the duration is what we expect.
t.Logf("Backoff %d is for %s", i, duration) t.Logf("Backoff %d is for %s", i, duration)
if wanted := time.Second * time.Duration(math.Exp2(float64(i))); !blacklist && duration != wanted { roundingAllowance := 0.01
t.Fatalf("Backoff %d should have been %s but was %s", i, wanted, duration) minDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*minJitterMultiplier*1000-roundingAllowance)
maxDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*maxJitterMultiplier*1000+roundingAllowance)
var inJitterRange bool
if duration >= minDuration && duration <= maxDuration {
inJitterRange = true
} else {
inJitterRange = false
}
if !blacklist && !inJitterRange {
t.Fatalf("Backoff %d should have been between %s and %s but was %s", i, minDuration, maxDuration, duration)
} }
} }
} }

View File

@ -52,6 +52,10 @@ type Receipt struct {
nid int64 nid int64
} }
func NewReceipt(nid int64) Receipt {
return Receipt{nid: nid}
}
func (r *Receipt) String() string { func (r *Receipt) String() string {
return fmt.Sprintf("%d", r.nid) return fmt.Sprintf("%d", r.nid)
} }

2
go.mod
View File

@ -50,6 +50,7 @@ require (
golang.org/x/term v0.0.0-20220919170432-7a66f970e087 golang.org/x/term v0.0.0-20220919170432-7a66f970e087
gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/h2non/bimg.v1 v1.1.9
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
gotest.tools/v3 v3.4.0
nhooyr.io/websocket v1.8.7 nhooyr.io/websocket v1.8.7
) )
@ -127,7 +128,6 @@ require (
gopkg.in/macaroon.v2 v2.1.0 // indirect gopkg.in/macaroon.v2 v2.1.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
gotest.tools/v3 v3.4.0 // indirect
) )
go 1.18 go 1.18