diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index 85331192..e89d8ff2 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -451,23 +451,20 @@ func createRoom( util.GetLogger(req.Context()).WithError(err).Error("authEvents.AddEvent failed") return jsonerror.InternalServerError() } + } - accumulated := gomatrixserverlib.UnwrapEventHeaders(builtEvents) - if err = roomserverAPI.SendEventWithState( - req.Context(), - rsAPI, - roomserverAPI.KindNew, - &gomatrixserverlib.RespState{ - StateEvents: accumulated, - AuthEvents: accumulated, - }, - ev.Headered(roomVersion), - nil, - false, - ); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("SendEventWithState failed") - return jsonerror.InternalServerError() - } + inputs := make([]roomserverAPI.InputRoomEvent, 0, len(builtEvents)) + for _, event := range builtEvents { + inputs = append(inputs, roomserverAPI.InputRoomEvent{ + Kind: roomserverAPI.KindNew, + Event: event, + Origin: cfg.Matrix.ServerName, + SendAsServer: roomserverAPI.DoNotSendToOtherServers, + }) + } + if err = roomserverAPI.SendInputRoomEvents(req.Context(), rsAPI, inputs, false); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") + return jsonerror.InternalServerError() } // TODO(#269): Reserve room alias while we create the room. This stops us diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 7ddb827e..4ce82079 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -109,6 +109,7 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, cfg.Matrix.ServerName, + cfg.Matrix.ServerName, nil, false, ); err != nil { diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 9de1869b..017facd2 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -169,7 +169,7 @@ func SetAvatarURL( return jsonerror.InternalServerError() } - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, nil, false); err != nil { + if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, false); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -286,7 +286,7 @@ func SetDisplayName( return jsonerror.InternalServerError() } - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, nil, false); err != nil { + if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, false); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 8492236b..01ea818a 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -120,7 +120,7 @@ func SendRedaction( JSON: jsonerror.NotFound("Room does not exist"), } } - if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, cfg.Matrix.ServerName, nil, false); err != nil { + if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, false); err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index f0498312..606107b9 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -121,6 +121,7 @@ func SendEvent( e.Headered(verRes.RoomVersion), }, cfg.Matrix.ServerName, + cfg.Matrix.ServerName, txnAndSessionID, false, ); err != nil { diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 985cf00c..db62ce06 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -366,6 +366,7 @@ func emit3PIDInviteEvent( event.Headered(queryRes.RoomVersion), }, cfg.Matrix.ServerName, + cfg.Matrix.ServerName, nil, false, ) diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 7be19dad..f5ee75b4 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -23,6 +23,8 @@ type FederationClient interface { MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest) (res gomatrixserverlib.MSC2946SpacesResponse, err error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) + GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) + LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } // FederationClientError is returned from FederationClient methods in the event of a problem. diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 632adae3..25ea7827 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -66,7 +66,11 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { - _, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) + _, err := s.jetstream.Subscribe( + s.topic, s.onMessage, s.durable, + nats.DeliverAll(), + nats.ManualAck(), + ) return err } diff --git a/federationapi/internal/api.go b/federationapi/internal/api.go index 1f31b07c..4e9fa841 100644 --- a/federationapi/internal/api.go +++ b/federationapi/internal/api.go @@ -1,9 +1,9 @@ package internal import ( - "context" "crypto/ed25519" "encoding/base64" + "fmt" "sync" "time" @@ -142,7 +142,7 @@ func failBlacklistableError(err error, stats *statistics.ServerStatistics) (unti return } -func (a *FederationInternalAPI) doRequest( +func (a *FederationInternalAPI) doRequestIfNotBackingOffOrBlacklisted( s gomatrixserverlib.ServerName, request func() (interface{}, error), ) (interface{}, error) { stats, err := a.isBlacklistedOrBackingOff(s) @@ -167,141 +167,15 @@ func (a *FederationInternalAPI) doRequest( return res, nil } -func (a *FederationInternalAPI) GetUserDevices( - ctx context.Context, s gomatrixserverlib.ServerName, userID string, -) (gomatrixserverlib.RespUserDevices, error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*30) - defer cancel() - ires, err := a.doRequest(s, func() (interface{}, error) { - return a.federation.GetUserDevices(ctx, s, userID) - }) - if err != nil { - return gomatrixserverlib.RespUserDevices{}, err +func (a *FederationInternalAPI) doRequestIfNotBlacklisted( + s gomatrixserverlib.ServerName, request func() (interface{}, error), +) (interface{}, error) { + stats := a.statistics.ForServer(s) + if _, blacklisted := stats.BackoffInfo(); blacklisted { + return stats, &api.FederationClientError{ + Err: fmt.Sprintf("server %q is blacklisted", s), + Blacklisted: true, + } } - return ires.(gomatrixserverlib.RespUserDevices), nil -} - -func (a *FederationInternalAPI) ClaimKeys( - ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, -) (gomatrixserverlib.RespClaimKeys, error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*30) - defer cancel() - ires, err := a.doRequest(s, func() (interface{}, error) { - return a.federation.ClaimKeys(ctx, s, oneTimeKeys) - }) - if err != nil { - return gomatrixserverlib.RespClaimKeys{}, err - } - return ires.(gomatrixserverlib.RespClaimKeys), nil -} - -func (a *FederationInternalAPI) QueryKeys( - ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, -) (gomatrixserverlib.RespQueryKeys, error) { - ires, err := a.doRequest(s, func() (interface{}, error) { - return a.federation.QueryKeys(ctx, s, keys) - }) - if err != nil { - return gomatrixserverlib.RespQueryKeys{}, err - } - return ires.(gomatrixserverlib.RespQueryKeys), nil -} - -func (a *FederationInternalAPI) Backfill( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, -) (res gomatrixserverlib.Transaction, err error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*30) - defer cancel() - ires, err := a.doRequest(s, func() (interface{}, error) { - return a.federation.Backfill(ctx, s, roomID, limit, eventIDs) - }) - if err != nil { - return gomatrixserverlib.Transaction{}, err - } - return ires.(gomatrixserverlib.Transaction), nil -} - -func (a *FederationInternalAPI) LookupState( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, -) (res gomatrixserverlib.RespState, err error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*30) - defer cancel() - ires, err := a.doRequest(s, func() (interface{}, error) { - return a.federation.LookupState(ctx, s, roomID, eventID, roomVersion) - }) - if err != nil { - return gomatrixserverlib.RespState{}, err - } - return ires.(gomatrixserverlib.RespState), nil -} - -func (a *FederationInternalAPI) LookupStateIDs( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, -) (res gomatrixserverlib.RespStateIDs, err error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*30) - defer cancel() - ires, err := a.doRequest(s, func() (interface{}, error) { - return a.federation.LookupStateIDs(ctx, s, roomID, eventID) - }) - if err != nil { - return gomatrixserverlib.RespStateIDs{}, err - } - return ires.(gomatrixserverlib.RespStateIDs), nil -} - -func (a *FederationInternalAPI) GetEvent( - ctx context.Context, s gomatrixserverlib.ServerName, eventID string, -) (res gomatrixserverlib.Transaction, err error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*30) - defer cancel() - ires, err := a.doRequest(s, func() (interface{}, error) { - return a.federation.GetEvent(ctx, s, eventID) - }) - if err != nil { - return gomatrixserverlib.Transaction{}, err - } - return ires.(gomatrixserverlib.Transaction), nil -} - -func (a *FederationInternalAPI) LookupServerKeys( - ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, -) ([]gomatrixserverlib.ServerKeys, error) { - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() - ires, err := a.doRequest(s, func() (interface{}, error) { - return a.federation.LookupServerKeys(ctx, s, keyRequests) - }) - if err != nil { - return []gomatrixserverlib.ServerKeys{}, err - } - return ires.([]gomatrixserverlib.ServerKeys), nil -} - -func (a *FederationInternalAPI) MSC2836EventRelationships( - ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, - roomVersion gomatrixserverlib.RoomVersion, -) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() - ires, err := a.doRequest(s, func() (interface{}, error) { - return a.federation.MSC2836EventRelationships(ctx, s, r, roomVersion) - }) - if err != nil { - return res, err - } - return ires.(gomatrixserverlib.MSC2836EventRelationshipsResponse), nil -} - -func (a *FederationInternalAPI) MSC2946Spaces( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest, -) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() - ires, err := a.doRequest(s, func() (interface{}, error) { - return a.federation.MSC2946Spaces(ctx, s, roomID, r) - }) - if err != nil { - return res, err - } - return ires.(gomatrixserverlib.MSC2946SpacesResponse), nil + return request() } diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go new file mode 100644 index 00000000..b31db466 --- /dev/null +++ b/federationapi/internal/federationclient.go @@ -0,0 +1,180 @@ +package internal + +import ( + "context" + "time" + + "github.com/matrix-org/gomatrixserverlib" +) + +// Functions here are "proxying" calls to the gomatrixserverlib federation +// client. + +func (a *FederationInternalAPI) GetEventAuth( + ctx context.Context, s gomatrixserverlib.ServerName, + roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, +) (res gomatrixserverlib.RespEventAuth, err error) { + ctx, cancel := context.WithTimeout(ctx, time.Second*30) + defer cancel() + ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { + return a.federation.GetEventAuth(ctx, s, roomVersion, roomID, eventID) + }) + if err != nil { + return gomatrixserverlib.RespEventAuth{}, err + } + return ires.(gomatrixserverlib.RespEventAuth), nil +} + +func (a *FederationInternalAPI) GetUserDevices( + ctx context.Context, s gomatrixserverlib.ServerName, userID string, +) (gomatrixserverlib.RespUserDevices, error) { + ctx, cancel := context.WithTimeout(ctx, time.Second*30) + defer cancel() + ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { + return a.federation.GetUserDevices(ctx, s, userID) + }) + if err != nil { + return gomatrixserverlib.RespUserDevices{}, err + } + return ires.(gomatrixserverlib.RespUserDevices), nil +} + +func (a *FederationInternalAPI) ClaimKeys( + ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, +) (gomatrixserverlib.RespClaimKeys, error) { + ctx, cancel := context.WithTimeout(ctx, time.Second*30) + defer cancel() + ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) { + return a.federation.ClaimKeys(ctx, s, oneTimeKeys) + }) + if err != nil { + return gomatrixserverlib.RespClaimKeys{}, err + } + return ires.(gomatrixserverlib.RespClaimKeys), nil +} + +func (a *FederationInternalAPI) QueryKeys( + ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, +) (gomatrixserverlib.RespQueryKeys, error) { + ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) { + return a.federation.QueryKeys(ctx, s, keys) + }) + if err != nil { + return gomatrixserverlib.RespQueryKeys{}, err + } + return ires.(gomatrixserverlib.RespQueryKeys), nil +} + +func (a *FederationInternalAPI) Backfill( + ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, +) (res gomatrixserverlib.Transaction, err error) { + ctx, cancel := context.WithTimeout(ctx, time.Second*30) + defer cancel() + ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { + return a.federation.Backfill(ctx, s, roomID, limit, eventIDs) + }) + if err != nil { + return gomatrixserverlib.Transaction{}, err + } + return ires.(gomatrixserverlib.Transaction), nil +} + +func (a *FederationInternalAPI) LookupState( + ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, +) (res gomatrixserverlib.RespState, err error) { + ctx, cancel := context.WithTimeout(ctx, time.Second*30) + defer cancel() + ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { + return a.federation.LookupState(ctx, s, roomID, eventID, roomVersion) + }) + if err != nil { + return gomatrixserverlib.RespState{}, err + } + return ires.(gomatrixserverlib.RespState), nil +} + +func (a *FederationInternalAPI) LookupStateIDs( + ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, +) (res gomatrixserverlib.RespStateIDs, err error) { + ctx, cancel := context.WithTimeout(ctx, time.Second*30) + defer cancel() + ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { + return a.federation.LookupStateIDs(ctx, s, roomID, eventID) + }) + if err != nil { + return gomatrixserverlib.RespStateIDs{}, err + } + return ires.(gomatrixserverlib.RespStateIDs), nil +} + +func (a *FederationInternalAPI) LookupMissingEvents( + ctx context.Context, s gomatrixserverlib.ServerName, roomID string, + missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, +) (res gomatrixserverlib.RespMissingEvents, err error) { + ctx, cancel := context.WithTimeout(ctx, time.Second*30) + defer cancel() + ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { + return a.federation.LookupMissingEvents(ctx, s, roomID, missing, roomVersion) + }) + if err != nil { + return gomatrixserverlib.RespMissingEvents{}, err + } + return ires.(gomatrixserverlib.RespMissingEvents), nil +} + +func (a *FederationInternalAPI) GetEvent( + ctx context.Context, s gomatrixserverlib.ServerName, eventID string, +) (res gomatrixserverlib.Transaction, err error) { + ctx, cancel := context.WithTimeout(ctx, time.Second*30) + defer cancel() + ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { + return a.federation.GetEvent(ctx, s, eventID) + }) + if err != nil { + return gomatrixserverlib.Transaction{}, err + } + return ires.(gomatrixserverlib.Transaction), nil +} + +func (a *FederationInternalAPI) LookupServerKeys( + ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, +) ([]gomatrixserverlib.ServerKeys, error) { + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { + return a.federation.LookupServerKeys(ctx, s, keyRequests) + }) + if err != nil { + return []gomatrixserverlib.ServerKeys{}, err + } + return ires.([]gomatrixserverlib.ServerKeys), nil +} + +func (a *FederationInternalAPI) MSC2836EventRelationships( + ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, + roomVersion gomatrixserverlib.RoomVersion, +) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { + return a.federation.MSC2836EventRelationships(ctx, s, r, roomVersion) + }) + if err != nil { + return res, err + } + return ires.(gomatrixserverlib.MSC2836EventRelationshipsResponse), nil +} + +func (a *FederationInternalAPI) MSC2946Spaces( + ctx context.Context, s gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest, +) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { + return a.federation.MSC2946Spaces(ctx, s, roomID, r) + }) + if err != nil { + return res, err + } + return ires.(gomatrixserverlib.MSC2946SpacesResponse), nil +} diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index b6c35842..4dd53c11 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -249,7 +249,9 @@ func (r *FederationInternalAPI) performJoinUsingServer( roomserverAPI.KindNew, respState, event.Headered(respMakeJoin.RoomVersion), - nil, false, + serverName, + nil, + false, ); err != nil { logrus.WithFields(logrus.Fields{ "room_id": roomID, @@ -430,7 +432,9 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( roomserverAPI.KindNew, &respState, respPeek.LatestEvent.Headered(respPeek.RoomVersion), - nil, false, + serverName, + nil, + false, ); err != nil { return fmt.Errorf("r.producer.SendEventWithState: %w", err) } diff --git a/federationapi/internal/query.go b/federationapi/internal/query.go index 31d1a3c4..b0a76eeb 100644 --- a/federationapi/internal/query.go +++ b/federationapi/internal/query.go @@ -28,7 +28,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom( func (a *FederationInternalAPI) fetchServerKeysDirectly(ctx context.Context, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() - ires, err := a.doRequest(serverName, func() (interface{}, error) { + ires, err := a.doRequestIfNotBackingOffOrBlacklisted(serverName, func() (interface{}, error) { return a.federation.GetServerKeys(ctx, serverName) }) if err != nil { diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index af6b801b..a65df906 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -26,16 +26,18 @@ const ( FederationAPIPerformServersAlivePath = "/federationapi/performServersAlive" FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" - FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" - FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" - FederationAPIQueryKeysPath = "/federationapi/client/queryKeys" - FederationAPIBackfillPath = "/federationapi/client/backfill" - FederationAPILookupStatePath = "/federationapi/client/lookupState" - FederationAPILookupStateIDsPath = "/federationapi/client/lookupStateIDs" - FederationAPIGetEventPath = "/federationapi/client/getEvent" - FederationAPILookupServerKeysPath = "/federationapi/client/lookupServerKeys" - FederationAPIEventRelationshipsPath = "/federationapi/client/msc2836eventRelationships" - FederationAPISpacesSummaryPath = "/federationapi/client/msc2946spacesSummary" + FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" + FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" + FederationAPIQueryKeysPath = "/federationapi/client/queryKeys" + FederationAPIBackfillPath = "/federationapi/client/backfill" + FederationAPILookupStatePath = "/federationapi/client/lookupState" + FederationAPILookupStateIDsPath = "/federationapi/client/lookupStateIDs" + FederationAPILookupMissingEventsPath = "/federationapi/client/lookupMissingEvents" + FederationAPIGetEventPath = "/federationapi/client/getEvent" + FederationAPILookupServerKeysPath = "/federationapi/client/lookupServerKeys" + FederationAPIEventRelationshipsPath = "/federationapi/client/msc2836eventRelationships" + FederationAPISpacesSummaryPath = "/federationapi/client/msc2946spacesSummary" + FederationAPIGetEventAuthPath = "/federationapi/client/getEventAuth" FederationAPIInputPublicKeyPath = "/federationapi/inputPublicKey" FederationAPIQueryPublicKeyPath = "/federationapi/queryPublicKey" @@ -353,6 +355,49 @@ func (h *httpFederationInternalAPI) LookupStateIDs( return *response.Res, nil } +type lookupMissingEvents struct { + S gomatrixserverlib.ServerName + RoomID string + Missing gomatrixserverlib.MissingEvents + RoomVersion gomatrixserverlib.RoomVersion + Res struct { + Events []gomatrixserverlib.RawJSON `json:"events"` + } + Err *api.FederationClientError +} + +func (h *httpFederationInternalAPI) LookupMissingEvents( + ctx context.Context, s gomatrixserverlib.ServerName, roomID string, + missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, +) (res gomatrixserverlib.RespMissingEvents, err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "LookupMissingEvents") + defer span.Finish() + + request := lookupMissingEvents{ + S: s, + RoomID: roomID, + Missing: missing, + RoomVersion: roomVersion, + } + apiURL := h.federationAPIURL + FederationAPILookupMissingEventsPath + err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &request) + if err != nil { + return res, err + } + if request.Err != nil { + return res, request.Err + } + res.Events = make([]*gomatrixserverlib.Event, 0, len(request.Res.Events)) + for _, js := range request.Res.Events { + ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(js, roomVersion) + if err != nil { + return res, err + } + res.Events = append(res.Events, ev) + } + return res, nil +} + type getEvent struct { S gomatrixserverlib.ServerName EventID string @@ -382,6 +427,40 @@ func (h *httpFederationInternalAPI) GetEvent( return *response.Res, nil } +type getEventAuth struct { + S gomatrixserverlib.ServerName + RoomVersion gomatrixserverlib.RoomVersion + RoomID string + EventID string + Res *gomatrixserverlib.RespEventAuth + Err *api.FederationClientError +} + +func (h *httpFederationInternalAPI) GetEventAuth( + ctx context.Context, s gomatrixserverlib.ServerName, + roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, +) (gomatrixserverlib.RespEventAuth, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "GetEventAuth") + defer span.Finish() + + request := getEventAuth{ + S: s, + RoomVersion: roomVersion, + RoomID: roomID, + EventID: eventID, + } + var response getEventAuth + apiURL := h.federationAPIURL + FederationAPIGetEventAuthPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) + if err != nil { + return gomatrixserverlib.RespEventAuth{}, err + } + if response.Err != nil { + return gomatrixserverlib.RespEventAuth{}, response.Err + } + return *response.Res, nil +} + func (h *httpFederationInternalAPI) QueryServerKeys( ctx context.Context, req *api.QueryServerKeysRequest, res *api.QueryServerKeysResponse, ) error { diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 7133eddd..8d193d9c 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -241,6 +241,34 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: request} }), ) + internalAPIMux.Handle( + FederationAPILookupMissingEventsPath, + httputil.MakeInternalAPI("LookupMissingEvents", func(req *http.Request) util.JSONResponse { + var request lookupMissingEvents + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + res, err := intAPI.LookupMissingEvents(req.Context(), request.S, request.RoomID, request.Missing, request.RoomVersion) + if err != nil { + ferr, ok := err.(*api.FederationClientError) + if ok { + request.Err = ferr + } else { + request.Err = &api.FederationClientError{ + Err: err.Error(), + } + } + } + for _, event := range res.Events { + js, err := json.Marshal(event) + if err != nil { + return util.MessageResponse(http.StatusInternalServerError, err.Error()) + } + request.Res.Events = append(request.Res.Events, js) + } + return util.JSONResponse{Code: http.StatusOK, JSON: request} + }), + ) internalAPIMux.Handle( FederationAPIGetEventPath, httputil.MakeInternalAPI("GetEvent", func(req *http.Request) util.JSONResponse { @@ -263,6 +291,28 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: request} }), ) + internalAPIMux.Handle( + FederationAPIGetEventAuthPath, + httputil.MakeInternalAPI("GetEventAuth", func(req *http.Request) util.JSONResponse { + var request getEventAuth + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + res, err := intAPI.GetEventAuth(req.Context(), request.S, request.RoomVersion, request.RoomID, request.EventID) + if err != nil { + ferr, ok := err.(*api.FederationClientError) + if ok { + request.Err = ferr + } else { + request.Err = &api.FederationClientError{ + Err: err.Error(), + } + } + } + request.Res = &res + return util.JSONResponse{Code: http.StatusOK, JSON: request} + }), + ) internalAPIMux.Handle( FederationAPIQueryServerKeysPath, httputil.MakeInternalAPI("QueryServerKeys", func(req *http.Request) util.JSONResponse { diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 7310a305..7f8d3150 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -324,7 +324,6 @@ func SendJoin( { Kind: api.KindNew, Event: event.Headered(stateAndAuthChainResponse.RoomVersion), - AuthEventIDs: event.AuthEventIDs(), SendAsServer: string(cfg.Matrix.ServerName), TransactionID: nil, }, diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index 6312adfa..0b83f04a 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -275,7 +275,6 @@ func SendLeave( { Kind: api.KindNew, Event: event.Headered(verRes.RoomVersion), - AuthEventIDs: event.AuthEventIDs(), SendAsServer: string(cfg.Matrix.ServerName), TransactionID: nil, }, diff --git a/federationapi/routing/publicrooms.go b/federationapi/routing/publicrooms.go index 5b9be880..a253f86e 100644 --- a/federationapi/routing/publicrooms.go +++ b/federationapi/routing/publicrooms.go @@ -133,8 +133,6 @@ func fillInRooms(ctx context.Context, roomIDs []string, rsAPI roomserverAPI.Room util.GetLogger(ctx).WithError(err).Error("QueryBulkStateContent failed") return nil, err } - util.GetLogger(ctx).Infof("room IDs: %+v", roomIDs) - util.GetLogger(ctx).Infof("State res: %+v", stateRes.Rooms) chunk := make([]gomatrixserverlib.PublicRoom, len(roomIDs)) i := 0 for roomID, data := range stateRes.Rooms { diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index fad23a5c..dbfd3ff9 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -17,7 +17,6 @@ package routing import ( "context" "encoding/json" - "errors" "fmt" "net/http" "sync" @@ -34,7 +33,6 @@ import ( "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" - "go.uber.org/atomic" ) const ( @@ -72,84 +70,15 @@ var ( Help: "Number of incoming EDUs from remote servers", }, ) - processEventSummary = prometheus.NewSummaryVec( - prometheus.SummaryOpts{ - Namespace: "dendrite", - Subsystem: "federationapi", - Name: "process_event", - Help: "How long it takes to process an incoming event and what work had to be done for it", - }, - []string{"work", "outcome"}, - ) ) func init() { prometheus.MustRegister( - pduCountTotal, eduCountTotal, processEventSummary, + pduCountTotal, eduCountTotal, ) } -type sendFIFOQueue struct { - tasks []*inputTask - count int - mutex sync.Mutex - notifs chan struct{} -} - -func newSendFIFOQueue() *sendFIFOQueue { - q := &sendFIFOQueue{ - notifs: make(chan struct{}, 1), - } - return q -} - -func (q *sendFIFOQueue) push(frame *inputTask) { - q.mutex.Lock() - defer q.mutex.Unlock() - q.tasks = append(q.tasks, frame) - q.count++ - select { - case q.notifs <- struct{}{}: - default: - } -} - -// pop returns the first item of the queue, if there is one. -// The second return value will indicate if a task was returned. -func (q *sendFIFOQueue) pop() (*inputTask, bool) { - q.mutex.Lock() - defer q.mutex.Unlock() - if q.count == 0 { - return nil, false - } - frame := q.tasks[0] - q.tasks[0] = nil - q.tasks = q.tasks[1:] - q.count-- - if q.count == 0 { - // Force a GC of the underlying array, since it might have - // grown significantly if the queue was hammered for some reason - q.tasks = nil - } - return frame, true -} - -type inputTask struct { - ctx context.Context - t *txnReq - event *gomatrixserverlib.Event - wg *sync.WaitGroup - err error // written back by worker, only safe to read when all tasks are done - duration time.Duration // written back by worker, only safe to read when all tasks are done -} - -type inputWorker struct { - running atomic.Bool - input *sendFIFOQueue -} - var inFlightTxnsPerOrigin sync.Map // transaction ID -> chan util.JSONResponse -var inputWorkers sync.Map // room ID -> *inputWorker // Send implements /_matrix/federation/v1/send/{txnID} func Send( @@ -201,8 +130,6 @@ func Send( eduAPI: eduAPI, keys: keys, federation: federation, - hadEvents: make(map[string]bool), - haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), servers: servers, keyAPI: keyAPI, roomsMu: mu, @@ -237,7 +164,7 @@ func Send( util.GetLogger(httpReq.Context()).Infof("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, request.Origin(), len(t.PDUs), len(t.EDUs)) - resp, jsonErr := t.processTransaction(context.Background()) + resp, jsonErr := t.processTransaction(httpReq.Context()) if jsonErr != nil { util.GetLogger(httpReq.Context()).WithField("jsonErr", jsonErr).Error("t.processTransaction failed") return *jsonErr @@ -263,22 +190,7 @@ type txnReq struct { keys gomatrixserverlib.JSONVerifier federation txnFederationClient roomsMu *internal.MutexByRoom - // something that can tell us about which servers are in a room right now - servers federationAPI.ServersInRoomProvider - // a list of events from the auth and prev events which we already had - hadEvents map[string]bool - hadEventsMutex sync.Mutex - // local cache of events for auth checks, etc - this may include events - // which the roomserver is unaware of. - haveEvents map[string]*gomatrixserverlib.HeaderedEvent - haveEventsMutex sync.Mutex - work string // metrics -} - -func (t *txnReq) hadEvent(eventID string, had bool) { - t.hadEventsMutex.Lock() - defer t.hadEventsMutex.Unlock() - t.hadEvents[eventID] = had + servers federationAPI.ServersInRoomProvider } // A subset of FederationClient functionality that txn requires. Useful for testing. @@ -293,9 +205,28 @@ type txnFederationClient interface { } func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { - results := make(map[string]gomatrixserverlib.PDUResult) var wg sync.WaitGroup - var tasks []*inputTask + wg.Add(1) + go func() { + defer wg.Done() + t.processEDUs(ctx) + }() + + results := make(map[string]gomatrixserverlib.PDUResult) + roomVersions := make(map[string]gomatrixserverlib.RoomVersion) + getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion { + if v, ok := roomVersions[roomID]; ok { + return v + } + verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} + verRes := api.QueryRoomVersionForRoomResponse{} + if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { + util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to query room version for room", verReq.RoomID) + return "" + } + roomVersions[roomID] = verRes.RoomVersion + return verRes.RoomVersion + } for _, pdu := range t.PDUs { pduCountTotal.WithLabelValues("total").Inc() @@ -308,15 +239,8 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res // failure in the PDU results continue } - verReq := api.QueryRoomVersionForRoomRequest{RoomID: header.RoomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { - util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to query room version for room", verReq.RoomID) - // We don't know the event ID at this point so we can't return the - // failure in the PDU results - continue - } - event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, verRes.RoomVersion) + roomVersion := getRoomVersion(header.RoomID) + event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) if err != nil { if _, ok := err.(gomatrixserverlib.BadJSONError); ok { // Room version 6 states that homeservers should strictly enforce canonical JSON @@ -347,114 +271,35 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res } continue } - v, _ := inputWorkers.LoadOrStore(event.RoomID(), &inputWorker{ - input: newSendFIFOQueue(), - }) - worker := v.(*inputWorker) - wg.Add(1) - task := &inputTask{ - ctx: ctx, - t: t, - event: event, - wg: &wg, - } - tasks = append(tasks, task) - worker.input.push(task) - if worker.running.CAS(false, true) { - go worker.run() - } - } - t.processEDUs(ctx) - wg.Wait() - - for _, task := range tasks { - if task.err != nil { - results[task.event.EventID()] = gomatrixserverlib.PDUResult{ - // Error: task.err.Error(), TODO: this upsets tests if uncommented + // pass the event to the roomserver which will do auth checks + // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently + // discarded by the caller of this function + if err = api.SendEvents( + ctx, + t.rsAPI, + api.KindNew, + []*gomatrixserverlib.HeaderedEvent{ + event.Headered(roomVersion), + }, + t.Origin, + api.DoNotSendToOtherServers, + nil, + true, + ); err != nil { + util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err) + results[event.EventID()] = gomatrixserverlib.PDUResult{ + Error: err.Error(), } - } else { - results[task.event.EventID()] = gomatrixserverlib.PDUResult{} - } - } - - if c := len(results); c > 0 { - util.GetLogger(ctx).Infof("Processed %d PDUs from %v in transaction %q", c, t.Origin, t.TransactionID) - } - return &gomatrixserverlib.RespSend{PDUs: results}, nil -} - -func (t *inputWorker) run() { - defer t.running.Store(false) - for { - task, ok := t.input.pop() - if !ok { - return - } - if task == nil { continue } - func() { - defer task.wg.Done() - select { - case <-task.ctx.Done(): - task.err = context.DeadlineExceeded - pduCountTotal.WithLabelValues("expired").Inc() - return - default: - evStart := time.Now() - // TODO: Is 5 minutes too long? - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) - task.err = task.t.processEvent(ctx, task.event) - cancel() - task.duration = time.Since(evStart) - if err := task.err; err != nil { - switch err.(type) { - case *gomatrixserverlib.NotAllowed: - processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeRejected).Observe( - float64(time.Since(evStart).Nanoseconds()) / 1000., - ) - util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", true).Warn( - "Failed to process incoming federation event, skipping", - ) - task.err = nil // make "rejected" failures silent - default: - processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeFail).Observe( - float64(time.Since(evStart).Nanoseconds()) / 1000., - ) - util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", false).Warn( - "Failed to process incoming federation event, skipping", - ) - } - } else { - pduCountTotal.WithLabelValues("success").Inc() - processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeOK).Observe( - float64(time.Since(evStart).Nanoseconds()) / 1000., - ) - } - } - }() + + results[event.EventID()] = gomatrixserverlib.PDUResult{} + pduCountTotal.WithLabelValues("success").Inc() } -} -type roomNotFoundError struct { - roomID string -} -type verifySigError struct { - eventID string - err error -} -type missingPrevEventsError struct { - eventID string - err error -} - -func (e roomNotFoundError) Error() string { return fmt.Sprintf("room %q not found", e.roomID) } -func (e verifySigError) Error() string { - return fmt.Sprintf("unable to verify signature of event %q: %s", e.eventID, e.err) -} -func (e missingPrevEventsError) Error() string { - return fmt.Sprintf("unable to get prev_events for event %q: %s", e.eventID, e.err) + wg.Wait() + return &gomatrixserverlib.RespSend{PDUs: results}, nil } func (t *txnReq) processEDUs(ctx context.Context) { @@ -598,811 +443,3 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli util.GetLogger(ctx).WithError(inputRes.Error).WithField("user_id", payload.UserID).Error("failed to InputDeviceListUpdate") } } - -func (t *txnReq) getServers(ctx context.Context, roomID string, event *gomatrixserverlib.Event) []gomatrixserverlib.ServerName { - // The server that sent us the event should be sufficient to tell us about missing - // prev and auth events. - servers := []gomatrixserverlib.ServerName{t.Origin} - // If the event origin is different to the transaction origin then we can use - // this as a last resort. The origin server that created the event would have - // had to know the auth and prev events. - if event != nil { - if origin := event.Origin(); origin != t.Origin { - servers = append(servers, origin) - } - } - // If a specific room-to-server provider exists then use that. This will primarily - // be used for the P2P demos. - if t.servers != nil { - servers = append(servers, t.servers.GetServersForRoom(ctx, roomID, event)...) - } - return servers -} - -func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) error { - logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) - t.work = "" // reset from previous event - - // Ask the roomserver if we know about the room and/or if we're joined - // to it. If we aren't then we won't bother processing the event. - joinedReq := api.QueryServerJoinedToRoomRequest{ - RoomID: e.RoomID(), - } - var joinedRes api.QueryServerJoinedToRoomResponse - if err := t.rsAPI.QueryServerJoinedToRoom(ctx, &joinedReq, &joinedRes); err != nil { - return fmt.Errorf("t.rsAPI.QueryServerJoinedToRoom: %w", err) - } - - if !joinedRes.RoomExists || !joinedRes.IsInRoom { - // We don't believe we're a member of this room, therefore there's - // no point in wasting work trying to figure out what to do with - // missing auth or prev events. Drop the event. - return roomNotFoundError{e.RoomID()} - } - - // Work out if the roomserver knows everything it needs to know to auth - // the event. This includes the prev_events and auth_events. - // NOTE! This is going to include prev_events that have an empty state - // snapshot. This is because we will need to re-request the event, and - // it's /state_ids, in order for it to exist in the roomserver correctly - // before the roomserver tries to work out - stateReq := api.QueryMissingAuthPrevEventsRequest{ - RoomID: e.RoomID(), - AuthEventIDs: e.AuthEventIDs(), - PrevEventIDs: e.PrevEventIDs(), - } - var stateResp api.QueryMissingAuthPrevEventsResponse - if err := t.rsAPI.QueryMissingAuthPrevEvents(ctx, &stateReq, &stateResp); err != nil { - return fmt.Errorf("t.rsAPI.QueryMissingAuthPrevEvents: %w", err) - } - - // Prepare a map of all the events we already had before this point, so - // that we don't send them to the roomserver again. - for _, eventID := range append(e.AuthEventIDs(), e.PrevEventIDs()...) { - t.hadEvent(eventID, true) - } - for _, eventID := range append(stateResp.MissingAuthEventIDs, stateResp.MissingPrevEventIDs...) { - t.hadEvent(eventID, false) - } - - if len(stateResp.MissingAuthEventIDs) > 0 { - t.work = MetricsWorkMissingAuthEvents - logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs)) - if err := t.retrieveMissingAuthEvents(ctx, e, &stateResp); err != nil { - return fmt.Errorf("t.retrieveMissingAuthEvents: %w", err) - } - } - - if len(stateResp.MissingPrevEventIDs) > 0 { - t.work = MetricsWorkMissingPrevEvents - logger.Infof("Event refers to %d unknown prev_events", len(stateResp.MissingPrevEventIDs)) - return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion) - } - t.work = MetricsWorkDirect - - // pass the event to the roomserver which will do auth checks - // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently - // discarded by the caller of this function - return api.SendEvents( - context.Background(), - t.rsAPI, - api.KindNew, - []*gomatrixserverlib.HeaderedEvent{ - e.Headered(stateResp.RoomVersion), - }, - api.DoNotSendToOtherServers, - nil, - false, - ) -} - -func (t *txnReq) retrieveMissingAuthEvents( - ctx context.Context, e *gomatrixserverlib.Event, stateResp *api.QueryMissingAuthPrevEventsResponse, -) error { - logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) - - missingAuthEvents := make(map[string]struct{}) - for _, missingAuthEventID := range stateResp.MissingAuthEventIDs { - missingAuthEvents[missingAuthEventID] = struct{}{} - } - -withNextEvent: - for missingAuthEventID := range missingAuthEvents { - withNextServer: - for _, server := range t.getServers(ctx, e.RoomID(), e) { - logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server) - tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID) - if err != nil { - logger.WithError(err).Warnf("Failed to retrieve auth event %q", missingAuthEventID) - if errors.Is(err, context.DeadlineExceeded) { - return err - } - continue withNextServer - } - ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion) - if err != nil { - logger.WithError(err).Warnf("Failed to unmarshal auth event %q", missingAuthEventID) - continue withNextServer - } - if err = api.SendInputRoomEvents( - context.Background(), - t.rsAPI, - []api.InputRoomEvent{ - { - Kind: api.KindOutlier, - Event: ev.Headered(stateResp.RoomVersion), - AuthEventIDs: ev.AuthEventIDs(), - SendAsServer: api.DoNotSendToOtherServers, - }, - }, - false, - ); err != nil { - return fmt.Errorf("api.SendEvents: %w", err) - } - t.hadEvent(ev.EventID(), true) // if the roomserver didn't know about the event before, it does now - t.cacheAndReturn(ev.Headered(stateResp.RoomVersion)) - delete(missingAuthEvents, missingAuthEventID) - continue withNextEvent - } - } - - if missing := len(missingAuthEvents); missing > 0 { - return fmt.Errorf("event refers to %d auth_events which we failed to fetch", missing) - } - return nil -} - -func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserverlib.Event) error { - authUsingState := gomatrixserverlib.NewAuthEvents(nil) - for i := range stateEvents { - err := authUsingState.AddEvent(stateEvents[i]) - if err != nil { - return err - } - } - return gomatrixserverlib.Allowed(e, &authUsingState) -} - -func (t *txnReq) processEventWithMissingState( - ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, -) error { - // We are missing the previous events for this events. - // This means that there is a gap in our view of the history of the - // room. There two ways that we can handle such a gap: - // 1) We can fill in the gap using /get_missing_events - // 2) We can leave the gap and request the state of the room at - // this event from the remote server using either /state_ids - // or /state. - // Synapse will attempt to do 1 and if that fails or if the gap is - // too large then it will attempt 2. - // Synapse will use /state_ids if possible since usually the state - // is largely unchanged and it is more efficient to fetch a list of - // event ids and then use /event to fetch the individual events. - // However not all version of synapse support /state_ids so you may - // need to fallback to /state. - - // Attempt to fill in the gap using /get_missing_events - // This will either: - // - fill in the gap completely then process event `e` returning no backwards extremity - // - fail to fill in the gap and tell us to terminate the transaction err=not nil - // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction - newEvents, err := t.getMissingEvents(ctx, e, roomVersion) - if err != nil { - return err - } - if len(newEvents) == 0 { - return nil - } - - backwardsExtremity := newEvents[0] - newEvents = newEvents[1:] - - type respState struct { - // A snapshot is considered trustworthy if it came from our own roomserver. - // That's because the state will have been through state resolution once - // already in QueryStateAfterEvent. - trustworthy bool - *gomatrixserverlib.RespState - } - - // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. - // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query - // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. - var states []*respState - for _, prevEventID := range backwardsExtremity.PrevEventIDs() { - // Look up what the state is after the backward extremity. This will either - // come from the roomserver, if we know all the required events, or it will - // come from a remote server via /state_ids if not. - prevState, trustworthy, lerr := t.lookupStateAfterEvent(ctx, roomVersion, backwardsExtremity.RoomID(), prevEventID) - if lerr != nil { - util.GetLogger(ctx).WithError(lerr).Errorf("Failed to lookup state after prev_event: %s", prevEventID) - return lerr - } - // Append the state onto the collected state. We'll run this through the - // state resolution next. - states = append(states, &respState{trustworthy, prevState}) - } - - // Now that we have collected all of the state from the prev_events, we'll - // run the state through the appropriate state resolution algorithm for the - // room if needed. This does a couple of things: - // 1. Ensures that the state is deduplicated fully for each state-key tuple - // 2. Ensures that we pick the latest events from both sets, in the case that - // one of the prev_events is quite a bit older than the others - resolvedState := &gomatrixserverlib.RespState{} - switch len(states) { - case 0: - extremityIsCreate := backwardsExtremity.Type() == gomatrixserverlib.MRoomCreate && backwardsExtremity.StateKeyEquals("") - if !extremityIsCreate { - // There are no previous states and this isn't the beginning of the - // room - this is an error condition! - util.GetLogger(ctx).Errorf("Failed to lookup any state after prev_events") - return fmt.Errorf("expected %d states but got %d", len(backwardsExtremity.PrevEventIDs()), len(states)) - } - case 1: - // There's only one previous state - if it's trustworthy (came from a - // local state snapshot which will already have been through state res), - // use it as-is. There's no point in resolving it again. - if states[0].trustworthy { - resolvedState = states[0].RespState - break - } - // Otherwise, if it isn't trustworthy (came from federation), run it through - // state resolution anyway for safety, in case there are duplicates. - fallthrough - default: - respStates := make([]*gomatrixserverlib.RespState, len(states)) - for i := range states { - respStates[i] = states[i].RespState - } - // There's more than one previous state - run them all through state res - t.roomsMu.Lock(e.RoomID()) - resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, backwardsExtremity) - t.roomsMu.Unlock(e.RoomID()) - if err != nil { - util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) - return err - } - } - - // First of all, send the backward extremity into the roomserver with the - // newly resolved state. This marks the "oldest" point in the backfill and - // sets the baseline state for any new events after this. We'll make a - // copy of the hadEvents map so that it can be taken downstream without - // worrying about concurrent map reads/writes, since t.hadEvents is meant - // to be protected by a mutex. - hadEvents := map[string]bool{} - t.hadEventsMutex.Lock() - for k, v := range t.hadEvents { - hadEvents[k] = v - } - t.hadEventsMutex.Unlock() - err = api.SendEventWithState( - context.Background(), - t.rsAPI, - api.KindOld, - resolvedState, - backwardsExtremity.Headered(roomVersion), - hadEvents, - // Send the events to the roomserver asynchronously, so they will be - // processed when the roomserver is able, without blocking here. - true, - ) - if err != nil { - return fmt.Errorf("api.SendEventWithState: %w", err) - } - - // Then send all of the newer backfilled events, of which will all be newer - // than the backward extremity, into the roomserver without state. This way - // they will automatically fast-forward based on the room state at the - // extremity in the last step. - headeredNewEvents := make([]*gomatrixserverlib.HeaderedEvent, len(newEvents)) - for i, newEvent := range newEvents { - headeredNewEvents[i] = newEvent.Headered(roomVersion) - } - if err = api.SendEvents( - context.Background(), - t.rsAPI, - api.KindOld, - append(headeredNewEvents, e.Headered(roomVersion)), - api.DoNotSendToOtherServers, - nil, - // Send the events to the roomserver asynchronously, so they will be - // processed when the roomserver is able, without blocking here. - true, - ); err != nil { - return fmt.Errorf("api.SendEvents: %w", err) - } - - return nil -} - -// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) -// added into the mix. -func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*gomatrixserverlib.RespState, bool, error) { - // try doing all this locally before we resort to querying federation - respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID) - if respState != nil { - return respState, true, nil - } - - respState, err := t.lookupStateBeforeEvent(ctx, roomVersion, roomID, eventID) - if err != nil { - return nil, false, fmt.Errorf("t.lookupStateBeforeEvent: %w", err) - } - - // fetch the event we're missing and add it to the pile - h, err := t.lookupEvent(ctx, roomVersion, roomID, eventID, false) - switch err.(type) { - case verifySigError: - return respState, false, nil - case nil: - // do nothing - default: - return nil, false, fmt.Errorf("t.lookupEvent: %w", err) - } - h = t.cacheAndReturn(h) - if h.StateKey() != nil { - addedToState := false - for i := range respState.StateEvents { - se := respState.StateEvents[i] - if se.Type() == h.Type() && se.StateKeyEquals(*h.StateKey()) { - respState.StateEvents[i] = h.Unwrap() - addedToState = true - break - } - } - if !addedToState { - respState.StateEvents = append(respState.StateEvents, h.Unwrap()) - } - } - - return respState, false, nil -} - -func (t *txnReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *gomatrixserverlib.HeaderedEvent { - t.haveEventsMutex.Lock() - defer t.haveEventsMutex.Unlock() - if cached, exists := t.haveEvents[ev.EventID()]; exists { - return cached - } - t.haveEvents[ev.EventID()] = ev - return ev -} - -func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *gomatrixserverlib.RespState { - var res api.QueryStateAfterEventsResponse - err := t.rsAPI.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{ - RoomID: roomID, - PrevEventIDs: []string{eventID}, - }, &res) - if err != nil || !res.PrevEventsExist { - util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to query state after %s locally, prev exists=%v", eventID, res.PrevEventsExist) - return nil - } - stateEvents := make([]*gomatrixserverlib.HeaderedEvent, len(res.StateEvents)) - for i, ev := range res.StateEvents { - // set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this - // processEvent request, which is better for memory. - stateEvents[i] = t.cacheAndReturn(ev) - t.hadEvent(ev.EventID(), true) - } - // we should never access res.StateEvents again so we delete it here to make GC faster - res.StateEvents = nil - - var authEvents []*gomatrixserverlib.Event - missingAuthEvents := map[string]bool{} - for _, ev := range stateEvents { - t.haveEventsMutex.Lock() - for _, ae := range ev.AuthEventIDs() { - if aev, ok := t.haveEvents[ae]; ok { - authEvents = append(authEvents, aev.Unwrap()) - } else { - missingAuthEvents[ae] = true - } - } - t.haveEventsMutex.Unlock() - } - // QueryStateAfterEvents does not return the auth events, so fetch them now. We know the roomserver has them else it wouldn't - // have stored the event. - if len(missingAuthEvents) > 0 { - var missingEventList []string - for evID := range missingAuthEvents { - missingEventList = append(missingEventList, evID) - } - queryReq := api.QueryEventsByIDRequest{ - EventIDs: missingEventList, - } - util.GetLogger(ctx).WithField("count", len(missingEventList)).Infof("Fetching missing auth events") - var queryRes api.QueryEventsByIDResponse - if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { - return nil - } - for i, ev := range queryRes.Events { - authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap()) - t.hadEvent(ev.EventID(), true) - } - queryRes.Events = nil - } - - return &gomatrixserverlib.RespState{ - StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateEvents), - AuthEvents: authEvents, - } -} - -// lookuptStateBeforeEvent returns the room state before the event e, which is just /state_ids and/or /state depending on what -// the server supports. -func (t *txnReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( - *gomatrixserverlib.RespState, error) { - - // Attempt to fetch the missing state using /state_ids and /events - return t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion) -} - -func (t *txnReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*gomatrixserverlib.RespState, backwardsExtremity *gomatrixserverlib.Event) (*gomatrixserverlib.RespState, error) { - var authEventList []*gomatrixserverlib.Event - var stateEventList []*gomatrixserverlib.Event - for _, state := range states { - authEventList = append(authEventList, state.AuthEvents...) - stateEventList = append(stateEventList, state.StateEvents...) - } - resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(roomVersion, stateEventList, authEventList) - if err != nil { - return nil, err - } - // apply the current event -retryAllowedState: - if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil { - switch missing := err.(type) { - case gomatrixserverlib.MissingAuthEventError: - h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) - switch err2.(type) { - case verifySigError: - return &gomatrixserverlib.RespState{ - AuthEvents: authEventList, - StateEvents: resolvedStateEvents, - }, nil - case nil: - // do nothing - default: - return nil, fmt.Errorf("missing auth event %s and failed to look it up: %w", missing.AuthEventID, err2) - } - util.GetLogger(ctx).Infof("fetched event %s", missing.AuthEventID) - resolvedStateEvents = append(resolvedStateEvents, h.Unwrap()) - goto retryAllowedState - default: - } - return nil, err - } - return &gomatrixserverlib.RespState{ - AuthEvents: authEventList, - StateEvents: resolvedStateEvents, - }, nil -} - -func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, err error) { - logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) - needed := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{e}) - // query latest events (our trusted forward extremities) - req := api.QueryLatestEventsAndStateRequest{ - RoomID: e.RoomID(), - StateToFetch: needed.Tuples(), - } - var res api.QueryLatestEventsAndStateResponse - if err = t.rsAPI.QueryLatestEventsAndState(ctx, &req, &res); err != nil { - logger.WithError(err).Warn("Failed to query latest events") - return nil, err - } - latestEvents := make([]string, len(res.LatestEvents)) - for i, ev := range res.LatestEvents { - latestEvents[i] = res.LatestEvents[i].EventID - t.hadEvent(ev.EventID, true) - } - - var missingResp *gomatrixserverlib.RespMissingEvents - servers := t.getServers(ctx, e.RoomID(), e) - for _, server := range servers { - var m gomatrixserverlib.RespMissingEvents - if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ - Limit: 20, - // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. - EarliestEvents: latestEvents, - // The event IDs to retrieve the previous events for. - LatestEvents: []string{e.EventID()}, - }, roomVersion); err == nil { - missingResp = &m - break - } else { - logger.WithError(err).Errorf("%s pushed us an event but %q did not respond to /get_missing_events", t.Origin, server) - if errors.Is(err, context.DeadlineExceeded) { - break - } - } - } - - if missingResp == nil { - logger.WithError(err).Errorf( - "%s pushed us an event but %d server(s) couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", - t.Origin, len(servers), - ) - return nil, missingPrevEventsError{ - eventID: e.EventID(), - err: err, - } - } - - // security: how we handle failures depends on whether or not this event will become the new forward extremity for the room. - // There's 2 scenarios to consider: - // - Case A: We got pushed an event and are now fetching missing prev_events. (isInboundTxn=true) - // - Case B: We are fetching missing prev_events already and now fetching some more (isInboundTxn=false) - // In Case B, we know for sure that the event we are currently processing will not become the new forward extremity for the room, - // as it was called in response to an inbound txn which had it as a prev_event. - // In Case A, the event is a forward extremity, and could eventually become the _only_ forward extremity in the room. This is bad - // because it means we would trust the state at that event to be the state for the entire room, and allows rooms to be hijacked. - // https://github.com/matrix-org/synapse/pull/3456 - // https://github.com/matrix-org/synapse/blob/229eb81498b0fe1da81e9b5b333a0285acde9446/synapse/handlers/federation.py#L335 - // For now, we do not allow Case B, so reject the event. - logger.Infof("get_missing_events returned %d events", len(missingResp.Events)) - - // Make sure events from the missingResp are using the cache - missing events - // will be added and duplicates will be removed. - for i, ev := range missingResp.Events { - missingResp.Events[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() - } - - // topologically sort and sanity check that we are making forward progress - newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents) - shouldHaveSomeEventIDs := e.PrevEventIDs() - hasPrevEvent := false -Event: - for _, pe := range shouldHaveSomeEventIDs { - for _, ev := range newEvents { - if ev.EventID() == pe { - hasPrevEvent = true - break Event - } - } - } - if !hasPrevEvent { - err = fmt.Errorf("called /get_missing_events but server %s didn't return any prev_events with IDs %v", t.Origin, shouldHaveSomeEventIDs) - logger.WithError(err).Errorf( - "%s pushed us an event but couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", - t.Origin, - ) - return nil, missingPrevEventsError{ - eventID: e.EventID(), - err: err, - } - } - - return newEvents, nil -} - -func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - respState *gomatrixserverlib.RespState, err error) { - state, err := t.federation.LookupState(ctx, t.Origin, roomID, eventID, roomVersion) - if err != nil { - return nil, err - } - // Check that the returned state is valid. - if err := state.Check(ctx, t.keys, nil); err != nil { - return nil, err - } - // Cache the results of this state lookup and deduplicate anything we already - // have in the cache, freeing up memory. - for i, ev := range state.AuthEvents { - state.AuthEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() - } - for i, ev := range state.StateEvents { - state.StateEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() - } - return &state, nil -} - -func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - *gomatrixserverlib.RespState, error) { - util.GetLogger(ctx).WithField("room_id", roomID).Infof("lookupMissingStateViaStateIDs %s", eventID) - // fetch the state event IDs at the time of the event - stateIDs, err := t.federation.LookupStateIDs(ctx, t.Origin, roomID, eventID) - if err != nil { - return nil, err - } - // work out which auth/state IDs are missing - wantIDs := append(stateIDs.StateEventIDs, stateIDs.AuthEventIDs...) - missing := make(map[string]bool) - var missingEventList []string - t.haveEventsMutex.Lock() - for _, sid := range wantIDs { - if _, ok := t.haveEvents[sid]; !ok { - if !missing[sid] { - missing[sid] = true - missingEventList = append(missingEventList, sid) - } - } - } - t.haveEventsMutex.Unlock() - - // fetch as many as we can from the roomserver - queryReq := api.QueryEventsByIDRequest{ - EventIDs: missingEventList, - } - var queryRes api.QueryEventsByIDResponse - if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { - return nil, err - } - for i, ev := range queryRes.Events { - queryRes.Events[i] = t.cacheAndReturn(queryRes.Events[i]) - t.hadEvent(ev.EventID(), true) - evID := queryRes.Events[i].EventID() - if missing[evID] { - delete(missing, evID) - } - } - queryRes.Events = nil // allow it to be GCed - - concurrentRequests := 8 - missingCount := len(missing) - util.GetLogger(ctx).WithField("room_id", roomID).WithField("event_id", eventID).Infof("lookupMissingStateViaStateIDs missing %d/%d events", missingCount, len(wantIDs)) - - // If over 50% of the auth/state events from /state_ids are missing - // then we'll just call /state instead, otherwise we'll just end up - // hammering the remote side with /event requests unnecessarily. - if missingCount > concurrentRequests && missingCount > len(wantIDs)/2 { - util.GetLogger(ctx).WithFields(logrus.Fields{ - "missing": missingCount, - "event_id": eventID, - "room_id": roomID, - "total_state": len(stateIDs.StateEventIDs), - "total_auth_events": len(stateIDs.AuthEventIDs), - }).Info("Fetching all state at event") - return t.lookupMissingStateViaState(ctx, roomID, eventID, roomVersion) - } - - if missingCount > 0 { - util.GetLogger(ctx).WithFields(logrus.Fields{ - "missing": missingCount, - "event_id": eventID, - "room_id": roomID, - "total_state": len(stateIDs.StateEventIDs), - "total_auth_events": len(stateIDs.AuthEventIDs), - "concurrent_requests": concurrentRequests, - }).Info("Fetching missing state at event") - - // Create a queue containing all of the missing event IDs that we want - // to retrieve. - pending := make(chan string, missingCount) - for missingEventID := range missing { - pending <- missingEventID - } - close(pending) - - // Define how many workers we should start to do this. - if missingCount < concurrentRequests { - concurrentRequests = missingCount - } - - // Create the wait group. - var fetchgroup sync.WaitGroup - fetchgroup.Add(concurrentRequests) - - // This is the only place where we'll write to t.haveEvents from - // multiple goroutines, and everywhere else is blocked on this - // synchronous function anyway. - var haveEventsMutex sync.Mutex - - // Define what we'll do in order to fetch the missing event ID. - fetch := func(missingEventID string) { - var h *gomatrixserverlib.HeaderedEvent - h, err = t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false) - switch err.(type) { - case verifySigError: - return - case nil: - break - default: - util.GetLogger(ctx).WithFields(logrus.Fields{ - "event_id": missingEventID, - "room_id": roomID, - }).Info("Failed to fetch missing event") - return - } - haveEventsMutex.Lock() - t.cacheAndReturn(h) - haveEventsMutex.Unlock() - } - - // Create the worker. - worker := func(ch <-chan string) { - defer fetchgroup.Done() - for missingEventID := range ch { - fetch(missingEventID) - } - } - - // Start the workers. - for i := 0; i < concurrentRequests; i++ { - go worker(pending) - } - - // Wait for the workers to finish. - fetchgroup.Wait() - } - - resp, err := t.createRespStateFromStateIDs(stateIDs) - return resp, err -} - -func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs) ( - *gomatrixserverlib.RespState, error) { // nolint:unparam - t.haveEventsMutex.Lock() - defer t.haveEventsMutex.Unlock() - - // create a RespState response using the response to /state_ids as a guide - respState := gomatrixserverlib.RespState{} - - for i := range stateIDs.StateEventIDs { - ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]] - if !ok { - logrus.Warnf("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i]) - continue - } - respState.StateEvents = append(respState.StateEvents, ev.Unwrap()) - } - for i := range stateIDs.AuthEventIDs { - ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]] - if !ok { - logrus.Warnf("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i]) - continue - } - respState.AuthEvents = append(respState.AuthEvents, ev.Unwrap()) - } - // We purposefully do not do auth checks on the returned events, as they will still - // be processed in the exact same way, just as a 'rejected' event - // TODO: Add a field to HeaderedEvent to indicate if the event is rejected. - return &respState, nil -} - -func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) { - if localFirst { - // fetch from the roomserver - queryReq := api.QueryEventsByIDRequest{ - EventIDs: []string{missingEventID}, - } - var queryRes api.QueryEventsByIDResponse - if err := t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { - util.GetLogger(ctx).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) - } else if len(queryRes.Events) == 1 { - return queryRes.Events[0], nil - } - } - var event *gomatrixserverlib.Event - found := false - servers := t.getServers(ctx, roomID, nil) - for _, serverName := range servers { - txn, err := t.federation.GetEvent(ctx, serverName, missingEventID) - if err != nil || len(txn.PDUs) == 0 { - util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("Failed to get missing /event for event ID") - if errors.Is(err, context.DeadlineExceeded) { - break - } - continue - } - event, err = gomatrixserverlib.NewEventFromUntrustedJSON(txn.PDUs[0], roomVersion) - if err != nil { - util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warnf("Transaction: Failed to parse event JSON of event") - continue - } - found = true - break - } - if !found { - util.GetLogger(ctx).WithField("event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(servers)) - return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(servers)) - } - if err := event.VerifyEventSignatures(ctx, t.keys); err != nil { - util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) - return nil, verifySigError{event.EventID(), err} - } - return t.cacheAndReturn(event.Headered(roomVersion)), nil -} diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 70288461..f1f6169d 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "reflect" "testing" "time" @@ -244,8 +243,6 @@ func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederat eduAPI: &testEDUProducer{}, keys: &test.NopJSONVerifier{}, federation: fedClient, - haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), - hadEvents: make(map[string]bool), roomsMu: internal.NewMutexByRoom(), } t.PDUs = pdus @@ -279,6 +276,7 @@ NextPDU: } } +/* func fromStateTuples(tuples []gomatrixserverlib.StateKeyTuple, omitTuples []gomatrixserverlib.StateKeyTuple) (result []*gomatrixserverlib.HeaderedEvent) { NextTuple: for _, t := range tuples { @@ -294,6 +292,7 @@ NextTuple: } return } +*/ func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []*gomatrixserverlib.HeaderedEvent) { for _, g := range got { @@ -355,6 +354,7 @@ func TestTransactionFailAuthChecks(t *testing.T) { // we request them from /get_missing_events. It works by setting PrevEventsExist=false in the roomserver query response, // resulting in a call to /get_missing_events which returns the missing prev event. Both events should be processed in // topological order and sent to the roomserver. +/* func TestTransactionFetchMissingPrevEvents(t *testing.T) { haveEvent := testEvents[len(testEvents)-3] prevEvent := testEvents[len(testEvents)-2] @@ -619,3 +619,4 @@ func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { mustProcessTransaction(t, txn, nil) assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{eventB, eventC, eventD}) } +*/ diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index fb919a59..b16c68d2 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -89,7 +89,7 @@ func CreateInvitesFrom3PIDInvites( } // Send all the events - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, evs, cfg.Matrix.ServerName, nil, false); err != nil { + if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, evs, "TODO", cfg.Matrix.ServerName, nil, false); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -178,6 +178,7 @@ func ExchangeThirdPartyInvite( []*gomatrixserverlib.HeaderedEvent{ signedEvent.Event.Headered(verRes.RoomVersion), }, + request.Origin(), cfg.Matrix.ServerName, nil, false, diff --git a/go.mod b/go.mod index a0af0c4e..f8f3e8a1 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220124102425-f3e2ef8d8e59 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220125141909-d6fd2b28b8e8 github.com/matrix-org/pinecone v0.0.0-20211216094739-095c5ea64d02 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.8 diff --git a/go.sum b/go.sum index 3af2715a..f87a83c4 100644 --- a/go.sum +++ b/go.sum @@ -990,8 +990,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220124102425-f3e2ef8d8e59 h1:KtXMLsXeSRx/pPq0+HTDmM+J+WTxwzt+3O17xA3u0WY= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220124102425-f3e2ef8d8e59/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220125141909-d6fd2b28b8e8 h1:v57j5jbSBgY27COjgqAtYPVX2uxxPJP/2hI3uOPCz6M= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220125141909-d6fd2b28b8e8/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= github.com/matrix-org/pinecone v0.0.0-20211216094739-095c5ea64d02 h1:tLn95Nqq3KPOZAjogGZTKMEkn4mMIzKu09biRTz/Ack= github.com/matrix-org/pinecone v0.0.0-20211216094739-095c5ea64d02/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= diff --git a/roomserver/api/input.go b/roomserver/api/input.go index a537e64e..4b0704b9 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -54,12 +54,8 @@ type InputRoomEvent struct { Kind Kind `json:"kind"` // The event JSON for the event to add. Event *gomatrixserverlib.HeaderedEvent `json:"event"` - // List of state event IDs that authenticate this event. - // These are likely derived from the "auth_events" JSON key of the event. - // But can be different because the "auth_events" key can be incomplete or wrong. - // For example many matrix events forget to reference the m.room.create event even though it is needed for auth. - // (since synapse allows this to happen we have to allow it as well.) - AuthEventIDs []string `json:"auth_event_ids"` + // Which server told us about this event. + Origin gomatrixserverlib.ServerName `json:"origin"` // Whether the state is supplied as a list of event IDs or whether it // should be derived from the state at the previous events. HasState bool `json:"has_state"` diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index cdb186c0..e9b94e48 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -26,6 +26,7 @@ import ( func SendEvents( ctx context.Context, rsAPI RoomserverInternalAPI, kind Kind, events []*gomatrixserverlib.HeaderedEvent, + origin gomatrixserverlib.ServerName, sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, async bool, ) error { @@ -34,7 +35,7 @@ func SendEvents( ires[i] = InputRoomEvent{ Kind: kind, Event: event, - AuthEventIDs: event.AuthEventIDs(), + Origin: origin, SendAsServer: string(sendAsServer), TransactionID: txnID, } @@ -48,7 +49,7 @@ func SendEvents( func SendEventWithState( ctx context.Context, rsAPI RoomserverInternalAPI, kind Kind, state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, - haveEventIDs map[string]bool, async bool, + origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, ) error { outliers, err := state.Events() if err != nil { @@ -61,9 +62,9 @@ func SendEventWithState( continue } ires = append(ires, InputRoomEvent{ - Kind: KindOutlier, - Event: outlier.Headered(event.RoomVersion), - AuthEventIDs: outlier.AuthEventIDs(), + Kind: KindOutlier, + Event: outlier.Headered(event.RoomVersion), + Origin: origin, }) } @@ -75,7 +76,7 @@ func SendEventWithState( ires = append(ires, InputRoomEvent{ Kind: kind, Event: event, - AuthEventIDs: event.AuthEventIDs(), + Origin: origin, HasState: true, StateEventIDs: stateEventIDs, }) diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index cf2e59c6..5b87e623 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -37,8 +37,11 @@ type RoomserverInternalAPI struct { Cache caching.RoomServerCaches ServerName gomatrixserverlib.ServerName KeyRing gomatrixserverlib.JSONVerifier + ServerACLs *acls.ServerACLs fsAPI fsAPI.FederationInternalAPI asAPI asAPI.AppServiceQueryAPI + JetStream nats.JetStreamContext + Durable nats.SubOpt InputRoomEventTopic string // JetStream topic for new input room events OutputRoomEventTopic string // JetStream topic for new output room events PerspectiveServerNames []gomatrixserverlib.ServerName @@ -56,21 +59,17 @@ func NewRoomserverAPI( Cache: caches, ServerName: cfg.Matrix.ServerName, PerspectiveServerNames: perspectiveServerNames, + InputRoomEventTopic: inputRoomEventTopic, + OutputRoomEventTopic: outputRoomEventTopic, + JetStream: consumer, + Durable: cfg.Matrix.JetStream.Durable("RoomserverInputConsumer"), + ServerACLs: serverACLs, Queryer: &query.Queryer{ DB: roomserverDB, Cache: caches, ServerName: cfg.Matrix.ServerName, ServerACLs: serverACLs, }, - Inputer: &input.Inputer{ - DB: roomserverDB, - InputRoomEventTopic: inputRoomEventTopic, - OutputRoomEventTopic: outputRoomEventTopic, - JetStream: consumer, - Durable: cfg.Matrix.JetStream.Durable("RoomserverInputConsumer"), - ServerName: cfg.Matrix.ServerName, - ACLs: serverACLs, - }, // perform-er structs get initialised when we have a federation sender to use } return a @@ -83,6 +82,18 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA r.fsAPI = fsAPI r.KeyRing = keyRing + r.Inputer = &input.Inputer{ + DB: r.DB, + InputRoomEventTopic: r.InputRoomEventTopic, + OutputRoomEventTopic: r.OutputRoomEventTopic, + JetStream: r.JetStream, + Durable: r.Durable, + ServerName: r.Cfg.Matrix.ServerName, + FSAPI: fsAPI, + KeyRing: keyRing, + ACLs: r.ServerACLs, + Queryer: r.Queryer, + } r.Inviter = &perform.Inviter{ DB: r.DB, Cfg: r.Cfg, diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 1f4215e7..ddda8081 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -56,7 +56,7 @@ func CheckForSoftFail( // Then get the state entries for the current state snapshot. // We'll use this to check if the event is allowed right now. - roomState := state.NewStateResolution(db, *roomInfo) + roomState := state.NewStateResolution(db, roomInfo) authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) if err != nil { return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err) diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index a389cc89..78a875c7 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -179,7 +179,7 @@ func GetMembershipsAtState( return events, nil } -func StateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { +func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { roomState := state.NewStateResolution(db, info) // Lookup the event NID eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) @@ -223,7 +223,7 @@ func LoadStateEvents( } func CheckServerAllowedToSeeEvent( - ctx context.Context, db storage.Database, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, + ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, ) (bool, error) { roomState := state.NewStateResolution(db, info) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) @@ -279,7 +279,7 @@ func CheckServerAllowedToSeeEvent( // TODO: Remove this when we have tests to assert correctness of this function func ScanEventTree( - ctx context.Context, db storage.Database, info types.RoomInfo, front []string, visited map[string]bool, limit int, + ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int, serverName gomatrixserverlib.ServerName, ) ([]types.EventNID, error) { var resultNIDs []types.EventNID @@ -387,7 +387,7 @@ func QueryLatestEventsAndState( return nil } - roomState := state.NewStateResolution(db, *roomInfo) + roomState := state.NewStateResolution(db, roomInfo) response.RoomExists = true response.RoomVersion = roomInfo.RoomVersion diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 57e51055..9601e018 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -19,12 +19,15 @@ import ( "context" "encoding/json" "sync" + "time" "github.com/Arceliar/phony" "github.com/getsentry/sentry-go" + fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/gomatrixserverlib" @@ -45,12 +48,28 @@ type Inputer struct { JetStream nats.JetStreamContext Durable nats.SubOpt ServerName gomatrixserverlib.ServerName + FSAPI fedapi.FederationInternalAPI + KeyRing gomatrixserverlib.JSONVerifier ACLs *acls.ServerACLs InputRoomEventTopic string OutputRoomEventTopic string workers sync.Map // room ID -> *phony.Inbox + + Queryer *query.Queryer } +func (r *Inputer) workerForRoom(roomID string) *phony.Inbox { + inbox, _ := r.workers.LoadOrStore(roomID, &phony.Inbox{}) + return inbox.(*phony.Inbox) +} + +// eventsInProgress is an in-memory map to keep a track of which events we have +// queued up for processing. If we get a redelivery from NATS and we still have +// the queued up item then we won't do anything with the redelivered message. If +// we've restarted Dendrite and now this map is empty then it means that we will +// reload pending work from NATS. +var eventsInProgress sync.Map + // onMessage is called when a new event arrives in the roomserver input stream. func (r *Inputer) Start() error { _, err := r.JetStream.Subscribe( @@ -65,11 +84,23 @@ func (r *Inputer) Start() error { _ = msg.Term() return } - inbox, _ := r.workers.LoadOrStore(roomID, &phony.Inbox{}) + + _ = msg.InProgress() + index := roomID + "\000" + inputRoomEvent.Event.EventID() + if _, ok := eventsInProgress.LoadOrStore(index, struct{}{}); ok { + // We're already waiting to deal with this event, so there's no + // point in queuing it up again. We've notified NATS that we're + // working on the message still, so that will have deferred the + // redelivery by a bit. + return + } + roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Inc() - inbox.(*phony.Inbox).Act(nil, func() { + r.workerForRoom(roomID).Act(nil, func() { + _ = msg.InProgress() // resets the acknowledgement wait timer + defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - if err := r.processRoomEvent(context.TODO(), &inputRoomEvent); err != nil { + if err := r.processRoomEvent(context.Background(), &inputRoomEvent); err != nil { sentry.CaptureException(err) } else { hooks.Run(hooks.KindNewEventPersisted, inputRoomEvent.Event) @@ -82,12 +113,14 @@ func (r *Inputer) Start() error { // sure that we only acknowledge when we're happy we've done everything we // can. This ensures we retry things when it makes sense to do so. nats.ManualAck(), - // NATS will try to redeliver things to us automatically if we don't ack - // or nak them within a certain amount of time. This stops that from - // happening, so we don't end up doing a lot of unnecessary duplicate work. - nats.MaxDeliver(0), // Use a durable named consumer. r.Durable, + // If we've missed things in the stream, e.g. we restarted, then replay + // all of the queued messages that were waiting for us. + nats.DeliverAll(), + // Ensure that NATS doesn't try to resend us something that wasn't done + // within the period of time that we might still be processing it. + nats.AckWait(MaximumProcessingTime+(time.Second*10)), ) return err } @@ -122,11 +155,20 @@ func (r *Inputer) InputRoomEvents( for _, e := range request.InputRoomEvents { inputRoomEvent := e roomID := inputRoomEvent.Event.RoomID() - inbox, _ := r.workers.LoadOrStore(roomID, &phony.Inbox{}) + index := roomID + "\000" + inputRoomEvent.Event.EventID() + if _, ok := eventsInProgress.LoadOrStore(index, struct{}{}); ok { + // We're already waiting to deal with this event, so there's no + // point in queuing it up again. We've notified NATS that we're + // working on the message still, so that will have deferred the + // redelivery by a bit. + return + } roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Inc() - inbox.(*phony.Inbox).Act(nil, func() { + worker := r.workerForRoom(roomID) + worker.Act(nil, func() { + defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - err := r.processRoomEvent(context.TODO(), &inputRoomEvent) + err := r.processRoomEvent(ctx, &inputRoomEvent) if err != nil { sentry.CaptureException(err) } else { @@ -142,6 +184,7 @@ func (r *Inputer) InputRoomEvents( for i := 0; i < len(request.InputRoomEvents); i++ { select { case <-ctx.Done(): + response.ErrMsg = context.DeadlineExceeded.Error() return case err := <-responses: if err != nil { diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 791f7f30..5f911522 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -22,6 +22,8 @@ import ( "fmt" "time" + fedapi "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" @@ -37,6 +39,9 @@ func init() { prometheus.MustRegister(processRoomEventDuration) } +// TODO: Does this value make sense? +const MaximumProcessingTime = time.Minute * 2 + var processRoomEventDuration = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: "dendrite", @@ -60,9 +65,25 @@ var processRoomEventDuration = prometheus.NewHistogramVec( // TODO: Break up function - we should probably do transaction ID checks before calling this. // nolint:gocyclo func (r *Inputer) processRoomEvent( - ctx context.Context, + inctx context.Context, input *api.InputRoomEvent, ) (err error) { + select { + case <-inctx.Done(): + // Before we do anything, make sure the context hasn't expired for this pending task. + // If it has then we'll give up straight away — it's probably a synchronous input + // request and the caller has already given up, but the inbox task was still queued. + return context.DeadlineExceeded + default: + } + + // Wrap the context with a time limit. We'll allow no more than MaximumProcessingTime for + // everything that we need to do for this event, or it's possible that we could end up wedging + // the roomserver for a very long time. + var cancel context.CancelFunc + ctx, cancel := context.WithTimeout(inctx, MaximumProcessingTime) + defer cancel() + // Measure how long it takes to process this event. started := time.Now() defer func() { @@ -75,6 +96,11 @@ func (r *Inputer) processRoomEvent( // Parse and validate the event JSON headered := input.Event event := headered.Unwrap() + logger := util.GetLogger(ctx).WithFields(logrus.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "type": event.Type(), + }) // if we have already got this event then do not process it again, if the input kind is an outlier. // Outliers contain no extra information which may warrant a re-processing. @@ -87,24 +113,67 @@ func (r *Inputer) processRoomEvent( switch idFormat { case gomatrixserverlib.EventIDFormatV1: if bytes.Equal(event.EventReference().EventSHA256, evs[0].EventReference().EventSHA256) { - util.GetLogger(ctx).WithField("event_id", event.EventID()).Infof("Already processed event; ignoring") + logger.Debugf("Already processed event; ignoring") return nil } default: - util.GetLogger(ctx).WithField("event_id", event.EventID()).Infof("Already processed event; ignoring") + logger.Debugf("Already processed event; ignoring") return nil } } } } - // Check that the event passes authentication checks and work out - // the numeric IDs for the auth events. + missingRes := &api.QueryMissingAuthPrevEventsResponse{} + serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{} + if event.Type() != gomatrixserverlib.MRoomCreate || !event.StateKeyEquals("") { + missingReq := &api.QueryMissingAuthPrevEventsRequest{ + RoomID: event.RoomID(), + AuthEventIDs: event.AuthEventIDs(), + PrevEventIDs: event.PrevEventIDs(), + } + if err = r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil { + return fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err) + } + } + if len(missingRes.MissingAuthEventIDs) > 0 || len(missingRes.MissingPrevEventIDs) > 0 { + serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ + RoomID: event.RoomID(), + ExcludeSelf: true, + } + if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { + return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) + } + } + if input.Origin != "" { + serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) + } + + // First of all, check that the auth events of the event are known. + // If they aren't then we will ask the federation API for them. isRejected := false - authEventNIDs, rejectionErr := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) - if rejectionErr != nil { - logrus.WithError(rejectionErr).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("helpers.CheckAuthEvents failed for event, rejecting event") + authEvents := gomatrixserverlib.NewAuthEvents(nil) + knownEvents := map[string]*types.Event{} + if err = r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + return fmt.Errorf("r.checkForMissingAuthEvents: %w", err) + } + + // Check if the event is allowed by its auth events. If it isn't then + // we consider the event to be "rejected" — it will still be persisted. + var rejectionErr error + if rejectionErr = gomatrixserverlib.Allowed(event, &authEvents); rejectionErr != nil { isRejected = true + logger.WithError(rejectionErr).Warnf("Event %s rejected", event.EventID()) + } + + // Accumulate the auth event NIDs. + authEventIDs := event.AuthEventIDs() + authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) + for _, authEventID := range authEventIDs { + if _, ok := knownEvents[authEventID]; !ok { + return fmt.Errorf("missing auth event %s", authEventID) + } + authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) } var softfail bool @@ -113,11 +182,50 @@ func (r *Inputer) processRoomEvent( // current room state. softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) if err != nil { - logrus.WithFields(logrus.Fields{ - "event_id": event.EventID(), - "type": event.Type(), - "room": event.RoomID(), - }).WithError(err).Info("Error authing soft-failed event") + logger.WithError(err).Info("Error authing soft-failed event") + } + } + + // At this point we are checking whether we know all of the prev events, and + // if we know the state before the prev events. This is necessary before we + // try to do `calculateAndSetState` on the event later, otherwise it will fail + // with missing event NIDs. If there's anything missing then we'll go and fetch + // the prev events and state from the federation. Note that we only do this if + // we weren't already told what the state before the event should be — if the + // HasState option was set and a state set was provided (as is the case in a + // typical federated room join) then we won't bother trying to fetch prev events + // because we may not be allowed to see them and we have no choice but to trust + // the state event IDs provided to us in the join instead. + missingPrev := !input.HasState && len(missingRes.MissingPrevEventIDs) > 0 + if missingPrev && input.Kind == api.KindNew { + // Don't do this for KindOld events, otherwise old events that we fetch + // to satisfy missing prev events/state will end up recursively calling + // processRoomEvent. + if len(serverRes.ServerNames) > 0 { + missingState := missingStateReq{ + origin: input.Origin, + inputer: r, + queryer: r.Queryer, + db: r.DB, + federation: r.FSAPI, + keys: r.KeyRing, + roomsMu: internal.NewMutexByRoom(), + servers: map[gomatrixserverlib.ServerName]struct{}{}, + hadEvents: map[string]bool{}, + haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{}, + } + for _, serverName := range serverRes.ServerNames { + missingState.servers[serverName] = struct{}{} + } + if err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { + isRejected = true + rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err) + } else { + missingPrev = false + } + } else { + isRejected = true + rejectionErr = fmt.Errorf("missing prev events and no other servers to ask") } } @@ -140,12 +248,7 @@ func (r *Inputer) processRoomEvent( // doesn't have any associated state to store and we don't need to // notify anyone about it. if input.Kind == api.KindOutlier { - logrus.WithFields(logrus.Fields{ - "event_id": event.EventID(), - "type": event.Type(), - "room": event.RoomID(), - "sender": event.Sender(), - }).Debug("Stored outlier") + logger.Debug("Stored outlier") return nil } @@ -157,24 +260,18 @@ func (r *Inputer) processRoomEvent( return fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID()) } - if stateAtEvent.BeforeStateSnapshotNID == 0 { + if !missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0 { // We haven't calculated a state for this event yet. // Lets calculate one. - err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event, isRejected) - if err != nil && input.Kind != api.KindOld { + err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected) + if err != nil { return fmt.Errorf("r.calculateAndSetState: %w", err) } } // We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it. if isRejected || softfail { - logrus.WithFields(logrus.Fields{ - "event_id": event.EventID(), - "type": event.Type(), - "room": event.RoomID(), - "soft_fail": softfail, - "sender": event.Sender(), - }).Debug("Stored rejected event") + logger.WithError(rejectionErr).WithField("soft_fail", softfail).Debug("Stored rejected event") return rejectionErr } @@ -228,10 +325,127 @@ func (r *Inputer) processRoomEvent( return nil } +// fetchAuthEvents will check to see if any of the +// auth events specified by the given event are unknown. If they are +// then we will go off and request them from the federation and then +// store them in the database. By the time this function ends, either +// we've failed to retrieve the auth chain altogether (in which case +// an error is returned) or we've successfully retrieved them all and +// they are now in the database. +func (r *Inputer) fetchAuthEvents( + ctx context.Context, + logger *logrus.Entry, + event *gomatrixserverlib.HeaderedEvent, + auth *gomatrixserverlib.AuthEvents, + known map[string]*types.Event, + servers []gomatrixserverlib.ServerName, +) error { + unknown := map[string]struct{}{} + authEventIDs := event.AuthEventIDs() + if len(authEventIDs) == 0 { + return nil + } + + for _, authEventID := range authEventIDs { + authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID}) + if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { + unknown[authEventID] = struct{}{} + continue + } + ev := authEvents[0] + known[authEventID] = &ev // don't take the pointer of the iterated event + if err = auth.AddEvent(ev.Event); err != nil { + return fmt.Errorf("auth.AddEvent: %w", err) + } + } + + // If there are no missing auth events then there is nothing more + // to do — we've loaded everything that we need. + if len(unknown) == 0 { + return nil + } + + var err error + var res gomatrixserverlib.RespEventAuth + var found bool + for _, serverName := range servers { + // Request the entire auth chain for the event in question. This should + // contain all of the auth events — including ones that we already know — + // so we'll need to filter through those in the next section. + res, err = r.FSAPI.GetEventAuth(ctx, serverName, event.RoomVersion, event.RoomID(), event.EventID()) + if err != nil { + logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err) + continue + } + found = true + break + } + if !found { + return fmt.Errorf("no servers provided event auth for event ID %q, tried servers %v", event.EventID(), servers) + } + + for _, authEvent := range gomatrixserverlib.ReverseTopologicalOrdering( + res.AuthEvents, + gomatrixserverlib.TopologicalOrderByAuthEvents, + ) { + // If we already know about this event from the database then we don't + // need to store it again or do anything further with it, so just skip + // over it rather than wasting cycles. + if ev, ok := known[authEvent.EventID()]; ok && ev != nil { + continue + } + + // Check the signatures of the event. + // TODO: It really makes sense for the federation API to be doing this, + // because then it can attempt another server if one serves up an event + // with an invalid signature. For now this will do. + if err := authEvent.VerifyEventSignatures(ctx, r.FSAPI.KeyRing()); err != nil { + return fmt.Errorf("event.VerifyEventSignatures: %w", err) + } + + // In order to store the new auth event, we need to know its auth chain + // as NIDs for the `auth_event_nids` column. Let's see if we can find those. + authEventNIDs := make([]types.EventNID, 0, len(authEvent.AuthEventIDs())) + for _, eventID := range authEvent.AuthEventIDs() { + knownEvent, ok := known[eventID] + if !ok { + return fmt.Errorf("missing auth event %s for %s", eventID, authEvent.EventID()) + } + authEventNIDs = append(authEventNIDs, knownEvent.EventNID) + } + + // Let's take a note of the fact that we now know about this event. + if err := auth.AddEvent(authEvent); err != nil { + return fmt.Errorf("auth.AddEvent: %w", err) + } + + // Check if the auth event should be rejected. + isRejected := false + if err := gomatrixserverlib.Allowed(authEvent, auth); err != nil { + isRejected = true + logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) + } + + // Finally, store the event in the database. + eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) + if err != nil { + return fmt.Errorf("r.DB.StoreEvent: %w", err) + } + + // Now we know about this event, it was stored and the signatures were OK. + known[authEvent.EventID()] = &types.Event{ + EventNID: eventNID, + Event: authEvent, + } + } + + return nil +} + func (r *Inputer) calculateAndSetState( ctx context.Context, input *api.InputRoomEvent, - roomInfo types.RoomInfo, + roomInfo *types.RoomInfo, stateAtEvent *types.StateAtEvent, event *gomatrixserverlib.Event, isRejected bool, diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index c9264a27..6137941e 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -199,7 +199,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error - roomState := state.NewStateResolution(u.api.DB, *u.roomInfo) + roomState := state.NewStateResolution(u.api.DB, u.roomInfo) // Work out if the state at the extremities has actually changed // or not. If they haven't then we won't bother doing all of the diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go new file mode 100644 index 00000000..44710962 --- /dev/null +++ b/roomserver/internal/input/input_missing.go @@ -0,0 +1,765 @@ +package input + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + fedapi "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/query" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +type missingStateReq struct { + origin gomatrixserverlib.ServerName + db storage.Database + inputer *Inputer + queryer *query.Queryer + keys gomatrixserverlib.JSONVerifier + federation fedapi.FederationInternalAPI + roomsMu *internal.MutexByRoom + servers map[gomatrixserverlib.ServerName]struct{} + hadEvents map[string]bool + hadEventsMutex sync.Mutex + haveEvents map[string]*gomatrixserverlib.HeaderedEvent + haveEventsMutex sync.Mutex +} + +// processEventWithMissingState is the entrypoint for a missingStateReq +// request, as called from processRoomEvent. +func (t *missingStateReq) processEventWithMissingState( + ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, +) error { + // We are missing the previous events for this events. + // This means that there is a gap in our view of the history of the + // room. There two ways that we can handle such a gap: + // 1) We can fill in the gap using /get_missing_events + // 2) We can leave the gap and request the state of the room at + // this event from the remote server using either /state_ids + // or /state. + // Synapse will attempt to do 1 and if that fails or if the gap is + // too large then it will attempt 2. + // Synapse will use /state_ids if possible since usually the state + // is largely unchanged and it is more efficient to fetch a list of + // event ids and then use /event to fetch the individual events. + // However not all version of synapse support /state_ids so you may + // need to fallback to /state. + logger := util.GetLogger(ctx).WithFields(map[string]interface{}{ + "txn_event": e.EventID(), + "room_id": e.RoomID(), + "txn_prev_events": e.PrevEventIDs(), + }) + + // Attempt to fill in the gap using /get_missing_events + // This will either: + // - fill in the gap completely then process event `e` returning no backwards extremity + // - fail to fill in the gap and tell us to terminate the transaction err=not nil + // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction + newEvents, isGapFilled, err := t.getMissingEvents(ctx, e, roomVersion) + if err != nil { + return fmt.Errorf("t.getMissingEvents: %w", err) + } + if len(newEvents) == 0 { + return fmt.Errorf("expected to find missing events but didn't") + } + if isGapFilled { + logger.Infof("gap filled by /get_missing_events, injecting %d new events", len(newEvents)) + // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled + // in the gap in the DAG + for _, newEvent := range newEvents { + err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + Kind: api.KindNew, + Event: newEvent.Headered(roomVersion), + Origin: t.origin, + SendAsServer: api.DoNotSendToOtherServers, + }) + if err != nil { + return fmt.Errorf("t.inputer.processRoomEvent: %w", err) + } + } + return nil + } + + backwardsExtremity := newEvents[0] + newEvents = newEvents[1:] + + type respState struct { + // A snapshot is considered trustworthy if it came from our own roomserver. + // That's because the state will have been through state resolution once + // already in QueryStateAfterEvent. + trustworthy bool + *gomatrixserverlib.RespState + } + + // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. + // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query + // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. + var states []*respState + for _, prevEventID := range backwardsExtremity.PrevEventIDs() { + // Look up what the state is after the backward extremity. This will either + // come from the roomserver, if we know all the required events, or it will + // come from a remote server via /state_ids if not. + prevState, trustworthy, lerr := t.lookupStateAfterEvent(ctx, roomVersion, backwardsExtremity.RoomID(), prevEventID) + if lerr != nil { + logger.WithError(lerr).Errorf("Failed to lookup state after prev_event: %s", prevEventID) + return lerr + } + // Append the state onto the collected state. We'll run this through the + // state resolution next. + states = append(states, &respState{trustworthy, prevState}) + } + + // Now that we have collected all of the state from the prev_events, we'll + // run the state through the appropriate state resolution algorithm for the + // room if needed. This does a couple of things: + // 1. Ensures that the state is deduplicated fully for each state-key tuple + // 2. Ensures that we pick the latest events from both sets, in the case that + // one of the prev_events is quite a bit older than the others + resolvedState := &gomatrixserverlib.RespState{} + switch len(states) { + case 0: + extremityIsCreate := backwardsExtremity.Type() == gomatrixserverlib.MRoomCreate && backwardsExtremity.StateKeyEquals("") + if !extremityIsCreate { + // There are no previous states and this isn't the beginning of the + // room - this is an error condition! + logger.Errorf("Failed to lookup any state after prev_events") + return fmt.Errorf("expected %d states but got %d", len(backwardsExtremity.PrevEventIDs()), len(states)) + } + case 1: + // There's only one previous state - if it's trustworthy (came from a + // local state snapshot which will already have been through state res), + // use it as-is. There's no point in resolving it again. + if states[0].trustworthy { + resolvedState = states[0].RespState + break + } + // Otherwise, if it isn't trustworthy (came from federation), run it through + // state resolution anyway for safety, in case there are duplicates. + fallthrough + default: + respStates := make([]*gomatrixserverlib.RespState, len(states)) + for i := range states { + respStates[i] = states[i].RespState + } + // There's more than one previous state - run them all through state res + t.roomsMu.Lock(e.RoomID()) + resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, backwardsExtremity) + t.roomsMu.Unlock(e.RoomID()) + if err != nil { + logger.WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) + return err + } + } + + hadEvents := map[string]bool{} + t.hadEventsMutex.Lock() + for k, v := range t.hadEvents { + hadEvents[k] = v + } + t.hadEventsMutex.Unlock() + + // Send outliers first so we can send the new backwards extremity without causing errors + outliers, err := resolvedState.Events() + if err != nil { + return err + } + var outlierRoomEvents []api.InputRoomEvent + for _, outlier := range outliers { + if hadEvents[outlier.EventID()] { + continue + } + outlierRoomEvents = append(outlierRoomEvents, api.InputRoomEvent{ + Kind: api.KindOutlier, + Event: outlier.Headered(roomVersion), + Origin: t.origin, + }) + } + // TODO: we could do this concurrently? + for _, ire := range outlierRoomEvents { + if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { + return fmt.Errorf("t.inputer.processRoomEvent[outlier]: %w", err) + } + } + + // Now send the backward extremity into the roomserver with the + // newly resolved state. This marks the "oldest" point in the backfill and + // sets the baseline state for any new events after this. + stateIDs := make([]string, 0, len(resolvedState.StateEvents)) + for _, event := range resolvedState.StateEvents { + stateIDs = append(stateIDs, event.EventID()) + } + + err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + Kind: api.KindOld, + Event: backwardsExtremity.Headered(roomVersion), + Origin: t.origin, + HasState: true, + StateEventIDs: stateIDs, + SendAsServer: api.DoNotSendToOtherServers, + }) + if err != nil { + return fmt.Errorf("t.inputer.processRoomEvent: %w", err) + } + + // Then send all of the newer backfilled events, of which will all be newer + // than the backward extremity, into the roomserver without state. This way + // they will automatically fast-forward based on the room state at the + // extremity in the last step. + for _, newEvent := range newEvents { + err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + Kind: api.KindOld, + Event: newEvent.Headered(roomVersion), + Origin: t.origin, + SendAsServer: api.DoNotSendToOtherServers, + }) + if err != nil { + return fmt.Errorf("t.inputer.processRoomEvent: %w", err) + } + } + + return nil +} + +// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) +// added into the mix. +func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*gomatrixserverlib.RespState, bool, error) { + // try doing all this locally before we resort to querying federation + respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID) + if respState != nil { + return respState, true, nil + } + + respState, err := t.lookupStateBeforeEvent(ctx, roomVersion, roomID, eventID) + if err != nil { + return nil, false, fmt.Errorf("t.lookupStateBeforeEvent: %w", err) + } + + // fetch the event we're missing and add it to the pile + h, err := t.lookupEvent(ctx, roomVersion, roomID, eventID, false) + switch err.(type) { + case verifySigError: + return respState, false, nil + case nil: + // do nothing + default: + return nil, false, fmt.Errorf("t.lookupEvent: %w", err) + } + h = t.cacheAndReturn(h) + if h.StateKey() != nil { + addedToState := false + for i := range respState.StateEvents { + se := respState.StateEvents[i] + if se.Type() == h.Type() && se.StateKeyEquals(*h.StateKey()) { + respState.StateEvents[i] = h.Unwrap() + addedToState = true + break + } + } + if !addedToState { + respState.StateEvents = append(respState.StateEvents, h.Unwrap()) + } + } + + return respState, false, nil +} + +func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *gomatrixserverlib.HeaderedEvent { + t.haveEventsMutex.Lock() + defer t.haveEventsMutex.Unlock() + if cached, exists := t.haveEvents[ev.EventID()]; exists { + return cached + } + t.haveEvents[ev.EventID()] = ev + return ev +} + +func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *gomatrixserverlib.RespState { + var res api.QueryStateAfterEventsResponse + err := t.queryer.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{ + RoomID: roomID, + PrevEventIDs: []string{eventID}, + }, &res) + if err != nil || !res.PrevEventsExist { + util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to query state after %s locally, prev exists=%v", eventID, res.PrevEventsExist) + return nil + } + stateEvents := make([]*gomatrixserverlib.HeaderedEvent, len(res.StateEvents)) + for i, ev := range res.StateEvents { + // set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this + // processEvent request, which is better for memory. + stateEvents[i] = t.cacheAndReturn(ev) + t.hadEvent(ev.EventID()) + } + // we should never access res.StateEvents again so we delete it here to make GC faster + res.StateEvents = nil + + var authEvents []*gomatrixserverlib.Event + missingAuthEvents := map[string]bool{} + for _, ev := range stateEvents { + t.haveEventsMutex.Lock() + for _, ae := range ev.AuthEventIDs() { + if aev, ok := t.haveEvents[ae]; ok { + authEvents = append(authEvents, aev.Unwrap()) + } else { + missingAuthEvents[ae] = true + } + } + t.haveEventsMutex.Unlock() + } + // QueryStateAfterEvents does not return the auth events, so fetch them now. We know the roomserver has them else it wouldn't + // have stored the event. + if len(missingAuthEvents) > 0 { + var missingEventList []string + for evID := range missingAuthEvents { + missingEventList = append(missingEventList, evID) + } + queryReq := api.QueryEventsByIDRequest{ + EventIDs: missingEventList, + } + util.GetLogger(ctx).WithField("count", len(missingEventList)).Infof("Fetching missing auth events") + var queryRes api.QueryEventsByIDResponse + if err = t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { + return nil + } + for i, ev := range queryRes.Events { + authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap()) + t.hadEvent(ev.EventID()) + } + queryRes.Events = nil + } + + return &gomatrixserverlib.RespState{ + StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateEvents), + AuthEvents: authEvents, + } +} + +// lookuptStateBeforeEvent returns the room state before the event e, which is just /state_ids and/or /state depending on what +// the server supports. +func (t *missingStateReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( + *gomatrixserverlib.RespState, error) { + + // Attempt to fetch the missing state using /state_ids and /events + return t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion) +} + +func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*gomatrixserverlib.RespState, backwardsExtremity *gomatrixserverlib.Event) (*gomatrixserverlib.RespState, error) { + var authEventList []*gomatrixserverlib.Event + var stateEventList []*gomatrixserverlib.Event + for _, state := range states { + authEventList = append(authEventList, state.AuthEvents...) + stateEventList = append(stateEventList, state.StateEvents...) + } + resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(roomVersion, stateEventList, authEventList) + if err != nil { + return nil, err + } + // apply the current event +retryAllowedState: + if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil { + switch missing := err.(type) { + case gomatrixserverlib.MissingAuthEventError: + h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) + switch err2.(type) { + case verifySigError: + return &gomatrixserverlib.RespState{ + AuthEvents: authEventList, + StateEvents: resolvedStateEvents, + }, nil + case nil: + // do nothing + default: + return nil, fmt.Errorf("missing auth event %s and failed to look it up: %w", missing.AuthEventID, err2) + } + util.GetLogger(ctx).Infof("fetched event %s", missing.AuthEventID) + resolvedStateEvents = append(resolvedStateEvents, h.Unwrap()) + goto retryAllowedState + default: + } + return nil, err + } + return &gomatrixserverlib.RespState{ + AuthEvents: authEventList, + StateEvents: resolvedStateEvents, + }, nil +} + +// get missing events for `e`. If `isGapFilled`=true then `newEvents` contains all the events to inject, +// without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events +func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled bool, err error) { + logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) + needed := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{e}) + // query latest events (our trusted forward extremities) + req := api.QueryLatestEventsAndStateRequest{ + RoomID: e.RoomID(), + StateToFetch: needed.Tuples(), + } + var res api.QueryLatestEventsAndStateResponse + if err = t.queryer.QueryLatestEventsAndState(ctx, &req, &res); err != nil { + logger.WithError(err).Warn("Failed to query latest events") + return nil, false, err + } + latestEvents := make([]string, len(res.LatestEvents)) + for i, ev := range res.LatestEvents { + latestEvents[i] = res.LatestEvents[i].EventID + t.hadEvent(ev.EventID) + } + + var missingResp *gomatrixserverlib.RespMissingEvents + for server := range t.servers { + var m gomatrixserverlib.RespMissingEvents + if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ + Limit: 20, + // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. + EarliestEvents: latestEvents, + // The event IDs to retrieve the previous events for. + LatestEvents: []string{e.EventID()}, + }, roomVersion); err == nil { + missingResp = &m + break + } else { + logger.WithError(err).Errorf("%s pushed us an event but %q did not respond to /get_missing_events", t.origin, server) + if errors.Is(err, context.DeadlineExceeded) { + select { + case <-ctx.Done(): // the parent request context timed out + return nil, false, context.DeadlineExceeded + default: // this request exceed its own timeout + continue + } + } + } + } + + if missingResp == nil { + logger.WithError(err).Errorf( + "%s pushed us an event but %d server(s) couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", + t.origin, len(t.servers), + ) + return nil, false, missingPrevEventsError{ + eventID: e.EventID(), + err: err, + } + } + + // Make sure events from the missingResp are using the cache - missing events + // will be added and duplicates will be removed. + logger.Infof("get_missing_events returned %d events", len(missingResp.Events)) + for i, ev := range missingResp.Events { + missingResp.Events[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + } + + // topologically sort and sanity check that we are making forward progress + newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents) + shouldHaveSomeEventIDs := e.PrevEventIDs() + hasPrevEvent := false +Event: + for _, pe := range shouldHaveSomeEventIDs { + for _, ev := range newEvents { + if ev.EventID() == pe { + hasPrevEvent = true + break Event + } + } + } + if !hasPrevEvent { + err = fmt.Errorf("called /get_missing_events but server %s didn't return any prev_events with IDs %v", t.origin, shouldHaveSomeEventIDs) + logger.WithError(err).Errorf( + "%s pushed us an event but couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", + t.origin, + ) + return nil, false, missingPrevEventsError{ + eventID: e.EventID(), + err: err, + } + } + if len(newEvents) == 0 { + return nil, false, nil // TODO: error instead? + } + + // now check if we can fill the gap. Look to see if we have state snapshot IDs for the earliest event + earliestNewEvent := newEvents[0] + if state, err := t.db.StateAtEventIDs(ctx, []string{earliestNewEvent.EventID()}); err != nil || len(state) == 0 { + if earliestNewEvent.Type() == gomatrixserverlib.MRoomCreate && earliestNewEvent.StateKeyEquals("") { + // we got to the beginning of the room so there will be no state! It's all good we can process this + return newEvents, true, nil + } + // we don't have the state at this earliest event from /g_m_e so we won't have state for later events either + return newEvents, false, nil + } + // StateAtEventIDs returned some kind of state for the earliest event so we can fill in the gap! + return newEvents, true, nil +} + +func (t *missingStateReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( + respState *gomatrixserverlib.RespState, err error) { + state, err := t.federation.LookupState(ctx, t.origin, roomID, eventID, roomVersion) + if err != nil { + return nil, err + } + // Check that the returned state is valid. + if err := state.Check(ctx, t.keys, nil); err != nil { + return nil, err + } + // Cache the results of this state lookup and deduplicate anything we already + // have in the cache, freeing up memory. + for i, ev := range state.AuthEvents { + state.AuthEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + } + for i, ev := range state.StateEvents { + state.StateEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + } + return &state, nil +} + +func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( + *gomatrixserverlib.RespState, error) { + util.GetLogger(ctx).WithField("room_id", roomID).Infof("lookupMissingStateViaStateIDs %s", eventID) + // fetch the state event IDs at the time of the event + stateIDs, err := t.federation.LookupStateIDs(ctx, t.origin, roomID, eventID) + if err != nil { + return nil, err + } + // work out which auth/state IDs are missing + wantIDs := append(stateIDs.StateEventIDs, stateIDs.AuthEventIDs...) + missing := make(map[string]bool) + var missingEventList []string + t.haveEventsMutex.Lock() + for _, sid := range wantIDs { + if _, ok := t.haveEvents[sid]; !ok { + if !missing[sid] { + missing[sid] = true + missingEventList = append(missingEventList, sid) + } + } + } + t.haveEventsMutex.Unlock() + + // fetch as many as we can from the roomserver + queryReq := api.QueryEventsByIDRequest{ + EventIDs: missingEventList, + } + var queryRes api.QueryEventsByIDResponse + if err = t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { + return nil, err + } + for i, ev := range queryRes.Events { + queryRes.Events[i] = t.cacheAndReturn(queryRes.Events[i]) + t.hadEvent(ev.EventID()) + evID := queryRes.Events[i].EventID() + if missing[evID] { + delete(missing, evID) + } + } + queryRes.Events = nil // allow it to be GCed + + concurrentRequests := 8 + missingCount := len(missing) + util.GetLogger(ctx).WithField("room_id", roomID).WithField("event_id", eventID).Infof("lookupMissingStateViaStateIDs missing %d/%d events", missingCount, len(wantIDs)) + + // If over 50% of the auth/state events from /state_ids are missing + // then we'll just call /state instead, otherwise we'll just end up + // hammering the remote side with /event requests unnecessarily. + if missingCount > concurrentRequests && missingCount > len(wantIDs)/2 { + util.GetLogger(ctx).WithFields(logrus.Fields{ + "missing": missingCount, + "event_id": eventID, + "room_id": roomID, + "total_state": len(stateIDs.StateEventIDs), + "total_auth_events": len(stateIDs.AuthEventIDs), + }).Info("Fetching all state at event") + return t.lookupMissingStateViaState(ctx, roomID, eventID, roomVersion) + } + + if missingCount > 0 { + util.GetLogger(ctx).WithFields(logrus.Fields{ + "missing": missingCount, + "event_id": eventID, + "room_id": roomID, + "total_state": len(stateIDs.StateEventIDs), + "total_auth_events": len(stateIDs.AuthEventIDs), + "concurrent_requests": concurrentRequests, + }).Info("Fetching missing state at event") + + // Create a queue containing all of the missing event IDs that we want + // to retrieve. + pending := make(chan string, missingCount) + for missingEventID := range missing { + pending <- missingEventID + } + close(pending) + + // Define how many workers we should start to do this. + if missingCount < concurrentRequests { + concurrentRequests = missingCount + } + + // Create the wait group. + var fetchgroup sync.WaitGroup + fetchgroup.Add(concurrentRequests) + + // This is the only place where we'll write to t.haveEvents from + // multiple goroutines, and everywhere else is blocked on this + // synchronous function anyway. + var haveEventsMutex sync.Mutex + + // Define what we'll do in order to fetch the missing event ID. + fetch := func(missingEventID string) { + var h *gomatrixserverlib.HeaderedEvent + h, err = t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false) + switch err.(type) { + case verifySigError: + return + case nil: + break + default: + util.GetLogger(ctx).WithFields(logrus.Fields{ + "event_id": missingEventID, + "room_id": roomID, + }).Info("Failed to fetch missing event") + return + } + haveEventsMutex.Lock() + t.cacheAndReturn(h) + haveEventsMutex.Unlock() + } + + // Create the worker. + worker := func(ch <-chan string) { + defer fetchgroup.Done() + for missingEventID := range ch { + fetch(missingEventID) + } + } + + // Start the workers. + for i := 0; i < concurrentRequests; i++ { + go worker(pending) + } + + // Wait for the workers to finish. + fetchgroup.Wait() + } + + resp, err := t.createRespStateFromStateIDs(stateIDs) + return resp, err +} + +func (t *missingStateReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs) ( + *gomatrixserverlib.RespState, error) { // nolint:unparam + t.haveEventsMutex.Lock() + defer t.haveEventsMutex.Unlock() + + // create a RespState response using the response to /state_ids as a guide + respState := gomatrixserverlib.RespState{} + + for i := range stateIDs.StateEventIDs { + ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]] + if !ok { + logrus.Warnf("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i]) + continue + } + respState.StateEvents = append(respState.StateEvents, ev.Unwrap()) + } + for i := range stateIDs.AuthEventIDs { + ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]] + if !ok { + logrus.Warnf("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i]) + continue + } + respState.AuthEvents = append(respState.AuthEvents, ev.Unwrap()) + } + // We purposefully do not do auth checks on the returned events, as they will still + // be processed in the exact same way, just as a 'rejected' event + // TODO: Add a field to HeaderedEvent to indicate if the event is rejected. + return &respState, nil +} + +func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) { + if localFirst { + // fetch from the roomserver + queryReq := api.QueryEventsByIDRequest{ + EventIDs: []string{missingEventID}, + } + var queryRes api.QueryEventsByIDResponse + if err := t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { + util.GetLogger(ctx).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) + } else if len(queryRes.Events) == 1 { + return queryRes.Events[0], nil + } + } + var event *gomatrixserverlib.Event + found := false + for serverName := range t.servers { + reqctx, cancel := context.WithTimeout(ctx, time.Second*30) + defer cancel() + txn, err := t.federation.GetEvent(reqctx, serverName, missingEventID) + if err != nil || len(txn.PDUs) == 0 { + util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("Failed to get missing /event for event ID") + if errors.Is(err, context.DeadlineExceeded) { + select { + case <-reqctx.Done(): // this server took too long + continue + case <-ctx.Done(): // the input request timed out + return nil, context.DeadlineExceeded + } + } + continue + } + event, err = gomatrixserverlib.NewEventFromUntrustedJSON(txn.PDUs[0], roomVersion) + if err != nil { + util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warnf("Transaction: Failed to parse event JSON of event") + continue + } + found = true + break + } + if !found { + util.GetLogger(ctx).WithField("event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) + return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) + } + if err := event.VerifyEventSignatures(ctx, t.keys); err != nil { + util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) + return nil, verifySigError{event.EventID(), err} + } + return t.cacheAndReturn(event.Headered(roomVersion)), nil +} + +func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserverlib.Event) error { + authUsingState := gomatrixserverlib.NewAuthEvents(nil) + for i := range stateEvents { + err := authUsingState.AddEvent(stateEvents[i]) + if err != nil { + return err + } + } + return gomatrixserverlib.Allowed(e, &authUsingState) +} + +func (t *missingStateReq) hadEvent(eventID string) { + t.hadEventsMutex.Lock() + defer t.hadEventsMutex.Unlock() + t.hadEvents[eventID] = true +} + +type verifySigError struct { + eventID string + err error +} +type missingPrevEventsError struct { + eventID string + err error +} + +func (e verifySigError) Error() string { + return fmt.Sprintf("unable to verify signature of event %q: %s", e.eventID, e.err) +} +func (e missingPrevEventsError) Error() string { + return fmt.Sprintf("unable to get prev_events for event %q: %s", e.eventID, e.err) +} diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index e198f67d..f3623de8 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -77,7 +77,7 @@ func (r *Backfiller) PerformBackfill( } // Scan the event tree for events to send back. - resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, *info, front, visited, request.Limit, request.ServerName) + resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) if err != nil { return err } @@ -418,7 +418,7 @@ FindSuccessor: return nil } - stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, *info, NIDs[eventID]) + stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID]) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") return nil diff --git a/roomserver/internal/perform/perform_inbound_peek.go b/roomserver/internal/perform/perform_inbound_peek.go index 98f5f6f9..d19fc838 100644 --- a/roomserver/internal/perform/perform_inbound_peek.go +++ b/roomserver/internal/perform/perform_inbound_peek.go @@ -79,7 +79,7 @@ func (r *InboundPeeker) PerformInboundPeek( response.LatestEvent = sortedLatestEvents[0].Headered(info.RoomVersion) // XXX: do we actually need to do a state resolution here? - roomState := state.NewStateResolution(r.DB, *info) + roomState := state.NewStateResolution(r.DB, info) var stateEntries []types.StateEntry stateEntries, err = roomState.LoadStateAtSnapshot( diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index ca065468..85b2322f 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -172,7 +172,7 @@ func (r *Inviter) PerformInvite( { Kind: api.KindNew, Event: event, - AuthEventIDs: event.AuthEventIDs(), + Origin: event.Origin(), SendAsServer: req.SendAsServer, }, }, @@ -231,7 +231,7 @@ func buildInviteStrippedState( StateKey: "", }) } - roomState := state.NewStateResolution(db, *info) + roomState := state.NewStateResolution(db, info) stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( ctx, info.StateSnapshotNID, stateWanted, ) diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 75397eb6..a1ffab5d 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -271,7 +271,6 @@ func (r *Joiner) performJoinRoomByID( { Kind: rsAPI.KindNew, Event: event.Headered(buildRes.RoomVersion), - AuthEventIDs: event.AuthEventIDs(), SendAsServer: string(r.Cfg.Matrix.ServerName), }, }, diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 4daeb10a..eac528ea 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -139,7 +139,7 @@ func (r *Leaver) performLeaveRoomByID( { Kind: api.KindNew, Event: event.Headered(buildRes.RoomVersion), - AuthEventIDs: event.AuthEventIDs(), + Origin: event.Origin(), SendAsServer: string(r.Cfg.Matrix.ServerName), }, }, diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 28b140c7..6b4cb581 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -63,7 +63,7 @@ func (r *Queryer) QueryStateAfterEvents( return nil } - roomState := state.NewStateResolution(r.DB, *info) + roomState := state.NewStateResolution(r.DB, info) response.RoomExists = true response.RoomVersion = info.RoomVersion @@ -294,7 +294,7 @@ func (r *Queryer) QueryMembershipsForRoom( events, err = r.DB.Events(ctx, eventNIDs) } else { - stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, *info, membershipEventNID) + stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID) if err != nil { logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") return err @@ -377,7 +377,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent( return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID) } response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent( - ctx, r.DB, *info, request.EventID, request.ServerName, inRoomRes.IsInRoom, + ctx, r.DB, info, request.EventID, request.ServerName, inRoomRes.IsInRoom, ) return } @@ -416,7 +416,7 @@ func (r *Queryer) QueryMissingEvents( return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID()) } - resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, *info, front, visited, request.Limit, request.ServerName) + resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) if err != nil { return err } @@ -473,7 +473,7 @@ func (r *Queryer) QueryStateAndAuthChain( } var stateEvents []*gomatrixserverlib.Event - stateEvents, err = r.loadStateAtEventIDs(ctx, *info, request.PrevEventIDs) + stateEvents, err = r.loadStateAtEventIDs(ctx, info, request.PrevEventIDs) if err != nil { return err } @@ -512,7 +512,7 @@ func (r *Queryer) QueryStateAndAuthChain( return err } -func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, error) { +func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, error) { roomState := state.NewStateResolution(r.DB, roomInfo) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 78398fc7..15d592b4 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -32,11 +32,11 @@ import ( type StateResolution struct { db storage.Database - roomInfo types.RoomInfo + roomInfo *types.RoomInfo events map[types.EventNID]*gomatrixserverlib.Event } -func NewStateResolution(db storage.Database, roomInfo types.RoomInfo) StateResolution { +func NewStateResolution(db storage.Database, roomInfo *types.RoomInfo) StateResolution { return StateResolution{ db: db, roomInfo: roomInfo, diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index c549fb65..778cd8d7 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -311,7 +311,9 @@ func (s *eventStatements) BulkSelectStateAtEventByID( ); err != nil { return nil, err } - if result.BeforeStateSnapshotNID == 0 { + // Genuine create events are the only case where it's OK to have no previous state. + isCreate := result.EventTypeNID == types.MRoomCreateNID && result.EventStateKeyNID == 1 + if result.BeforeStateSnapshotNID == 0 && !isCreate { return nil, types.MissingEventError( fmt.Sprintf("storage: missing state for event NID %d", result.EventNID), ) diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 3127eb17..7483e281 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -322,7 +322,9 @@ func (s *eventStatements) BulkSelectStateAtEventByID( ); err != nil { return nil, err } - if result.BeforeStateSnapshotNID == 0 { + // Genuine create events are the only case where it's OK to have no previous state. + isCreate := result.EventTypeNID == types.MRoomCreateNID && result.EventStateKeyNID == 1 + if result.BeforeStateSnapshotNID == 0 && !isCreate { return nil, types.MissingEventError( fmt.Sprintf("storage: missing state for event NID %d", result.EventNID), ) diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index 2d563226..1891b96b 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -3,6 +3,7 @@ package jetstream import "github.com/nats-io/nats.go" func WithJetStreamMessage(msg *nats.Msg, f func(msg *nats.Msg) bool) { + _ = msg.InProgress() if f(msg) { _ = msg.Ack() } else { diff --git a/setup/jetstream/streams.go b/setup/jetstream/streams.go index 0fd31082..5810a2a9 100644 --- a/setup/jetstream/streams.go +++ b/setup/jetstream/streams.go @@ -24,7 +24,7 @@ var ( var streams = []*nats.StreamConfig{ { Name: InputRoomEvent, - Retention: nats.InterestPolicy, + Retention: nats.WorkQueuePolicy, Storage: nats.FileStorage, }, { diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index e048d736..8a35e414 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -643,9 +643,8 @@ func (rc *reqCtx) injectResponseToRoomserver(res *gomatrixserverlib.MSC2836Event var ires []roomserver.InputRoomEvent for _, outlier := range append(eventsInOrder, messageEvents...) { ires = append(ires, roomserver.InputRoomEvent{ - Kind: roomserver.KindOutlier, - Event: outlier.Headered(outlier.Version()), - AuthEventIDs: outlier.AuthEventIDs(), + Kind: roomserver.KindOutlier, + Event: outlier.Headered(outlier.Version()), }) } // we've got the data by this point so use a background context diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 6b3ebe53..e9c4abe8 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -73,7 +73,11 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { - _, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) + _, err := s.jetstream.Subscribe( + s.topic, s.onMessage, s.durable, + nats.DeliverAll(), + nats.ManualAck(), + ) return err } diff --git a/sytest-blacklist b/sytest-blacklist index 5e562845..3e08f0cb 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -25,7 +25,6 @@ Local device key changes get to remote servers with correct prev_id # Flakey Local device key changes appear in /keys/changes Device list doesn't change if remote server is down -If a device list update goes missing, the server resyncs on the next one # we don't support groups Remove group category @@ -33,4 +32,3 @@ Remove group role # See https://github.com/matrix-org/sytest/pull/1142 Device list doesn't change if remote server is down -If a device list update goes missing, the server resyncs on the next one diff --git a/sytest-whitelist b/sytest-whitelist index f11cd96b..7d26c610 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -589,3 +589,4 @@ Remote user can backfill in a room with version 9 Can reject invites over federation for rooms with version 9 Can receive redactions from regular users over federation in room version 9 Forward extremities remain so even after the next events are populated as outliers +If a device list update goes missing, the server resyncs on the next one