diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 6ef56513..06a38b9c 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -106,8 +106,8 @@ func Send( eduAPI: eduAPI, keys: keys, federation: federation, + hadEvents: make(map[string]bool), haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), - newEvents: make(map[string]bool), keyAPI: keyAPI, roomsMu: mu, } @@ -167,13 +167,12 @@ type txnReq struct { servers []gomatrixserverlib.ServerName serversMutex sync.RWMutex roomsMu *internal.MutexByRoom + // a list of events from the auth and prev events which we already had + hadEvents map[string]bool // local cache of events for auth checks, etc - this may include events // which the roomserver is unaware of. haveEvents map[string]*gomatrixserverlib.HeaderedEvent - // new events which the roomserver does not know about - newEvents map[string]bool - newEventsMutex sync.RWMutex - work string // metrics + work string // metrics } // A subset of FederationClient functionality that txn requires. Useful for testing. @@ -340,19 +339,6 @@ func (e missingPrevEventsError) Error() string { return fmt.Sprintf("unable to get prev_events for event %q: %s", e.eventID, e.err) } -func (t *txnReq) haveEventIDs() map[string]bool { - t.newEventsMutex.RLock() - defer t.newEventsMutex.RUnlock() - result := make(map[string]bool, len(t.haveEvents)) - for eventID := range t.haveEvents { - if t.newEvents[eventID] { - continue - } - result[eventID] = true - } - return result -} - func (t *txnReq) processEDUs(ctx context.Context) { for _, e := range t.EDUs { eduCountTotal.Inc() @@ -527,6 +513,15 @@ func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) e return roomNotFoundError{e.RoomID()} } + // Prepare a map of all the events we already had before this point, so + // that we don't send them to the roomserver again. + for _, eventID := range append(e.AuthEventIDs(), e.PrevEventIDs()...) { + t.hadEvents[eventID] = true + } + for _, eventID := range append(stateResp.MissingAuthEventIDs, stateResp.MissingPrevEventIDs...) { + t.hadEvents[eventID] = false + } + if len(stateResp.MissingAuthEventIDs) > 0 { t.work = MetricsWorkMissingAuthEvents logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs)) @@ -596,6 +591,8 @@ withNextEvent: ); err != nil { return fmt.Errorf("api.SendEvents: %w", err) } + t.hadEvents[ev.EventID()] = true // if the roomserver didn't know about the event before, it does now + t.cacheAndReturn(ev.Headered(stateResp.RoomVersion)) delete(missingAuthEvents, missingAuthEventID) continue withNextEvent } @@ -739,7 +736,7 @@ func (t *txnReq) processEventWithMissingState( api.KindOld, resolvedState, backwardsExtremity.Headered(roomVersion), - t.haveEventIDs(), + t.hadEvents, ) if err != nil { return fmt.Errorf("api.SendEventWithState: %w", err) @@ -791,7 +788,7 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix default: return nil, false, fmt.Errorf("t.lookupEvent: %w", err) } - t.cacheAndReturn(h) + h = t.cacheAndReturn(h) if h.StateKey() != nil { addedToState := false for i := range respState.StateEvents { @@ -833,6 +830,7 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event // set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this // processEvent request, which is better for memory. stateEvents[i] = t.cacheAndReturn(ev) + t.hadEvents[ev.EventID()] = true } // we should never access res.StateEvents again so we delete it here to make GC faster res.StateEvents = nil @@ -863,8 +861,9 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { return nil } - for i := range queryRes.Events { + for i, ev := range queryRes.Events { authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap()) + t.hadEvents[ev.EventID()] = true } queryRes.Events = nil } @@ -939,8 +938,9 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Even return nil, err } latestEvents := make([]string, len(res.LatestEvents)) - for i := range res.LatestEvents { + for i, ev := range res.LatestEvents { latestEvents[i] = res.LatestEvents[i].EventID + t.hadEvents[ev.EventID] = true } var missingResp *gomatrixserverlib.RespMissingEvents @@ -985,6 +985,12 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Even // For now, we do not allow Case B, so reject the event. logger.Infof("get_missing_events returned %d events", len(missingResp.Events)) + // Make sure events from the missingResp are using the cache - missing events + // will be added and duplicates will be removed. + for i, ev := range missingResp.Events { + missingResp.Events[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + } + // topologically sort and sanity check that we are making forward progress newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents) shouldHaveSomeEventIDs := e.PrevEventIDs() @@ -1023,6 +1029,14 @@ func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID if err := state.Check(ctx, t.keys, nil); err != nil { return nil, err } + // Cache the results of this state lookup and deduplicate anything we already + // have in the cache, freeing up memory. + for i, ev := range state.AuthEvents { + state.AuthEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + } + for i, ev := range state.StateEvents { + state.StateEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + } return &state, nil } @@ -1055,9 +1069,10 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { return nil, err } - for i := range queryRes.Events { + for i, ev := range queryRes.Events { + queryRes.Events[i] = t.cacheAndReturn(queryRes.Events[i]) + t.hadEvents[ev.EventID()] = true evID := queryRes.Events[i].EventID() - t.cacheAndReturn(queryRes.Events[i]) if missing[evID] { delete(missing, evID) } @@ -1221,9 +1236,5 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib. util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) return nil, verifySigError{event.EventID(), err} } - h := event.Headered(roomVersion) - t.newEventsMutex.Lock() - t.newEvents[h.EventID()] = true - t.newEventsMutex.Unlock() - return h, nil + return t.cacheAndReturn(event.Headered(roomVersion)), nil } diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index b14cbd35..98ff1a0a 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -370,7 +370,7 @@ func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederat keys: &test.NopJSONVerifier{}, federation: fedClient, haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), - newEvents: make(map[string]bool), + hadEvents: make(map[string]bool), roomsMu: internal.NewMutexByRoom(), } t.PDUs = pdus