diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index d437d776..deb88ea8 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -33,6 +33,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" ) @@ -75,6 +76,11 @@ func (r *Inputer) processRoomEvent( default: } + span, ctx := opentracing.StartSpanFromContext(ctx, "processRoomEvent") + span.SetTag("room_id", input.Event.RoomID()) + span.SetTag("event_id", input.Event.EventID()) + defer span.Finish() + // Measure how long it takes to process this event. started := time.Now() defer func() { @@ -411,6 +417,9 @@ func (r *Inputer) fetchAuthEvents( known map[string]*types.Event, servers []gomatrixserverlib.ServerName, ) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "fetchAuthEvents") + defer span.Finish() + unknown := map[string]struct{}{} authEventIDs := event.AuthEventIDs() if len(authEventIDs) == 0 { @@ -526,6 +535,9 @@ func (r *Inputer) calculateAndSetState( event *gomatrixserverlib.Event, isRejected bool, ) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "calculateAndSetState") + defer span.Finish() + var succeeded bool updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) if err != nil { diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index f772299a..9738ed4e 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -27,6 +27,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/opentracing/opentracing-go" "github.com/sirupsen/logrus" ) @@ -56,6 +57,9 @@ func (r *Inputer) updateLatestEvents( transactionID *api.TransactionID, rewritesState bool, ) (err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "updateLatestEvents") + defer span.Finish() + var succeeded bool updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) if err != nil { @@ -200,6 +204,9 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { } func (u *latestEventsUpdater) latestState() error { + span, ctx := opentracing.StartSpanFromContext(u.ctx, "processEventWithMissingState") + defer span.Finish() + var err error roomState := state.NewStateResolution(u.updater, u.roomInfo) @@ -246,7 +253,7 @@ func (u *latestEventsUpdater) latestState() error { // of the state after the events. The snapshot state will be resolved // using the correct state resolution algorithm for the room. u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents( - u.ctx, latestStateAtEvents, + ctx, latestStateAtEvents, ) if err != nil { return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err) @@ -258,7 +265,7 @@ func (u *latestEventsUpdater) latestState() error { // another list of added ones. Replacing a value for a state-key tuple // will result one removed (the old event) and one added (the new event). u.removed, u.added, err = roomState.DifferenceBetweeenStateSnapshots( - u.ctx, u.oldStateNID, u.newStateNID, + ctx, u.oldStateNID, u.newStateNID, ) if err != nil { return fmt.Errorf("roomState.DifferenceBetweenStateSnapshots: %w", err) @@ -278,7 +285,7 @@ func (u *latestEventsUpdater) latestState() error { // Also work out the state before the event removes and the event // adds. u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots( - u.ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID, + ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID, ) if err != nil { return fmt.Errorf("roomState.DifferenceBetweeenStateSnapshots: %w", err) @@ -294,6 +301,9 @@ func (u *latestEventsUpdater) calculateLatest( newEvent *gomatrixserverlib.Event, newStateAndRef types.StateAtEventAndReference, ) (bool, error) { + span, _ := opentracing.StartSpanFromContext(u.ctx, "calculateLatest") + defer span.Finish() + // First of all, get a list of all of the events in our current // set of forward extremities. existingRefs := make(map[string]*types.StateAtEventAndReference) diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 3953586b..3ce8791a 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/opentracing/opentracing-go" ) // updateMembership updates the current membership and the invites for each @@ -34,6 +35,9 @@ func (r *Inputer) updateMemberships( updater *shared.RoomUpdater, removed, added []types.StateEntry, ) ([]api.OutputEvent, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "updateMemberships") + defer span.Finish() + changes := membershipChanges(removed, added) var eventNIDs []types.EventNID for _, change := range changes { diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 9c70076c..edc153b7 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -15,6 +15,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/opentracing/opentracing-go" "github.com/sirupsen/logrus" ) @@ -59,6 +60,9 @@ type missingStateReq struct { func (t *missingStateReq) processEventWithMissingState( ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, ) (*parsedRespState, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "processEventWithMissingState") + defer span.Finish() + // We are missing the previous events for this events. // This means that there is a gap in our view of the history of the // room. There two ways that we can handle such a gap: @@ -235,6 +239,9 @@ func (t *missingStateReq) processEventWithMissingState( } func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (*parsedRespState, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "lookupResolvedStateBeforeEvent") + defer span.Finish() + type respState struct { // A snapshot is considered trustworthy if it came from our own roomserver. // That's because the state will have been through state resolution once @@ -310,6 +317,9 @@ func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) // added into the mix. func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*parsedRespState, bool, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "lookupStateAfterEvent") + defer span.Finish() + // try doing all this locally before we resort to querying federation respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID) if respState != nil { @@ -361,6 +371,9 @@ func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.Event) *gomatrixs } func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *parsedRespState { + span, ctx := opentracing.StartSpanFromContext(ctx, "lookupStateAfterEventLocally") + defer span.Finish() + var res parsedRespState roomInfo, err := t.db.RoomInfo(ctx, roomID) if err != nil { @@ -435,12 +448,17 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, room // the server supports. func (t *missingStateReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( *parsedRespState, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "lookupStateBeforeEvent") + defer span.Finish() // Attempt to fetch the missing state using /state_ids and /events return t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion) } func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*parsedRespState, backwardsExtremity *gomatrixserverlib.Event) (*parsedRespState, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "resolveStatesAndCheck") + defer span.Finish() + var authEventList []*gomatrixserverlib.Event var stateEventList []*gomatrixserverlib.Event for _, state := range states { @@ -484,6 +502,9 @@ retryAllowedState: // get missing events for `e`. If `isGapFilled`=true then `newEvents` contains all the events to inject, // without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled, prevStateKnown bool, err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "getMissingEvents") + defer span.Finish() + logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) latest, _, _, err := t.db.LatestEventIDs(ctx, t.roomInfo.RoomNID) if err != nil { @@ -608,6 +629,9 @@ func (t *missingStateReq) isPrevStateKnown(ctx context.Context, e *gomatrixserve func (t *missingStateReq) lookupMissingStateViaState( ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (respState *parsedRespState, err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaState") + defer span.Finish() + state, err := t.federation.LookupState(ctx, t.origin, roomID, eventID, roomVersion) if err != nil { return nil, err @@ -637,6 +661,9 @@ func (t *missingStateReq) lookupMissingStateViaState( func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( *parsedRespState, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaStateIDs") + defer span.Finish() + util.GetLogger(ctx).WithField("room_id", roomID).Infof("lookupMissingStateViaStateIDs %s", eventID) // fetch the state event IDs at the time of the event stateIDs, err := t.federation.LookupStateIDs(ctx, t.origin, roomID, eventID) @@ -799,6 +826,9 @@ func (t *missingStateReq) createRespStateFromStateIDs( } func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.Event, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "lookupEvent") + defer span.Finish() + if localFirst { // fetch from the roomserver events, err := t.db.EventsFromIDs(ctx, []string{missingEventID}) diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 95abdcb3..6c4e4b86 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -20,9 +20,11 @@ import ( "context" "fmt" "sort" + "sync" "time" "github.com/matrix-org/util" + "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" "github.com/matrix-org/dendrite/roomserver/types" @@ -62,6 +64,9 @@ func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) Sta func (v *StateResolution) LoadStateAtSnapshot( ctx context.Context, stateNID types.StateSnapshotNID, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtSnapshot") + defer span.Finish() + stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) if err != nil { return nil, err @@ -100,6 +105,9 @@ func (v *StateResolution) LoadStateAtSnapshot( func (v *StateResolution) LoadStateAtEvent( ctx context.Context, eventID string, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtEvent") + defer span.Finish() + snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) if err != nil { return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %s", eventID, err) @@ -122,6 +130,9 @@ func (v *StateResolution) LoadStateAtEvent( func (v *StateResolution) LoadCombinedStateAfterEvents( ctx context.Context, prevStates []types.StateAtEvent, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadCombinedStateAfterEvents") + defer span.Finish() + stateNIDs := make([]types.StateSnapshotNID, len(prevStates)) for i, state := range prevStates { stateNIDs[i] = state.BeforeStateSnapshotNID @@ -194,6 +205,9 @@ func (v *StateResolution) LoadCombinedStateAfterEvents( func (v *StateResolution) DifferenceBetweeenStateSnapshots( ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID, ) (removed, added []types.StateEntry, err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.DifferenceBetweeenStateSnapshots") + defer span.Finish() + if oldStateNID == newStateNID { // If the snapshot NIDs are the same then nothing has changed return nil, nil, nil @@ -255,6 +269,9 @@ func (v *StateResolution) LoadStateAtSnapshotForStringTuples( stateNID types.StateSnapshotNID, stateKeyTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtSnapshotForStringTuples") + defer span.Finish() + numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) if err != nil { return nil, err @@ -269,6 +286,9 @@ func (v *StateResolution) stringTuplesToNumericTuples( ctx context.Context, stringTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateKeyTuple, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.stringTuplesToNumericTuples") + defer span.Finish() + eventTypes := make([]string, len(stringTuples)) stateKeys := make([]string, len(stringTuples)) for i := range stringTuples { @@ -311,6 +331,9 @@ func (v *StateResolution) loadStateAtSnapshotForNumericTuples( stateNID types.StateSnapshotNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateAtSnapshotForNumericTuples") + defer span.Finish() + stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) if err != nil { return nil, err @@ -359,6 +382,9 @@ func (v *StateResolution) LoadStateAfterEventsForStringTuples( prevStates []types.StateAtEvent, stateKeyTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAfterEventsForStringTuples") + defer span.Finish() + numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) if err != nil { return nil, err @@ -371,6 +397,9 @@ func (v *StateResolution) loadStateAfterEventsForNumericTuples( prevStates []types.StateAtEvent, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateAfterEventsForNumericTuples") + defer span.Finish() + if len(prevStates) == 1 { // Fast path for a single event. prevState := prevStates[0] @@ -543,6 +572,9 @@ func (v *StateResolution) CalculateAndStoreStateBeforeEvent( event *gomatrixserverlib.Event, isRejected bool, ) (types.StateSnapshotNID, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.CalculateAndStoreStateBeforeEvent") + defer span.Finish() + // Load the state at the prev events. prevStates, err := v.db.StateAtEventIDs(ctx, event.PrevEventIDs()) if err != nil { @@ -559,6 +591,9 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents( ctx context.Context, prevStates []types.StateAtEvent, ) (types.StateSnapshotNID, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.CalculateAndStoreStateAfterEvents") + defer span.Finish() + metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} if len(prevStates) == 0 { @@ -631,6 +666,9 @@ func (v *StateResolution) calculateAndStoreStateAfterManyEvents( prevStates []types.StateAtEvent, metrics calculateStateMetrics, ) (types.StateSnapshotNID, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.calculateAndStoreStateAfterManyEvents") + defer span.Finish() + state, algorithm, conflictLength, err := v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates) metrics.algorithm = algorithm @@ -649,6 +687,9 @@ func (v *StateResolution) calculateStateAfterManyEvents( ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, prevStates []types.StateAtEvent, ) (state []types.StateEntry, algorithm string, conflictLength int, err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.calculateStateAfterManyEvents") + defer span.Finish() + var combined []types.StateEntry // Conflict resolution. // First stage: load the state after each of the prev events. @@ -701,6 +742,9 @@ func (v *StateResolution) resolveConflicts( ctx context.Context, version gomatrixserverlib.RoomVersion, notConflicted, conflicted []types.StateEntry, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflicts") + defer span.Finish() + stateResAlgo, err := version.StateResAlgorithm() if err != nil { return nil, err @@ -725,6 +769,8 @@ func (v *StateResolution) resolveConflictsV1( ctx context.Context, notConflicted, conflicted []types.StateEntry, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflictsV1") + defer span.Finish() // Load the conflicted events conflictedEvents, eventIDMap, err := v.loadStateEvents(ctx, conflicted) @@ -788,6 +834,9 @@ func (v *StateResolution) resolveConflictsV2( ctx context.Context, notConflicted, conflicted []types.StateEntry, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflictsV2") + defer span.Finish() + estimate := len(conflicted) + len(notConflicted) eventIDMap := make(map[string]types.StateEntry, estimate) @@ -815,31 +864,47 @@ func (v *StateResolution) resolveConflictsV2( authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3) gotAuthEvents := make(map[string]struct{}, estimate*3) authDifference := make([]*gomatrixserverlib.Event, 0, estimate) + knownAuthEvents := make(map[string]types.Event, estimate*3) // For each conflicted event, let's try and get the needed auth events. - for _, conflictedEvent := range conflictedEvents { - // Work out which auth events we need to load. - key := conflictedEvent.EventID() + if err = func() error { + span, sctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadAuthEvents") + defer span.Finish() - // Store the newly found auth events in the auth set for this event. - var authEventMap map[string]types.StateEntry - authSets[key], authEventMap, err = v.loadAuthEvents(ctx, conflictedEvent) - if err != nil { - return nil, err - } - for k, v := range authEventMap { - eventIDMap[k] = v + loader := authEventLoader{ + v: v, + lookupFromDB: make([]string, 0, len(conflictedEvents)*3), + lookupFromMem: make([]string, 0, len(conflictedEvents)*3), + lookedUpEvents: make([]types.Event, 0, len(conflictedEvents)*3), + eventMap: map[string]types.Event{}, } + for _, conflictedEvent := range conflictedEvents { + // Work out which auth events we need to load. + key := conflictedEvent.EventID() - // Only add auth events into the authEvents slice once, otherwise the - // check for the auth difference can become expensive and produce - // duplicate entries, which just waste memory and CPU time. - for _, event := range authSets[key] { - if _, ok := gotAuthEvents[event.EventID()]; !ok { - authEvents = append(authEvents, event) - gotAuthEvents[event.EventID()] = struct{}{} + // Store the newly found auth events in the auth set for this event. + var authEventMap map[string]types.StateEntry + authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, conflictedEvent, knownAuthEvents) + if err != nil { + return err + } + for k, v := range authEventMap { + eventIDMap[k] = v + } + + // Only add auth events into the authEvents slice once, otherwise the + // check for the auth difference can become expensive and produce + // duplicate entries, which just waste memory and CPU time. + for _, event := range authSets[key] { + if _, ok := gotAuthEvents[event.EventID()]; !ok { + authEvents = append(authEvents, event) + gotAuthEvents[event.EventID()] = struct{}{} + } } } + return nil + }(); err != nil { + return nil, err } // Kill the reference to this so that the GC may pick it up, since we no @@ -870,19 +935,29 @@ func (v *StateResolution) resolveConflictsV2( // Look through all of the auth events that we've been given and work out if // there are any events which don't appear in all of the auth sets. If they // don't then we add them to the auth difference. - for _, event := range authEvents { - if !isInAllAuthLists(event) { - authDifference = append(authDifference, event) + func() { + span, _ := opentracing.StartSpanFromContext(ctx, "isInAllAuthLists") + defer span.Finish() + + for _, event := range authEvents { + if !isInAllAuthLists(event) { + authDifference = append(authDifference, event) + } } - } + }() // Resolve the conflicts. - resolvedEvents := gomatrixserverlib.ResolveStateConflictsV2( - conflictedEvents, - nonConflictedEvents, - authEvents, - authDifference, - ) + resolvedEvents := func() []*gomatrixserverlib.Event { + span, _ := opentracing.StartSpanFromContext(ctx, "gomatrixserverlib.ResolveStateConflictsV2") + defer span.Finish() + + return gomatrixserverlib.ResolveStateConflictsV2( + conflictedEvents, + nonConflictedEvents, + authEvents, + authDifference, + ) + }() // Map from the full events back to numeric state entries. for _, resolvedEvent := range resolvedEvents { @@ -947,6 +1022,9 @@ func (v *StateResolution) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.E func (v *StateResolution) loadStateEvents( ctx context.Context, entries []types.StateEntry, ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateEvents") + defer span.Finish() + result := make([]*gomatrixserverlib.Event, 0, len(entries)) eventEntries := make([]types.StateEntry, 0, len(entries)) eventNIDs := make([]types.EventNID, 0, len(entries)) @@ -975,43 +1053,86 @@ func (v *StateResolution) loadStateEvents( return result, eventIDMap, nil } +type authEventLoader struct { + sync.Mutex + v *StateResolution + lookupFromDB []string // scratch space + lookupFromMem []string // scratch space + lookedUpEvents []types.Event // scratch space + eventMap map[string]types.Event +} + // loadAuthEvents loads all of the auth events for a given event recursively, // along with a map that contains state entries for all of the auth events. -func (v *StateResolution) loadAuthEvents( - ctx context.Context, event *gomatrixserverlib.Event, +func (l *authEventLoader) loadAuthEvents( + ctx context.Context, event *gomatrixserverlib.Event, eventMap map[string]types.Event, ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { - eventMap := map[string]struct{}{} - var lookup []string - var authEvents []types.Event + l.Lock() + defer l.Unlock() + authEvents := []types.Event{} // our returned list + included := map[string]struct{}{} // dedupes authEvents above queue := event.AuthEventIDs() for i := 0; i < len(queue); i++ { - lookup = lookup[:0] + // Reuse the same underlying memory, since it reduces the + // amount of allocations we make the more times we call + // loadAuthEvents. + l.lookupFromDB = l.lookupFromDB[:0] + l.lookupFromMem = l.lookupFromMem[:0] + l.lookedUpEvents = l.lookedUpEvents[:0] + + // Separate out the list of events in the queue based on if + // we think we already know the event in memory or not. for _, authEventID := range queue { - if _, ok := eventMap[authEventID]; ok { + if _, ok := included[authEventID]; ok { continue } - lookup = append(lookup, authEventID) + if _, ok := eventMap[authEventID]; ok { + l.lookupFromMem = append(l.lookupFromMem, authEventID) + } else { + l.lookupFromDB = append(l.lookupFromDB, authEventID) + } } - if len(lookup) == 0 { + // If there's nothing to do, stop here. + if len(l.lookupFromDB) == 0 && len(l.lookupFromMem) == 0 { break } - events, err := v.db.EventsFromIDs(ctx, lookup) - if err != nil { - return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err) + + // If we need to get events from the database, go and fetch + // those now. + if len(l.lookupFromDB) > 0 { + eventsFromDB, err := l.v.db.EventsFromIDs(ctx, l.lookupFromDB) + if err != nil { + return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err) + } + l.lookedUpEvents = append(l.lookedUpEvents, eventsFromDB...) + for _, event := range eventsFromDB { + eventMap[event.EventID()] = event + } } + + // Fill in the gaps with events that we already have in memory. + if len(l.lookupFromMem) > 0 { + for _, eventID := range l.lookupFromMem { + l.lookedUpEvents = append(l.lookedUpEvents, eventMap[eventID]) + } + } + + // From the events that we've retrieved, work out which auth + // events to look up on the next iteration. add := map[string]struct{}{} - for _, event := range events { - eventMap[event.EventID()] = struct{}{} + for _, event := range l.lookedUpEvents { authEvents = append(authEvents, event) + included[event.EventID()] = struct{}{} + for _, authEventID := range event.AuthEventIDs() { - if _, ok := eventMap[authEventID]; ok { + if _, ok := included[authEventID]; ok { continue } add[authEventID] = struct{}{} } - for authEventID := range add { - queue = append(queue, authEventID) - } + } + for authEventID := range add { + queue = append(queue, authEventID) } } authEventTypes := map[string]struct{}{} @@ -1028,11 +1149,11 @@ func (v *StateResolution) loadAuthEvents( for eventStateKey := range authEventStateKeys { lookupAuthEventStateKeys = append(lookupAuthEventStateKeys, eventStateKey) } - eventTypes, err := v.db.EventTypeNIDs(ctx, lookupAuthEventTypes) + eventTypes, err := l.v.db.EventTypeNIDs(ctx, lookupAuthEventTypes) if err != nil { return nil, nil, fmt.Errorf("v.db.EventTypeNIDs: %w", err) } - eventStateKeys, err := v.db.EventStateKeyNIDs(ctx, lookupAuthEventStateKeys) + eventStateKeys, err := l.v.db.EventStateKeyNIDs(ctx, lookupAuthEventStateKeys) if err != nil { return nil, nil, fmt.Errorf("v.db.EventStateKeyNIDs: %w", err) }