From 895ead804893191b34fd52a549b22331997d45f7 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 7 Sep 2020 12:32:40 +0100 Subject: [PATCH] Use background context when processing event with missing state (#1403) * Use background context when processing event with missing state * Five minute timeout * Remove context from txnreq, thread through instead * Fix unit tests --- federationapi/routing/send.go | 162 +++++++++++++++-------------- federationapi/routing/send_test.go | 3 +- 2 files changed, 85 insertions(+), 80 deletions(-) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index beb7d461..c6e2a3dc 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -19,6 +19,7 @@ import ( "encoding/json" "fmt" "net/http" + "time" "github.com/matrix-org/dendrite/clientapi/jsonerror" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" @@ -43,7 +44,6 @@ func Send( federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { t := txnReq{ - context: httpReq.Context(), rsAPI: rsAPI, eduAPI: eduAPI, keys: keys, @@ -82,7 +82,7 @@ func Send( util.GetLogger(httpReq.Context()).Infof("Received transaction %q containing %d PDUs, %d EDUs", txnID, len(t.PDUs), len(t.EDUs)) - resp, jsonErr := t.processTransaction() + resp, jsonErr := t.processTransaction(httpReq.Context()) if jsonErr != nil { util.GetLogger(httpReq.Context()).WithField("jsonErr", jsonErr).Error("t.processTransaction failed") return *jsonErr @@ -100,7 +100,6 @@ func Send( type txnReq struct { gomatrixserverlib.Transaction - context context.Context rsAPI api.RoomserverInternalAPI eduAPI eduserverAPI.EDUServerInputAPI keyAPI keyapi.KeyInternalAPI @@ -124,7 +123,7 @@ type txnFederationClient interface { roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } -func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONResponse) { +func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { results := make(map[string]gomatrixserverlib.PDUResult) pdus := []gomatrixserverlib.HeaderedEvent{} @@ -133,15 +132,15 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe RoomID string `json:"room_id"` } if err := json.Unmarshal(pdu, &header); err != nil { - util.GetLogger(t.context).WithError(err).Warn("Transaction: Failed to extract room ID from event") + util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to extract room ID from event") // We don't know the event ID at this point so we can't return the // failure in the PDU results continue } verReq := api.QueryRoomVersionForRoomRequest{RoomID: header.RoomID} verRes := api.QueryRoomVersionForRoomResponse{} - if err := t.rsAPI.QueryRoomVersionForRoom(t.context, &verReq, &verRes); err != nil { - util.GetLogger(t.context).WithError(err).Warn("Transaction: Failed to query room version for room", verReq.RoomID) + if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { + util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to query room version for room", verReq.RoomID) // We don't know the event ID at this point so we can't return the // failure in the PDU results continue @@ -161,17 +160,17 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe JSON: jsonerror.BadJSON("PDU contains bad JSON"), } } - util.GetLogger(t.context).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %s", string(pdu)) + util.GetLogger(ctx).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %s", string(pdu)) continue } - if api.IsServerBannedFromRoom(t.context, t.rsAPI, event.RoomID(), t.Origin) { + if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { results[event.EventID()] = gomatrixserverlib.PDUResult{ Error: "Forbidden by server ACLs", } continue } - if err = gomatrixserverlib.VerifyAllEventSignatures(t.context, []gomatrixserverlib.Event{event}, t.keys); err != nil { - util.GetLogger(t.context).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) + if err = gomatrixserverlib.VerifyAllEventSignatures(ctx, []gomatrixserverlib.Event{event}, t.keys); err != nil { + util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) results[event.EventID()] = gomatrixserverlib.PDUResult{ Error: err.Error(), } @@ -182,7 +181,7 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe // Process the events. for _, e := range pdus { - if err := t.processEvent(e.Unwrap(), true); err != nil { + if err := t.processEvent(ctx, e.Unwrap(), true); err != nil { // If the error is due to the event itself being bad then we skip // it and move onto the next event. We report an error so that the // sender knows that we have skipped processing it. @@ -201,7 +200,7 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe if isProcessingErrorFatal(err) { // Any other error should be the result of a temporary error in // our server so we should bail processing the transaction entirely. - util.GetLogger(t.context).Warnf("Processing %s failed fatally: %s", e.EventID(), err) + util.GetLogger(ctx).Warnf("Processing %s failed fatally: %s", e.EventID(), err) jsonErr := util.ErrorResponse(err) return nil, &jsonErr } else { @@ -211,7 +210,7 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe if rejected { errMsg = "" } - util.GetLogger(t.context).WithError(err).WithField("event_id", e.EventID()).WithField("rejected", rejected).Warn( + util.GetLogger(ctx).WithError(err).WithField("event_id", e.EventID()).WithField("rejected", rejected).Warn( "Failed to process incoming federation event, skipping", ) results[e.EventID()] = gomatrixserverlib.PDUResult{ @@ -223,9 +222,9 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe } } - t.processEDUs(t.EDUs) + t.processEDUs(ctx) if c := len(results); c > 0 { - util.GetLogger(t.context).Infof("Processed %d PDUs from transaction %q", c, t.TransactionID) + util.GetLogger(ctx).Infof("Processed %d PDUs from transaction %q", c, t.TransactionID) } return &gomatrixserverlib.RespSend{PDUs: results}, nil } @@ -284,8 +283,9 @@ func (t *txnReq) haveEventIDs() map[string]bool { return result } -func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) { - for _, e := range edus { +// nolint:gocyclo +func (t *txnReq) processEDUs(ctx context.Context) { + for _, e := range t.EDUs { switch e.Type { case gomatrixserverlib.MTyping: // https://matrix.org/docs/spec/server_server/latest#typing-notifications @@ -295,24 +295,24 @@ func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) { Typing bool `json:"typing"` } if err := json.Unmarshal(e.Content, &typingPayload); err != nil { - util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal typing event") + util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal typing event") continue } - if err := eduserverAPI.SendTyping(t.context, t.eduAPI, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { - util.GetLogger(t.context).WithError(err).Error("Failed to send typing event to edu server") + if err := eduserverAPI.SendTyping(ctx, t.eduAPI, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to edu server") } case gomatrixserverlib.MDirectToDevice: // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema var directPayload gomatrixserverlib.ToDeviceMessage if err := json.Unmarshal(e.Content, &directPayload); err != nil { - util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal send-to-device events") + util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal send-to-device events") continue } for userID, byUser := range directPayload.Messages { for deviceID, message := range byUser { // TODO: check that the user and the device actually exist here - if err := eduserverAPI.SendToDevice(t.context, t.eduAPI, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil { - util.GetLogger(t.context).WithError(err).WithFields(logrus.Fields{ + if err := eduserverAPI.SendToDevice(ctx, t.eduAPI, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil { + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ "sender": directPayload.Sender, "user_id": userID, "device_id": deviceID, @@ -321,17 +321,17 @@ func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) { } } case gomatrixserverlib.MDeviceListUpdate: - t.processDeviceListUpdate(e) + t.processDeviceListUpdate(ctx, e) default: - util.GetLogger(t.context).WithField("type", e.Type).Debug("Unhandled EDU") + util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU") } } } -func (t *txnReq) processDeviceListUpdate(e gomatrixserverlib.EDU) { +func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverlib.EDU) { var payload gomatrixserverlib.DeviceListUpdateEvent if err := json.Unmarshal(e.Content, &payload); err != nil { - util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal device list update event") + util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal device list update event") return } var inputRes keyapi.InputDeviceListUpdateResponse @@ -339,11 +339,11 @@ func (t *txnReq) processDeviceListUpdate(e gomatrixserverlib.EDU) { Event: payload, }, &inputRes) if inputRes.Error != nil { - util.GetLogger(t.context).WithError(inputRes.Error).WithField("user_id", payload.UserID).Error("failed to InputDeviceListUpdate") + util.GetLogger(ctx).WithError(inputRes.Error).WithField("user_id", payload.UserID).Error("failed to InputDeviceListUpdate") } } -func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) error { +func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, isInboundTxn bool) error { prevEventIDs := e.PrevEventIDs() // Fetch the state needed to authenticate the event. @@ -354,7 +354,7 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro StateToFetch: needed.Tuples(), } var stateResp api.QueryStateAfterEventsResponse - if err := t.rsAPI.QueryStateAfterEvents(t.context, &stateReq, &stateResp); err != nil { + if err := t.rsAPI.QueryStateAfterEvents(ctx, &stateReq, &stateResp); err != nil { return err } @@ -369,7 +369,7 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro } if !stateResp.PrevEventsExist { - return t.processEventWithMissingState(e, stateResp.RoomVersion, isInboundTxn) + return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion, isInboundTxn) } // Check that the event is allowed by the state at the event. @@ -379,7 +379,8 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro // pass the event to the roomserver return api.SendEvents( - t.context, t.rsAPI, + context.Background(), + t.rsAPI, []gomatrixserverlib.HeaderedEvent{ e.Headered(stateResp.RoomVersion), }, @@ -399,7 +400,12 @@ func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserver return gomatrixserverlib.Allowed(e, &authUsingState) } -func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) error { +func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) error { + // Do this with a fresh context, so that we keep working even if the + // original request times out. With any luck, by the time the remote + // side retries, we'll have fetched the missing state. + gmectx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer cancel() // We are missing the previous events for this events. // This means that there is a gap in our view of the history of the // room. There two ways that we can handle such a gap: @@ -420,7 +426,7 @@ func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVer // - fill in the gap completely then process event `e` returning no backwards extremity // - fail to fill in the gap and tell us to terminate the transaction err=not nil // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction - backwardsExtremity, err := t.getMissingEvents(e, roomVersion, isInboundTxn) + backwardsExtremity, err := t.getMissingEvents(gmectx, e, roomVersion, isInboundTxn) if err != nil { return err } @@ -437,16 +443,16 @@ func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVer needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{*backwardsExtremity}).Tuples() for _, prevEventID := range backwardsExtremity.PrevEventIDs() { var prevState *gomatrixserverlib.RespState - prevState, err = t.lookupStateAfterEvent(roomVersion, backwardsExtremity.RoomID(), prevEventID, needed) + prevState, err = t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID, needed) if err != nil { - util.GetLogger(t.context).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID) + util.GetLogger(ctx).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID) return err } states = append(states, prevState) } - resolvedState, err := t.resolveStatesAndCheck(roomVersion, states, backwardsExtremity) + resolvedState, err := t.resolveStatesAndCheck(gmectx, roomVersion, states, backwardsExtremity) if err != nil { - util.GetLogger(t.context).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) + util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) return err } @@ -457,20 +463,20 @@ func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVer // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) // added into the mix. -func (t *txnReq) lookupStateAfterEvent(roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) (*gomatrixserverlib.RespState, error) { +func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) (*gomatrixserverlib.RespState, error) { // try doing all this locally before we resort to querying federation - respState := t.lookupStateAfterEventLocally(roomID, eventID, needed) + respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID, needed) if respState != nil { return respState, nil } - respState, err := t.lookupStateBeforeEvent(roomVersion, roomID, eventID) + respState, err := t.lookupStateBeforeEvent(ctx, roomVersion, roomID, eventID) if err != nil { return nil, err } // fetch the event we're missing and add it to the pile - h, err := t.lookupEvent(roomVersion, eventID, false) + h, err := t.lookupEvent(ctx, roomVersion, eventID, false) if err != nil { return nil, err } @@ -493,15 +499,15 @@ func (t *txnReq) lookupStateAfterEvent(roomVersion gomatrixserverlib.RoomVersion return respState, nil } -func (t *txnReq) lookupStateAfterEventLocally(roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.RespState { +func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.RespState { var res api.QueryStateAfterEventsResponse - err := t.rsAPI.QueryStateAfterEvents(t.context, &api.QueryStateAfterEventsRequest{ + err := t.rsAPI.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{ RoomID: roomID, PrevEventIDs: []string{eventID}, StateToFetch: needed, }, &res) if err != nil || !res.PrevEventsExist { - util.GetLogger(t.context).WithError(err).Warnf("failed to query state after %s locally", eventID) + util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID) return nil } for i, ev := range res.StateEvents { @@ -528,9 +534,9 @@ func (t *txnReq) lookupStateAfterEventLocally(roomID, eventID string, needed []g queryReq := api.QueryEventsByIDRequest{ EventIDs: missingEventList, } - util.GetLogger(t.context).Infof("Fetching missing auth events: %v", missingEventList) + util.GetLogger(ctx).Infof("Fetching missing auth events: %v", missingEventList) var queryRes api.QueryEventsByIDResponse - if err = t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil { + if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { return nil } for i := range queryRes.Events { @@ -548,22 +554,22 @@ func (t *txnReq) lookupStateAfterEventLocally(roomID, eventID string, needed []g // lookuptStateBeforeEvent returns the room state before the event e, which is just /state_ids and/or /state depending on what // the server supports. -func (t *txnReq) lookupStateBeforeEvent(roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( +func (t *txnReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( respState *gomatrixserverlib.RespState, err error) { - util.GetLogger(t.context).Infof("lookupStateBeforeEvent %s", eventID) + util.GetLogger(ctx).Infof("lookupStateBeforeEvent %s", eventID) // Attempt to fetch the missing state using /state_ids and /events - respState, err = t.lookupMissingStateViaStateIDs(roomID, eventID, roomVersion) + respState, err = t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion) if err != nil { // Fallback to /state - util.GetLogger(t.context).WithError(err).Warn("lookupStateBeforeEvent failed to /state_ids, falling back to /state") - respState, err = t.lookupMissingStateViaState(roomID, eventID, roomVersion) + util.GetLogger(ctx).WithError(err).Warn("lookupStateBeforeEvent failed to /state_ids, falling back to /state") + respState, err = t.lookupMissingStateViaState(ctx, roomID, eventID, roomVersion) } return } -func (t *txnReq) resolveStatesAndCheck(roomVersion gomatrixserverlib.RoomVersion, states []*gomatrixserverlib.RespState, backwardsExtremity *gomatrixserverlib.Event) (*gomatrixserverlib.RespState, error) { +func (t *txnReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*gomatrixserverlib.RespState, backwardsExtremity *gomatrixserverlib.Event) (*gomatrixserverlib.RespState, error) { var authEventList []gomatrixserverlib.Event var stateEventList []gomatrixserverlib.Event for _, state := range states { @@ -579,11 +585,11 @@ retryAllowedState: if err = checkAllowedByState(*backwardsExtremity, resolvedStateEvents); err != nil { switch missing := err.(type) { case gomatrixserverlib.MissingAuthEventError: - h, err2 := t.lookupEvent(roomVersion, missing.AuthEventID, true) + h, err2 := t.lookupEvent(ctx, roomVersion, missing.AuthEventID, true) if err2 != nil { return nil, fmt.Errorf("missing auth event %s and failed to look it up: %w", missing.AuthEventID, err2) } - util.GetLogger(t.context).Infof("fetched event %s", missing.AuthEventID) + util.GetLogger(ctx).Infof("fetched event %s", missing.AuthEventID) resolvedStateEvents = append(resolvedStateEvents, h.Unwrap()) goto retryAllowedState default: @@ -600,12 +606,12 @@ retryAllowedState: // begin from. Returns an error only if we should terminate the transaction which initiated /get_missing_events // This function recursively calls txnReq.processEvent with the missing events, which will be processed before this function returns. // This means that we may recursively call this function, as we spider back up prev_events to the min depth. -func (t *txnReq) getMissingEvents(e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) (backwardsExtremity *gomatrixserverlib.Event, err error) { +func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) (backwardsExtremity *gomatrixserverlib.Event, err error) { if !isInboundTxn { // we've recursed here, so just take a state snapshot please! return &e, nil } - logger := util.GetLogger(t.context).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) + logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{e}) // query latest events (our trusted forward extremities) req := api.QueryLatestEventsAndStateRequest{ @@ -613,7 +619,7 @@ func (t *txnReq) getMissingEvents(e gomatrixserverlib.Event, roomVersion gomatri StateToFetch: needed.Tuples(), } var res api.QueryLatestEventsAndStateResponse - if err = t.rsAPI.QueryLatestEventsAndState(t.context, &req, &res); err != nil { + if err = t.rsAPI.QueryLatestEventsAndState(ctx, &req, &res); err != nil { logger.WithError(err).Warn("Failed to query latest events") return &e, nil } @@ -626,7 +632,7 @@ func (t *txnReq) getMissingEvents(e gomatrixserverlib.Event, roomVersion gomatri if minDepth < 0 { minDepth = 0 } - missingResp, err := t.federation.LookupMissingEvents(t.context, t.Origin, e.RoomID(), gomatrixserverlib.MissingEvents{ + missingResp, err := t.federation.LookupMissingEvents(ctx, t.Origin, e.RoomID(), gomatrixserverlib.MissingEvents{ Limit: 20, // synapse uses the min depth they've ever seen in that room MinDepth: minDepth, @@ -685,7 +691,7 @@ Event: } // process the missing events then the event which started this whole thing for _, ev := range append(newEvents, e) { - err := t.processEvent(ev, false) + err := t.processEvent(ctx, ev, false) if err != nil { return nil, err } @@ -695,24 +701,24 @@ Event: return nil, nil } -func (t *txnReq) lookupMissingStateViaState(roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( +func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( respState *gomatrixserverlib.RespState, err error) { - state, err := t.federation.LookupState(t.context, t.Origin, roomID, eventID, roomVersion) + state, err := t.federation.LookupState(ctx, t.Origin, roomID, eventID, roomVersion) if err != nil { return nil, err } // Check that the returned state is valid. - if err := state.Check(t.context, t.keys, nil); err != nil { + if err := state.Check(ctx, t.keys, nil); err != nil { return nil, err } return &state, nil } -func (t *txnReq) lookupMissingStateViaStateIDs(roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( +func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( *gomatrixserverlib.RespState, error) { - util.GetLogger(t.context).Infof("lookupMissingStateViaStateIDs %s", eventID) + util.GetLogger(ctx).Infof("lookupMissingStateViaStateIDs %s", eventID) // fetch the state event IDs at the time of the event - stateIDs, err := t.federation.LookupStateIDs(t.context, t.Origin, roomID, eventID) + stateIDs, err := t.federation.LookupStateIDs(ctx, t.Origin, roomID, eventID) if err != nil { return nil, err } @@ -734,7 +740,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(roomID, eventID string, roomVersi EventIDs: missingEventList, } var queryRes api.QueryEventsByIDResponse - if err = t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil { + if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { return nil, err } for i := range queryRes.Events { @@ -745,7 +751,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(roomID, eventID string, roomVersi } } - util.GetLogger(t.context).WithFields(logrus.Fields{ + util.GetLogger(ctx).WithFields(logrus.Fields{ "missing": len(missing), "event_id": eventID, "room_id": roomID, @@ -755,7 +761,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(roomID, eventID string, roomVersi for missingEventID := range missing { var h *gomatrixserverlib.HeaderedEvent - h, err = t.lookupEvent(roomVersion, missingEventID, false) + h, err = t.lookupEvent(ctx, roomVersion, missingEventID, false) if err != nil { return nil, err } @@ -793,33 +799,33 @@ func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStat return &respState, nil } -func (t *txnReq) lookupEvent(roomVersion gomatrixserverlib.RoomVersion, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) { +func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) { if localFirst { // fetch from the roomserver queryReq := api.QueryEventsByIDRequest{ EventIDs: []string{missingEventID}, } var queryRes api.QueryEventsByIDResponse - if err := t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil { - util.GetLogger(t.context).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) + if err := t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { + util.GetLogger(ctx).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) } else if len(queryRes.Events) == 1 { return &queryRes.Events[0], nil } } - txn, err := t.federation.GetEvent(t.context, t.Origin, missingEventID) + txn, err := t.federation.GetEvent(ctx, t.Origin, missingEventID) if err != nil || len(txn.PDUs) == 0 { - util.GetLogger(t.context).WithError(err).WithField("event_id", missingEventID).Warn("failed to get missing /event for event ID") + util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("failed to get missing /event for event ID") return nil, err } pdu := txn.PDUs[0] var event gomatrixserverlib.Event event, err = gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) if err != nil { - util.GetLogger(t.context).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %q", event.EventID()) + util.GetLogger(ctx).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %q", event.EventID()) return nil, unmarshalError{err} } - if err = gomatrixserverlib.VerifyAllEventSignatures(t.context, []gomatrixserverlib.Event{event}, t.keys); err != nil { - util.GetLogger(t.context).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) + if err = gomatrixserverlib.VerifyAllEventSignatures(ctx, []gomatrixserverlib.Event{event}, t.keys); err != nil { + util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) return nil, verifySigError{event.EventID(), err} } h := event.Headered(roomVersion) diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index f16fde0e..6b4a3084 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -365,7 +365,6 @@ func (c *txnFedClient) LookupMissingEvents(ctx context.Context, s gomatrixserver func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq { t := &txnReq{ - context: context.Background(), rsAPI: rsAPI, eduAPI: &testEDUProducer{}, keys: &test.NopJSONVerifier{}, @@ -381,7 +380,7 @@ func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederat } func mustProcessTransaction(t *testing.T, txn *txnReq, pdusWithErrors []string) { - res, err := txn.processTransaction() + res, err := txn.processTransaction(context.Background()) if err != nil { t.Errorf("txn.processTransaction returned an error: %v", err) return