Federation API workers for /send to reduce memory usage (#1897)

* Try to process rooms concurrently in FS /send

* Clean up

* Use request context so that dead things don't linger for so long

* Remove mutex

* Free up pdus slice so only references remaining are in channel

* Revert "Remove mutex"

This reverts commit 8558075e8c9bab3c1d8b2252b4ab40c7eaf774e8.

* Process EDUs in parallel

* Try refactoring /send concurrency

* Fix waitgroup

* Release on waitgroup

* Respond to transaction

* Reduce CPU usage, fix unit tests

* Tweaks

* Move into one file
This commit is contained in:
Neil Alexander 2021-07-02 12:33:27 +01:00 committed by GitHub
parent 192a7a7923
commit 57320897cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -16,7 +16,6 @@ package routing
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -24,7 +23,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
federationAPI "github.com/matrix-org/dendrite/federationapi/api" federationAPI "github.com/matrix-org/dendrite/federationapi/api"
@ -36,6 +34,7 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"go.uber.org/atomic"
) )
const ( const (
@ -90,6 +89,67 @@ func init() {
) )
} }
type sendFIFOQueue struct {
tasks []*inputTask
count int
mutex sync.Mutex
notifs chan struct{}
}
func newSendFIFOQueue() *sendFIFOQueue {
q := &sendFIFOQueue{
notifs: make(chan struct{}, 1),
}
return q
}
func (q *sendFIFOQueue) push(frame *inputTask) {
q.mutex.Lock()
defer q.mutex.Unlock()
q.tasks = append(q.tasks, frame)
q.count++
select {
case q.notifs <- struct{}{}:
default:
}
}
// pop returns the first item of the queue, if there is one.
// The second return value will indicate if a task was returned.
func (q *sendFIFOQueue) pop() (*inputTask, bool) {
q.mutex.Lock()
defer q.mutex.Unlock()
if q.count == 0 {
return nil, false
}
frame := q.tasks[0]
q.tasks[0] = nil
q.tasks = q.tasks[1:]
q.count--
if q.count == 0 {
// Force a GC of the underlying array, since it might have
// grown significantly if the queue was hammered for some reason
q.tasks = nil
}
return frame, true
}
type inputTask struct {
ctx context.Context
t *txnReq
event *gomatrixserverlib.Event
wg *sync.WaitGroup
err error // written back by worker, only safe to read when all tasks are done
duration time.Duration // written back by worker, only safe to read when all tasks are done
}
type inputWorker struct {
running atomic.Bool
input *sendFIFOQueue
}
var inputWorkers sync.Map // room ID -> *inputWorker
// Send implements /_matrix/federation/v1/send/{txnID} // Send implements /_matrix/federation/v1/send/{txnID}
func Send( func Send(
httpReq *http.Request, httpReq *http.Request,
@ -193,8 +253,12 @@ type txnFederationClient interface {
func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) {
results := make(map[string]gomatrixserverlib.PDUResult) results := make(map[string]gomatrixserverlib.PDUResult)
//var resultsMutex sync.Mutex
var wg sync.WaitGroup
var tasks []*inputTask
wg.Add(1) // for processEDUs
pdus := []*gomatrixserverlib.HeaderedEvent{}
for _, pdu := range t.PDUs { for _, pdu := range t.PDUs {
pduCountTotal.WithLabelValues("total").Inc() pduCountTotal.WithLabelValues("total").Inc()
var header struct { var header struct {
@ -245,83 +309,97 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res
} }
continue continue
} }
pdus = append(pdus, event.Headered(verRes.RoomVersion)) v, _ := inputWorkers.LoadOrStore(event.RoomID(), &inputWorker{
input: newSendFIFOQueue(),
})
worker := v.(*inputWorker)
if !worker.running.Load() {
go worker.run()
}
wg.Add(1)
task := &inputTask{
ctx: ctx,
t: t,
event: event,
wg: &wg,
}
tasks = append(tasks, task)
worker.input.push(task)
} }
// Process the events. go func() {
for _, e := range pdus { defer wg.Done()
evStart := time.Now() t.processEDUs(ctx)
if err := t.processEvent(ctx, e.Unwrap()); err != nil { }()
// If the error is due to the event itself being bad then we skip
// it and move onto the next event. We report an error so that the wg.Wait()
// sender knows that we have skipped processing it.
// for _, task := range tasks {
// However if the event is due to a temporary failure in our server if task.err != nil {
// such as a database being unavailable then we should bail, and results[task.event.EventID()] = gomatrixserverlib.PDUResult{
// hope that the sender will retry when we are feeling better. Error: task.err.Error(),
//
// It is uncertain what we should do if an event fails because
// we failed to fetch more information from the sending server.
// For example if a request to /state fails.
// If we skip the event then we risk missing the event until we
// receive another event referencing it.
// If we bail and stop processing then we risk wedging incoming
// transactions from that server forever.
if isProcessingErrorFatal(err) {
sentry.CaptureException(err)
// Any other error should be the result of a temporary error in
// our server so we should bail processing the transaction entirely.
util.GetLogger(ctx).Warnf("Processing %s failed fatally: %s", e.EventID(), err)
jsonErr := util.ErrorResponse(err)
processEventSummary.WithLabelValues(t.work, MetricsOutcomeFatal).Observe(
float64(time.Since(evStart).Nanoseconds()) / 1000.,
)
return nil, &jsonErr
} else {
// Auth errors mean the event is 'rejected' which have to be silent to appease sytest
errMsg := ""
outcome := MetricsOutcomeRejected
_, rejected := err.(*gomatrixserverlib.NotAllowed)
if !rejected {
errMsg = err.Error()
outcome = MetricsOutcomeFail
}
util.GetLogger(ctx).WithError(err).WithField("event_id", e.EventID()).WithField("rejected", rejected).Warn(
"Failed to process incoming federation event, skipping",
)
processEventSummary.WithLabelValues(t.work, outcome).Observe(
float64(time.Since(evStart).Nanoseconds()) / 1000.,
)
results[e.EventID()] = gomatrixserverlib.PDUResult{
Error: errMsg,
}
} }
} else { } else {
results[e.EventID()] = gomatrixserverlib.PDUResult{} results[task.event.EventID()] = gomatrixserverlib.PDUResult{}
pduCountTotal.WithLabelValues("success").Inc()
processEventSummary.WithLabelValues(t.work, MetricsOutcomeOK).Observe(
float64(time.Since(evStart).Nanoseconds()) / 1000.,
)
} }
} }
t.processEDUs(ctx)
if c := len(results); c > 0 { if c := len(results); c > 0 {
util.GetLogger(ctx).Infof("Processed %d PDUs from transaction %q", c, t.TransactionID) util.GetLogger(ctx).Infof("Processed %d PDUs from transaction %q", c, t.TransactionID)
} }
return &gomatrixserverlib.RespSend{PDUs: results}, nil return &gomatrixserverlib.RespSend{PDUs: results}, nil
} }
// isProcessingErrorFatal returns true if the error is really bad and func (t *inputWorker) run() {
// we should stop processing the transaction, and returns false if it if !t.running.CAS(false, true) {
// is just some less serious error about a specific event. return
func isProcessingErrorFatal(err error) bool { }
switch err { defer t.running.Store(false)
case sql.ErrConnDone: for {
case sql.ErrTxDone: task, ok := t.input.pop()
return true if !ok {
return
}
if task == nil {
continue
}
func() {
defer task.wg.Done()
select {
case <-task.ctx.Done():
task.err = context.DeadlineExceeded
return
default:
evStart := time.Now()
task.err = task.t.processEvent(task.ctx, task.event)
task.duration = time.Since(evStart)
if err := task.err; err != nil {
switch err.(type) {
case *gomatrixserverlib.NotAllowed:
processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeRejected).Observe(
float64(time.Since(evStart).Nanoseconds()) / 1000.,
)
util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", true).Warn(
"Failed to process incoming federation event, skipping",
)
task.err = nil // make "rejected" failures silent
default:
processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeFail).Observe(
float64(time.Since(evStart).Nanoseconds()) / 1000.,
)
util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", false).Warn(
"Failed to process incoming federation event, skipping",
)
}
} else {
pduCountTotal.WithLabelValues("success").Inc()
processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeOK).Observe(
float64(time.Since(evStart).Nanoseconds()) / 1000.,
)
}
}
}()
} }
return false
} }
type roomNotFoundError struct { type roomNotFoundError struct {
@ -633,11 +711,6 @@ func (t *txnReq) processEventWithMissingState(
processEventWithMissingStateMutexes.Lock(e.RoomID()) processEventWithMissingStateMutexes.Lock(e.RoomID())
defer processEventWithMissingStateMutexes.Unlock(e.RoomID()) defer processEventWithMissingStateMutexes.Unlock(e.RoomID())
// Do this with a fresh context, so that we keep working even if the
// original request times out. With any luck, by the time the remote
// side retries, we'll have fetched the missing state.
gmectx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
defer cancel()
// We are missing the previous events for this events. // We are missing the previous events for this events.
// This means that there is a gap in our view of the history of the // This means that there is a gap in our view of the history of the
// room. There two ways that we can handle such a gap: // room. There two ways that we can handle such a gap:
@ -658,7 +731,7 @@ func (t *txnReq) processEventWithMissingState(
// - fill in the gap completely then process event `e` returning no backwards extremity // - fill in the gap completely then process event `e` returning no backwards extremity
// - fail to fill in the gap and tell us to terminate the transaction err=not nil // - fail to fill in the gap and tell us to terminate the transaction err=not nil
// - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction
newEvents, err := t.getMissingEvents(gmectx, e, roomVersion) newEvents, err := t.getMissingEvents(ctx, e, roomVersion)
if err != nil { if err != nil {
return err return err
} }
@ -685,7 +758,7 @@ func (t *txnReq) processEventWithMissingState(
// Look up what the state is after the backward extremity. This will either // Look up what the state is after the backward extremity. This will either
// come from the roomserver, if we know all the required events, or it will // come from the roomserver, if we know all the required events, or it will
// come from a remote server via /state_ids if not. // come from a remote server via /state_ids if not.
prevState, trustworthy, lerr := t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID) prevState, trustworthy, lerr := t.lookupStateAfterEvent(ctx, roomVersion, backwardsExtremity.RoomID(), prevEventID)
if lerr != nil { if lerr != nil {
util.GetLogger(ctx).WithError(lerr).Errorf("Failed to lookup state after prev_event: %s", prevEventID) util.GetLogger(ctx).WithError(lerr).Errorf("Failed to lookup state after prev_event: %s", prevEventID)
return lerr return lerr
@ -729,7 +802,7 @@ func (t *txnReq) processEventWithMissingState(
} }
// There's more than one previous state - run them all through state res // There's more than one previous state - run them all through state res
t.roomsMu.Lock(e.RoomID()) t.roomsMu.Lock(e.RoomID())
resolvedState, err = t.resolveStatesAndCheck(gmectx, roomVersion, respStates, backwardsExtremity) resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, backwardsExtremity)
t.roomsMu.Unlock(e.RoomID()) t.roomsMu.Unlock(e.RoomID())
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID())