From 57320897cba655046d47d2cddc7a1381a04d5c66 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 2 Jul 2021 12:33:27 +0100 Subject: [PATCH] Federation API workers for /send to reduce memory usage (#1897) * Try to process rooms concurrently in FS /send * Clean up * Use request context so that dead things don't linger for so long * Remove mutex * Free up pdus slice so only references remaining are in channel * Revert "Remove mutex" This reverts commit 8558075e8c9bab3c1d8b2252b4ab40c7eaf774e8. * Process EDUs in parallel * Try refactoring /send concurrency * Fix waitgroup * Release on waitgroup * Respond to transaction * Reduce CPU usage, fix unit tests * Tweaks * Move into one file --- federationapi/routing/send.go | 221 ++++++++++++++++++++++------------ 1 file changed, 147 insertions(+), 74 deletions(-) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index a514127c..ae9a63fc 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -16,7 +16,6 @@ package routing import ( "context" - "database/sql" "encoding/json" "errors" "fmt" @@ -24,7 +23,6 @@ import ( "sync" "time" - "github.com/getsentry/sentry-go" "github.com/matrix-org/dendrite/clientapi/jsonerror" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" federationAPI "github.com/matrix-org/dendrite/federationapi/api" @@ -36,6 +34,7 @@ import ( "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + "go.uber.org/atomic" ) const ( @@ -90,6 +89,67 @@ func init() { ) } +type sendFIFOQueue struct { + tasks []*inputTask + count int + mutex sync.Mutex + notifs chan struct{} +} + +func newSendFIFOQueue() *sendFIFOQueue { + q := &sendFIFOQueue{ + notifs: make(chan struct{}, 1), + } + return q +} + +func (q *sendFIFOQueue) push(frame *inputTask) { + q.mutex.Lock() + defer q.mutex.Unlock() + q.tasks = append(q.tasks, frame) + q.count++ + select { + case q.notifs <- struct{}{}: + default: + } +} + +// pop returns the first item of the queue, if there is one. +// The second return value will indicate if a task was returned. +func (q *sendFIFOQueue) pop() (*inputTask, bool) { + q.mutex.Lock() + defer q.mutex.Unlock() + if q.count == 0 { + return nil, false + } + frame := q.tasks[0] + q.tasks[0] = nil + q.tasks = q.tasks[1:] + q.count-- + if q.count == 0 { + // Force a GC of the underlying array, since it might have + // grown significantly if the queue was hammered for some reason + q.tasks = nil + } + return frame, true +} + +type inputTask struct { + ctx context.Context + t *txnReq + event *gomatrixserverlib.Event + wg *sync.WaitGroup + err error // written back by worker, only safe to read when all tasks are done + duration time.Duration // written back by worker, only safe to read when all tasks are done +} + +type inputWorker struct { + running atomic.Bool + input *sendFIFOQueue +} + +var inputWorkers sync.Map // room ID -> *inputWorker + // Send implements /_matrix/federation/v1/send/{txnID} func Send( httpReq *http.Request, @@ -193,8 +253,12 @@ type txnFederationClient interface { func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { results := make(map[string]gomatrixserverlib.PDUResult) + //var resultsMutex sync.Mutex + + var wg sync.WaitGroup + var tasks []*inputTask + wg.Add(1) // for processEDUs - pdus := []*gomatrixserverlib.HeaderedEvent{} for _, pdu := range t.PDUs { pduCountTotal.WithLabelValues("total").Inc() var header struct { @@ -245,83 +309,97 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res } continue } - pdus = append(pdus, event.Headered(verRes.RoomVersion)) + v, _ := inputWorkers.LoadOrStore(event.RoomID(), &inputWorker{ + input: newSendFIFOQueue(), + }) + worker := v.(*inputWorker) + if !worker.running.Load() { + go worker.run() + } + wg.Add(1) + task := &inputTask{ + ctx: ctx, + t: t, + event: event, + wg: &wg, + } + tasks = append(tasks, task) + worker.input.push(task) } - // Process the events. - for _, e := range pdus { - evStart := time.Now() - if err := t.processEvent(ctx, e.Unwrap()); err != nil { - // If the error is due to the event itself being bad then we skip - // it and move onto the next event. We report an error so that the - // sender knows that we have skipped processing it. - // - // However if the event is due to a temporary failure in our server - // such as a database being unavailable then we should bail, and - // hope that the sender will retry when we are feeling better. - // - // It is uncertain what we should do if an event fails because - // we failed to fetch more information from the sending server. - // For example if a request to /state fails. - // If we skip the event then we risk missing the event until we - // receive another event referencing it. - // If we bail and stop processing then we risk wedging incoming - // transactions from that server forever. - if isProcessingErrorFatal(err) { - sentry.CaptureException(err) - // Any other error should be the result of a temporary error in - // our server so we should bail processing the transaction entirely. - util.GetLogger(ctx).Warnf("Processing %s failed fatally: %s", e.EventID(), err) - jsonErr := util.ErrorResponse(err) - processEventSummary.WithLabelValues(t.work, MetricsOutcomeFatal).Observe( - float64(time.Since(evStart).Nanoseconds()) / 1000., - ) - return nil, &jsonErr - } else { - // Auth errors mean the event is 'rejected' which have to be silent to appease sytest - errMsg := "" - outcome := MetricsOutcomeRejected - _, rejected := err.(*gomatrixserverlib.NotAllowed) - if !rejected { - errMsg = err.Error() - outcome = MetricsOutcomeFail - } - util.GetLogger(ctx).WithError(err).WithField("event_id", e.EventID()).WithField("rejected", rejected).Warn( - "Failed to process incoming federation event, skipping", - ) - processEventSummary.WithLabelValues(t.work, outcome).Observe( - float64(time.Since(evStart).Nanoseconds()) / 1000., - ) - results[e.EventID()] = gomatrixserverlib.PDUResult{ - Error: errMsg, - } + go func() { + defer wg.Done() + t.processEDUs(ctx) + }() + + wg.Wait() + + for _, task := range tasks { + if task.err != nil { + results[task.event.EventID()] = gomatrixserverlib.PDUResult{ + Error: task.err.Error(), } } else { - results[e.EventID()] = gomatrixserverlib.PDUResult{} - pduCountTotal.WithLabelValues("success").Inc() - processEventSummary.WithLabelValues(t.work, MetricsOutcomeOK).Observe( - float64(time.Since(evStart).Nanoseconds()) / 1000., - ) + results[task.event.EventID()] = gomatrixserverlib.PDUResult{} } } - t.processEDUs(ctx) if c := len(results); c > 0 { util.GetLogger(ctx).Infof("Processed %d PDUs from transaction %q", c, t.TransactionID) } return &gomatrixserverlib.RespSend{PDUs: results}, nil } -// isProcessingErrorFatal returns true if the error is really bad and -// we should stop processing the transaction, and returns false if it -// is just some less serious error about a specific event. -func isProcessingErrorFatal(err error) bool { - switch err { - case sql.ErrConnDone: - case sql.ErrTxDone: - return true +func (t *inputWorker) run() { + if !t.running.CAS(false, true) { + return + } + defer t.running.Store(false) + for { + task, ok := t.input.pop() + if !ok { + return + } + if task == nil { + continue + } + func() { + defer task.wg.Done() + select { + case <-task.ctx.Done(): + task.err = context.DeadlineExceeded + return + default: + evStart := time.Now() + task.err = task.t.processEvent(task.ctx, task.event) + task.duration = time.Since(evStart) + if err := task.err; err != nil { + switch err.(type) { + case *gomatrixserverlib.NotAllowed: + processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeRejected).Observe( + float64(time.Since(evStart).Nanoseconds()) / 1000., + ) + util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", true).Warn( + "Failed to process incoming federation event, skipping", + ) + task.err = nil // make "rejected" failures silent + default: + processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeFail).Observe( + float64(time.Since(evStart).Nanoseconds()) / 1000., + ) + util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", false).Warn( + "Failed to process incoming federation event, skipping", + ) + } + } else { + pduCountTotal.WithLabelValues("success").Inc() + processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeOK).Observe( + float64(time.Since(evStart).Nanoseconds()) / 1000., + ) + } + } + }() } - return false } type roomNotFoundError struct { @@ -633,11 +711,6 @@ func (t *txnReq) processEventWithMissingState( processEventWithMissingStateMutexes.Lock(e.RoomID()) defer processEventWithMissingStateMutexes.Unlock(e.RoomID()) - // Do this with a fresh context, so that we keep working even if the - // original request times out. With any luck, by the time the remote - // side retries, we'll have fetched the missing state. - gmectx, cancel := context.WithTimeout(context.Background(), time.Minute*5) - defer cancel() // We are missing the previous events for this events. // This means that there is a gap in our view of the history of the // room. There two ways that we can handle such a gap: @@ -658,7 +731,7 @@ func (t *txnReq) processEventWithMissingState( // - fill in the gap completely then process event `e` returning no backwards extremity // - fail to fill in the gap and tell us to terminate the transaction err=not nil // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction - newEvents, err := t.getMissingEvents(gmectx, e, roomVersion) + newEvents, err := t.getMissingEvents(ctx, e, roomVersion) if err != nil { return err } @@ -685,7 +758,7 @@ func (t *txnReq) processEventWithMissingState( // Look up what the state is after the backward extremity. This will either // come from the roomserver, if we know all the required events, or it will // come from a remote server via /state_ids if not. - prevState, trustworthy, lerr := t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID) + prevState, trustworthy, lerr := t.lookupStateAfterEvent(ctx, roomVersion, backwardsExtremity.RoomID(), prevEventID) if lerr != nil { util.GetLogger(ctx).WithError(lerr).Errorf("Failed to lookup state after prev_event: %s", prevEventID) return lerr @@ -729,7 +802,7 @@ func (t *txnReq) processEventWithMissingState( } // There's more than one previous state - run them all through state res t.roomsMu.Lock(e.RoomID()) - resolvedState, err = t.resolveStatesAndCheck(gmectx, roomVersion, respStates, backwardsExtremity) + resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, backwardsExtremity) t.roomsMu.Unlock(e.RoomID()) if err != nil { util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID())