From 10f1beb0de7a52ccdd122b05b4adffdbdab4ea2e Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 15 Oct 2020 12:08:49 +0100 Subject: [PATCH] Don't re-run state resolution on a single trusted state snapshot (#1526) * Don't re-run state resolution on a single trusted state snapshot * Lint * Check if backward extremity is create event before checking missing state --- federationapi/routing/send.go | 72 ++++++++++++++++++++++++++--------- 1 file changed, 54 insertions(+), 18 deletions(-) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 611a90a7..783fdc3b 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -476,6 +476,7 @@ func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserver return gomatrixserverlib.Allowed(e, &authUsingState) } +// nolint:gocyclo func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) error { // Do this with a fresh context, so that we keep working even if the // original request times out. With any luck, by the time the remote @@ -513,35 +514,70 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser backwardsExtremity := &newEvents[0] newEvents = newEvents[1:] + type respState struct { + // A snapshot is considered trustworthy if it came from our own roomserver. + // That's because the state will have been through state resolution once + // already in QueryStateAfterEvent. + trustworthy bool + *gomatrixserverlib.RespState + } + // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. - var states []*gomatrixserverlib.RespState + var states []*respState for _, prevEventID := range backwardsExtremity.PrevEventIDs() { // Look up what the state is after the backward extremity. This will either // come from the roomserver, if we know all the required events, or it will // come from a remote server via /state_ids if not. - var prevState *gomatrixserverlib.RespState - prevState, err = t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID) - if err != nil { - util.GetLogger(ctx).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID) - return err + prevState, trustworthy, lerr := t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID) + if lerr != nil { + util.GetLogger(ctx).WithError(lerr).Errorf("Failed to lookup state after prev_event: %s", prevEventID) + return lerr } // Append the state onto the collected state. We'll run this through the // state resolution next. - states = append(states, prevState) + states = append(states, &respState{trustworthy, prevState}) } // Now that we have collected all of the state from the prev_events, we'll // run the state through the appropriate state resolution algorithm for the - // room. This does a couple of things: + // room if needed. This does a couple of things: // 1. Ensures that the state is deduplicated fully for each state-key tuple // 2. Ensures that we pick the latest events from both sets, in the case that // one of the prev_events is quite a bit older than the others - resolvedState, err := t.resolveStatesAndCheck(gmectx, roomVersion, states, backwardsExtremity) - if err != nil { - util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) - return err + resolvedState := &gomatrixserverlib.RespState{} + switch len(states) { + case 0: + extremityIsCreate := backwardsExtremity.Type() == gomatrixserverlib.MRoomCreate && backwardsExtremity.StateKeyEquals("") + if !extremityIsCreate { + // There are no previous states and this isn't the beginning of the + // room - this is an error condition! + util.GetLogger(ctx).Errorf("Failed to lookup any state after prev_events") + return fmt.Errorf("expected %d states but got %d", len(backwardsExtremity.PrevEventIDs()), len(states)) + } + case 1: + // There's only one previous state - if it's trustworthy (came from a + // local state snapshot which will already have been through state res), + // use it as-is. There's no point in resolving it again. + if states[0].trustworthy { + resolvedState = states[0].RespState + break + } + // Otherwise, if it isn't trustworthy (came from federation), run it through + // state resolution anyway for safety, in case there are duplicates. + fallthrough + default: + respStates := make([]*gomatrixserverlib.RespState, len(states)) + for i := range states { + respStates[i] = states[i].RespState + } + // There's more than one previous state - run them all through state res + resolvedState, err = t.resolveStatesAndCheck(gmectx, roomVersion, respStates, backwardsExtremity) + if err != nil { + util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) + return err + } } // First of all, send the backward extremity into the roomserver with the @@ -581,16 +617,16 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) // added into the mix. -func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*gomatrixserverlib.RespState, error) { +func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*gomatrixserverlib.RespState, bool, error) { // try doing all this locally before we resort to querying federation respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID) if respState != nil { - return respState, nil + return respState, true, nil } respState, err := t.lookupStateBeforeEvent(ctx, roomVersion, roomID, eventID) if err != nil { - return nil, fmt.Errorf("t.lookupStateBeforeEvent: %w", err) + return nil, false, fmt.Errorf("t.lookupStateBeforeEvent: %w", err) } servers := t.getServers(ctx, roomID) @@ -602,11 +638,11 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix h, err := t.lookupEvent(ctx, roomVersion, eventID, false, servers) switch err.(type) { case verifySigError: - return respState, nil + return respState, false, nil case nil: // do nothing default: - return nil, fmt.Errorf("t.lookupEvent: %w", err) + return nil, false, fmt.Errorf("t.lookupEvent: %w", err) } t.haveEvents[h.EventID()] = h if h.StateKey() != nil { @@ -624,7 +660,7 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix } } - return respState, nil + return respState, false, nil } func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *gomatrixserverlib.RespState {