Don't re-run state resolution on a single trusted state snapshot (#1526)

* Don't re-run state resolution on a single trusted state snapshot

* Lint

* Check if backward extremity is create event before checking missing state
This commit is contained in:
Neil Alexander 2020-10-15 12:08:49 +01:00 committed by GitHub
parent e3c2b081c7
commit 10f1beb0de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -476,6 +476,7 @@ func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserver
return gomatrixserverlib.Allowed(e, &authUsingState) return gomatrixserverlib.Allowed(e, &authUsingState)
} }
// nolint:gocyclo
func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) error { func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) error {
// Do this with a fresh context, so that we keep working even if the // Do this with a fresh context, so that we keep working even if the
// original request times out. With any luck, by the time the remote // original request times out. With any luck, by the time the remote
@ -513,36 +514,71 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser
backwardsExtremity := &newEvents[0] backwardsExtremity := &newEvents[0]
newEvents = newEvents[1:] newEvents = newEvents[1:]
type respState struct {
// A snapshot is considered trustworthy if it came from our own roomserver.
// That's because the state will have been through state resolution once
// already in QueryStateAfterEvent.
trustworthy bool
*gomatrixserverlib.RespState
}
// at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity.
// Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query
// the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event.
var states []*gomatrixserverlib.RespState var states []*respState
for _, prevEventID := range backwardsExtremity.PrevEventIDs() { for _, prevEventID := range backwardsExtremity.PrevEventIDs() {
// Look up what the state is after the backward extremity. This will either // Look up what the state is after the backward extremity. This will either
// come from the roomserver, if we know all the required events, or it will // come from the roomserver, if we know all the required events, or it will
// come from a remote server via /state_ids if not. // come from a remote server via /state_ids if not.
var prevState *gomatrixserverlib.RespState prevState, trustworthy, lerr := t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID)
prevState, err = t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID) if lerr != nil {
if err != nil { util.GetLogger(ctx).WithError(lerr).Errorf("Failed to lookup state after prev_event: %s", prevEventID)
util.GetLogger(ctx).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID) return lerr
return err
} }
// Append the state onto the collected state. We'll run this through the // Append the state onto the collected state. We'll run this through the
// state resolution next. // state resolution next.
states = append(states, prevState) states = append(states, &respState{trustworthy, prevState})
} }
// Now that we have collected all of the state from the prev_events, we'll // Now that we have collected all of the state from the prev_events, we'll
// run the state through the appropriate state resolution algorithm for the // run the state through the appropriate state resolution algorithm for the
// room. This does a couple of things: // room if needed. This does a couple of things:
// 1. Ensures that the state is deduplicated fully for each state-key tuple // 1. Ensures that the state is deduplicated fully for each state-key tuple
// 2. Ensures that we pick the latest events from both sets, in the case that // 2. Ensures that we pick the latest events from both sets, in the case that
// one of the prev_events is quite a bit older than the others // one of the prev_events is quite a bit older than the others
resolvedState, err := t.resolveStatesAndCheck(gmectx, roomVersion, states, backwardsExtremity) resolvedState := &gomatrixserverlib.RespState{}
switch len(states) {
case 0:
extremityIsCreate := backwardsExtremity.Type() == gomatrixserverlib.MRoomCreate && backwardsExtremity.StateKeyEquals("")
if !extremityIsCreate {
// There are no previous states and this isn't the beginning of the
// room - this is an error condition!
util.GetLogger(ctx).Errorf("Failed to lookup any state after prev_events")
return fmt.Errorf("expected %d states but got %d", len(backwardsExtremity.PrevEventIDs()), len(states))
}
case 1:
// There's only one previous state - if it's trustworthy (came from a
// local state snapshot which will already have been through state res),
// use it as-is. There's no point in resolving it again.
if states[0].trustworthy {
resolvedState = states[0].RespState
break
}
// Otherwise, if it isn't trustworthy (came from federation), run it through
// state resolution anyway for safety, in case there are duplicates.
fallthrough
default:
respStates := make([]*gomatrixserverlib.RespState, len(states))
for i := range states {
respStates[i] = states[i].RespState
}
// There's more than one previous state - run them all through state res
resolvedState, err = t.resolveStatesAndCheck(gmectx, roomVersion, respStates, backwardsExtremity)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID())
return err return err
} }
}
// First of all, send the backward extremity into the roomserver with the // First of all, send the backward extremity into the roomserver with the
// newly resolved state. This marks the "oldest" point in the backfill and // newly resolved state. This marks the "oldest" point in the backfill and
@ -581,16 +617,16 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser
// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event)
// added into the mix. // added into the mix.
func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*gomatrixserverlib.RespState, error) { func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*gomatrixserverlib.RespState, bool, error) {
// try doing all this locally before we resort to querying federation // try doing all this locally before we resort to querying federation
respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID) respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID)
if respState != nil { if respState != nil {
return respState, nil return respState, true, nil
} }
respState, err := t.lookupStateBeforeEvent(ctx, roomVersion, roomID, eventID) respState, err := t.lookupStateBeforeEvent(ctx, roomVersion, roomID, eventID)
if err != nil { if err != nil {
return nil, fmt.Errorf("t.lookupStateBeforeEvent: %w", err) return nil, false, fmt.Errorf("t.lookupStateBeforeEvent: %w", err)
} }
servers := t.getServers(ctx, roomID) servers := t.getServers(ctx, roomID)
@ -602,11 +638,11 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix
h, err := t.lookupEvent(ctx, roomVersion, eventID, false, servers) h, err := t.lookupEvent(ctx, roomVersion, eventID, false, servers)
switch err.(type) { switch err.(type) {
case verifySigError: case verifySigError:
return respState, nil return respState, false, nil
case nil: case nil:
// do nothing // do nothing
default: default:
return nil, fmt.Errorf("t.lookupEvent: %w", err) return nil, false, fmt.Errorf("t.lookupEvent: %w", err)
} }
t.haveEvents[h.EventID()] = h t.haveEvents[h.EventID()] = h
if h.StateKey() != nil { if h.StateKey() != nil {
@ -624,7 +660,7 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix
} }
} }
return respState, nil return respState, false, nil
} }
func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *gomatrixserverlib.RespState { func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *gomatrixserverlib.RespState {