From 617131030707aacd39f0f771626eaa5b8f88299c Mon Sep 17 00:00:00 2001 From: kegsay Date: Thu, 27 Apr 2023 16:35:19 +0100 Subject: [PATCH] Use PDU interface (#3070) We only use it in a few places currently, enough to get things to compile and run. We should be using it in much more places. Similarly, in some places we cast []PDU back to []*Event, we need to not do that. Likewise, in some places we cast PDU to *Event, we need to not do that. For now though, hopefully this is a start. --- clientapi/routing/sendevent.go | 2 +- cmd/resolve-state/main.go | 6 ++--- federationapi/internal/perform.go | 8 +++--- federationapi/routing/backfill.go | 2 +- federationapi/routing/join.go | 2 +- federationapi/routing/leave.go | 2 +- go.mod | 2 +- go.sum | 4 +++ internal/transactionrequest.go | 2 +- roomserver/api/wrapper.go | 3 ++- roomserver/internal/helpers/auth.go | 20 ++++++++------- roomserver/internal/input/input_events.go | 12 +++++---- roomserver/internal/input/input_missing.go | 25 +++++++++++-------- roomserver/internal/perform/perform_admin.go | 4 +-- .../internal/perform/perform_backfill.go | 22 ++++++++-------- .../internal/perform/perform_inbound_peek.go | 4 +-- .../internal/perform/perform_upgrade.go | 2 +- roomserver/internal/query/query.go | 15 +++++++---- roomserver/state/state.go | 10 ++++---- roomserver/storage/interface.go | 8 +++--- roomserver/storage/shared/storage.go | 6 ++--- setup/mscs/msc2836/msc2836.go | 4 +-- syncapi/streams/stream_pdu.go | 8 +++--- 23 files changed, 96 insertions(+), 77 deletions(-) diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 71dc6c40..0d01367d 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -317,7 +317,7 @@ func generateSendEvent( for i := range queryRes.StateEvents { stateEvents[i] = queryRes.StateEvents[i].Event } - provider := gomatrixserverlib.NewAuthEvents(stateEvents) + provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) if err = gomatrixserverlib.Allowed(e.Event, &provider); err != nil { return nil, &util.JSONResponse{ Code: http.StatusForbidden, diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index 09c0e690..b2f4afa8 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -183,8 +183,8 @@ func main() { var resolved Events resolved, err = gomatrixserverlib.ResolveConflicts( gomatrixserverlib.RoomVersion(*roomVersion), - events, - authEvents, + gomatrixserverlib.ToPDUs(events), + gomatrixserverlib.ToPDUs(authEvents), ) if err != nil { panic(err) @@ -208,7 +208,7 @@ func main() { fmt.Println("Returned", count, "state events after filtering") } -type Events []*gomatrixserverlib.Event +type Events []gomatrixserverlib.PDU func (e Events) Len() int { return len(e) diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 8882b5c1..fccea866 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -634,13 +634,13 @@ func federatedEventProvider( ) gomatrixserverlib.EventProvider { // A list of events that we have retried, if they were not included in // the auth events supplied in the send_join. - retries := map[string][]*gomatrixserverlib.Event{} + retries := map[string][]gomatrixserverlib.PDU{} // Define a function which we can pass to Check to retrieve missing // auth events inline. This greatly increases our chances of not having // to repeat the entire set of checks just for a missing event or two. - return func(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]*gomatrixserverlib.Event, error) { - returning := []*gomatrixserverlib.Event{} + return func(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.PDU, error) { + returning := []gomatrixserverlib.PDU{} verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion) if err != nil { return nil, err @@ -680,7 +680,7 @@ func federatedEventProvider( } // Check the signatures of the event. - if err := ev.VerifyEventSignatures(ctx, keyRing); err != nil { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing); err != nil { return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err) } diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index 40cb88fb..06685387 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -114,7 +114,7 @@ func Backfill( eventJSONs := []json.RawMessage{} for _, e := range gomatrixserverlib.ReverseTopologicalOrdering( - evs, + gomatrixserverlib.ToPDUs(evs), gomatrixserverlib.TopologicalOrderByPrevEvents, ) { eventJSONs = append(eventJSONs, e.JSON()) diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index eee0f3d9..a6a7511c 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -168,7 +168,7 @@ func MakeJoin( stateEvents[i] = queryRes.StateEvents[i].Event } - provider := gomatrixserverlib.NewAuthEvents(stateEvents) + provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) if err = gomatrixserverlib.Allowed(event.Event, &provider); err != nil { return util.JSONResponse{ Code: http.StatusForbidden, diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index ae7617fa..c9b13b98 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -114,7 +114,7 @@ func MakeLeave( for i := range queryRes.StateEvents { stateEvents[i] = queryRes.StateEvents[i].Event } - provider := gomatrixserverlib.NewAuthEvents(stateEvents) + provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) if err = gomatrixserverlib.Allowed(event.Event, &provider); err != nil { return util.JSONResponse{ Code: http.StatusForbidden, diff --git a/go.mod b/go.mod index 527320d2..04cd27fe 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230427113737-4a73af377afe + github.com/matrix-org/gomatrixserverlib v0.0.0-20230427151624-793f1829f540 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index 077b6a0f..f18bde80 100644 --- a/go.sum +++ b/go.sum @@ -331,6 +331,10 @@ github.com/matrix-org/gomatrixserverlib v0.0.0-20230427083830-f2324ed2e085 h1:QR github.com/matrix-org/gomatrixserverlib v0.0.0-20230427083830-f2324ed2e085/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230427113737-4a73af377afe h1:+FeGaWZCDw7w3DGZhQW3n0amZ4iW5at/ocErlOFrO58= github.com/matrix-org/gomatrixserverlib v0.0.0-20230427113737-4a73af377afe/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230427132341-80bb893dc05c h1:qr5H8hWq+qFQXefajLf843wnB5WhppWedWQchlqp6Tc= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230427132341-80bb893dc05c/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230427151624-793f1829f540 h1:T+8YYREEIKM7QFcmOFvh3hv1gPvN0l8LCI+goNfeMO0= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230427151624-793f1829f540/go.mod h1:7HTbSZe+CIdmeqVyFMekwD5dFU8khWQyngKATvd12FU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go index 13bb9fa4..29107b4d 100644 --- a/internal/transactionrequest.go +++ b/internal/transactionrequest.go @@ -168,7 +168,7 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut } continue } - if err = event.VerifyEventSignatures(ctx, t.keys); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil { util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) results[event.EventID()] = fclient.PDUResult{ Error: err.Error(), diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 73feb2d3..d4606622 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -55,7 +55,8 @@ func SendEventWithState( state gomatrixserverlib.StateResponse, event *types.HeaderedEvent, origin spec.ServerName, haveEventIDs map[string]bool, async bool, ) error { - outliers := gomatrixserverlib.LineariseStateResponse(event.Version(), state) + outliersPDU := gomatrixserverlib.LineariseStateResponse(event.Version(), state) + outliers := gomatrixserverlib.TempCastToEvents(outliersPDU) ires := make([]InputRoomEvent, 0, len(outliers)) for _, outlier := range outliers { if haveEventIDs[outlier.EventID()] { diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 0fdd6982..48e2e1cf 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -65,7 +65,9 @@ func CheckForSoftFail( } // Work out which of the state events we actually need. - stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Event}) + stateNeeded := gomatrixserverlib.StateNeededForAuth( + gomatrixserverlib.ToPDUs([]*gomatrixserverlib.Event{event.Event}), + ) // Load the actual auth events from the database. authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries) @@ -98,7 +100,7 @@ func CheckAuthEvents( authStateEntries = types.DeduplicateStateEntries(authStateEntries) // Work out which of the state events we actually need. - stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Event}) + stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.PDU{event.Event}) // Load the actual auth events from the database. authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries) @@ -132,31 +134,31 @@ func (ae *authEvents) Valid() bool { } // Create implements gomatrixserverlib.AuthEventProvider -func (ae *authEvents) Create() (*gomatrixserverlib.Event, error) { +func (ae *authEvents) Create() (gomatrixserverlib.PDU, error) { return ae.lookupEventWithEmptyStateKey(types.MRoomCreateNID), nil } // PowerLevels implements gomatrixserverlib.AuthEventProvider -func (ae *authEvents) PowerLevels() (*gomatrixserverlib.Event, error) { +func (ae *authEvents) PowerLevels() (gomatrixserverlib.PDU, error) { return ae.lookupEventWithEmptyStateKey(types.MRoomPowerLevelsNID), nil } // JoinRules implements gomatrixserverlib.AuthEventProvider -func (ae *authEvents) JoinRules() (*gomatrixserverlib.Event, error) { +func (ae *authEvents) JoinRules() (gomatrixserverlib.PDU, error) { return ae.lookupEventWithEmptyStateKey(types.MRoomJoinRulesNID), nil } // Memmber implements gomatrixserverlib.AuthEventProvider -func (ae *authEvents) Member(stateKey string) (*gomatrixserverlib.Event, error) { +func (ae *authEvents) Member(stateKey string) (gomatrixserverlib.PDU, error) { return ae.lookupEvent(types.MRoomMemberNID, stateKey), nil } // ThirdPartyInvite implements gomatrixserverlib.AuthEventProvider -func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Event, error) { +func (ae *authEvents) ThirdPartyInvite(stateKey string) (gomatrixserverlib.PDU, error) { return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil } -func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event { +func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) gomatrixserverlib.PDU { eventNID, ok := ae.state.lookup(types.StateKeyTuple{ EventTypeNID: typeNID, EventStateKeyNID: types.EmptyStateKeyNID, @@ -171,7 +173,7 @@ func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) * return event.Event } -func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *gomatrixserverlib.Event { +func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) gomatrixserverlib.PDU { stateKeyNID, ok := ae.stateKeyNIDMap[stateKey] if !ok { return nil diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 6fa7cfc2..763e4170 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -545,7 +545,7 @@ func (r *Inputer) processStateBefore( // will include the history visibility here even though we don't // actually need it for auth, because we want to send it in the // output events. - tuplesNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event}).Tuples() + tuplesNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.PDU{event}).Tuples() tuplesNeeded = append(tuplesNeeded, gomatrixserverlib.StateKeyTuple{ EventType: spec.MRoomHistoryVisibility, StateKey: "", @@ -576,7 +576,9 @@ func (r *Inputer) processStateBefore( // At this point, stateBeforeEvent should be populated either by // the supplied state in the input request, or from the prev events. // Check whether the event is allowed or not. - stateBeforeAuth := gomatrixserverlib.NewAuthEvents(stateBeforeEvent) + stateBeforeAuth := gomatrixserverlib.NewAuthEvents( + gomatrixserverlib.ToPDUs(stateBeforeEvent), + ) if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth); rejectionErr != nil { rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr) return @@ -675,7 +677,7 @@ func (r *Inputer) fetchAuthEvents( isRejected := false nextAuthEvent: for _, authEvent := range gomatrixserverlib.ReverseTopologicalOrdering( - res.AuthEvents.UntrustedEvents(event.Version()), + gomatrixserverlib.ToPDUs(res.AuthEvents.UntrustedEvents(event.Version())), gomatrixserverlib.TopologicalOrderByAuthEvents, ) { // If we already know about this event from the database then we don't @@ -688,7 +690,7 @@ nextAuthEvent: // Check the signatures of the event. If this fails then we'll simply // skip it, because gomatrixserverlib.Allowed() will notice a problem // if a critical event is missing anyway. - if err := authEvent.VerifyEventSignatures(ctx, r.FSAPI.KeyRing()); err != nil { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing()); err != nil { continue nextAuthEvent } @@ -743,7 +745,7 @@ nextAuthEvent: // Now we know about this event, it was stored and the signatures were OK. known[authEvent.EventID()] = &types.Event{ EventNID: eventNID, - Event: authEvent, + Event: authEvent.(*gomatrixserverlib.Event), } } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index b56b2418..6847509b 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -26,7 +26,7 @@ type parsedRespState struct { StateEvents []*gomatrixserverlib.Event } -func (p *parsedRespState) Events() []*gomatrixserverlib.Event { +func (p *parsedRespState) Events() []gomatrixserverlib.PDU { eventsByID := make(map[string]*gomatrixserverlib.Event, len(p.AuthEvents)+len(p.StateEvents)) for i, event := range p.AuthEvents { eventsByID[event.EventID()] = p.AuthEvents[i] @@ -38,7 +38,8 @@ func (p *parsedRespState) Events() []*gomatrixserverlib.Event { for _, event := range eventsByID { allEvents = append(allEvents, event) } - return gomatrixserverlib.ReverseTopologicalOrdering(allEvents, gomatrixserverlib.TopologicalOrderByAuthEvents) + return gomatrixserverlib.ReverseTopologicalOrdering( + gomatrixserverlib.ToPDUs(allEvents), gomatrixserverlib.TopologicalOrderByAuthEvents) } type missingStateReq struct { @@ -155,7 +156,7 @@ func (t *missingStateReq) processEventWithMissingState( } outlierRoomEvents = append(outlierRoomEvents, api.InputRoomEvent{ Kind: api.KindOutlier, - Event: &types.HeaderedEvent{Event: outlier}, + Event: &types.HeaderedEvent{Event: outlier.(*gomatrixserverlib.Event)}, Origin: t.origin, }) } @@ -468,7 +469,9 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion authEventList = append(authEventList, state.AuthEvents...) stateEventList = append(stateEventList, state.StateEvents...) } - resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(roomVersion, stateEventList, authEventList) + resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts( + roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), + ) if err != nil { return nil, err } @@ -482,7 +485,7 @@ retryAllowedState: case verifySigError: return &parsedRespState{ AuthEvents: authEventList, - StateEvents: resolvedStateEvents, + StateEvents: gomatrixserverlib.TempCastToEvents(resolvedStateEvents), }, nil case nil: // do nothing @@ -498,7 +501,7 @@ retryAllowedState: } return &parsedRespState{ AuthEvents: authEventList, - StateEvents: resolvedStateEvents, + StateEvents: gomatrixserverlib.TempCastToEvents(resolvedStateEvents), }, nil } @@ -559,7 +562,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve // will be added and duplicates will be removed. missingEvents := make([]*gomatrixserverlib.Event, 0, len(missingResp.Events)) for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { - if err = ev.VerifyEventSignatures(ctx, t.keys); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys); err != nil { continue } missingEvents = append(missingEvents, t.cacheAndReturn(ev)) @@ -567,7 +570,9 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve logger.Debugf("get_missing_events returned %d events (%d passed signature checks)", len(missingResp.Events), len(missingEvents)) // topologically sort and sanity check that we are making forward progress - newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingEvents, gomatrixserverlib.TopologicalOrderByPrevEvents) + newEventsPDUs := gomatrixserverlib.ReverseTopologicalOrdering( + gomatrixserverlib.ToPDUs(missingEvents), gomatrixserverlib.TopologicalOrderByPrevEvents) + newEvents = gomatrixserverlib.TempCastToEvents(newEventsPDUs) shouldHaveSomeEventIDs := e.PrevEventIDs() hasPrevEvent := false Event: @@ -882,14 +887,14 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) } - if err := event.VerifyEventSignatures(ctx, t.keys); err != nil { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil { t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) return nil, verifySigError{event.EventID(), err} } return t.cacheAndReturn(event), nil } -func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserverlib.Event) error { +func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []gomatrixserverlib.PDU) error { authUsingState := gomatrixserverlib.NewAuthEvents(nil) for i := range stateEvents { err := authUsingState.AddEvent(stateEvents[i]) diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 2d96721e..e08f3d61 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -334,13 +334,13 @@ func (r *Admin) PerformAdminDownloadState( return nil } for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = authEvent.VerifyEventSignatures(ctx, r.Inputer.KeyRing); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing); err != nil { continue } authEventMap[authEvent.EventID()] = authEvent } for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = stateEvent.VerifyEventSignatures(ctx, r.Inputer.KeyRing); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing); err != nil { continue } stateEventMap[stateEvent.EventID()] = stateEvent diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 6b250150..daaf5878 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -133,7 +133,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) // persist these new events - auth checks have already been done - roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) + roomNID, backfilledEventMap := persistEvents(ctx, r.DB, gomatrixserverlib.TempCastToEvents(events)) for _, ev := range backfilledEventMap { // now add state for these events @@ -170,7 +170,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform res.Events = make([]*types.HeaderedEvent, len(events)) for i := range events { - res.Events[i] = &types.HeaderedEvent{Event: events[i]} + res.Events[i] = &types.HeaderedEvent{Event: events[i].(*gomatrixserverlib.Event)} } res.HistoryVisibility = requester.historyVisiblity return nil @@ -230,7 +230,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom logger.WithError(err).Warn("event failed PDU checks") continue } - missingMap[id] = &types.HeaderedEvent{Event: res.Event} + missingMap[id] = &types.HeaderedEvent{Event: res.Event.(*gomatrixserverlib.Event)} } } } @@ -257,7 +257,7 @@ type backfillRequester struct { // per-request state servers []spec.ServerName eventIDToBeforeStateIDs map[string][]string - eventIDMap map[string]*gomatrixserverlib.Event + eventIDMap map[string]gomatrixserverlib.PDU historyVisiblity gomatrixserverlib.HistoryVisibility roomInfo types.RoomInfo } @@ -278,14 +278,14 @@ func newBackfillRequester( virtualHost: virtualHost, isLocalServerName: isLocalServerName, eventIDToBeforeStateIDs: make(map[string][]string), - eventIDMap: make(map[string]*gomatrixserverlib.Event), + eventIDMap: make(map[string]gomatrixserverlib.PDU), bwExtrems: bwExtrems, preferServer: preferServer, historyVisiblity: gomatrixserverlib.HistoryVisibilityShared, } } -func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent *gomatrixserverlib.Event) ([]string, error) { +func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent gomatrixserverlib.PDU) ([]string, error) { b.eventIDMap[targetEvent.EventID()] = targetEvent if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok { return ids, nil @@ -337,7 +337,7 @@ FederationHit: return nil, lastErr } -func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent *gomatrixserverlib.Event, prevEventStateIDs []string) []string { +func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent gomatrixserverlib.PDU, prevEventStateIDs []string) []string { newStateIDs := prevEventStateIDs[:] if prevEvent.StateKey() == nil { // state is the same as the previous event @@ -375,7 +375,7 @@ func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent *gomatri } func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, - event *gomatrixserverlib.Event, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) { + event gomatrixserverlib.PDU, eventIDs []string) (map[string]gomatrixserverlib.PDU, error) { // try to fetch the events from the database first events, err := b.ProvideEvents(roomVer, eventIDs) @@ -385,7 +385,7 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr } else { logrus.Infof("Fetched %d/%d events from the database", len(events), len(eventIDs)) if len(events) == len(eventIDs) { - result := make(map[string]*gomatrixserverlib.Event) + result := make(map[string]gomatrixserverlib.PDU) for i := range events { result[events[i].EventID()] = events[i] b.eventIDMap[events[i].EventID()] = events[i] @@ -516,7 +516,7 @@ func (b *backfillRequester) Backfill(ctx context.Context, origin, server spec.Se return tx, err } -func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, eventIDs []string) ([]*gomatrixserverlib.Event, error) { +func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.PDU, error) { ctx := context.Background() nidMap, err := b.db.EventNIDs(ctx, eventIDs) if err != nil { @@ -538,7 +538,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") return nil, err } - events := make([]*gomatrixserverlib.Event, len(eventsWithNids)) + events := make([]gomatrixserverlib.PDU, len(eventsWithNids)) for i := range eventsWithNids { events[i] = eventsWithNids[i].Event } diff --git a/roomserver/internal/perform/perform_inbound_peek.go b/roomserver/internal/perform/perform_inbound_peek.go index 19b81c64..68b82746 100644 --- a/roomserver/internal/perform/perform_inbound_peek.go +++ b/roomserver/internal/perform/perform_inbound_peek.go @@ -68,7 +68,7 @@ func (r *InboundPeeker) PerformInboundPeek( if err != nil { return err } - var sortedLatestEvents []*gomatrixserverlib.Event + var sortedLatestEvents []gomatrixserverlib.PDU for _, ev := range latestEvents { sortedLatestEvents = append(sortedLatestEvents, ev.Event) } @@ -76,7 +76,7 @@ func (r *InboundPeeker) PerformInboundPeek( sortedLatestEvents, gomatrixserverlib.TopologicalOrderByPrevEvents, ) - response.LatestEvent = &types.HeaderedEvent{Event: sortedLatestEvents[0]} + response.LatestEvent = &types.HeaderedEvent{Event: sortedLatestEvents[0].(*gomatrixserverlib.Event)} // XXX: do we actually need to do a state resolution here? roomState := state.NewStateResolution(r.DB, info) diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index e37f0e21..bfe70354 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -642,7 +642,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user for i := range queryRes.StateEvents { stateEvents[i] = queryRes.StateEvents[i].Event } - provider := gomatrixserverlib.NewAuthEvents(stateEvents) + provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) if err = gomatrixserverlib.Allowed(headeredEvent.Event, &provider); err != nil { return nil, &api.PerformError{ Code: api.PerformErrorNotAllowed, diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 052ce0a8..4bd648a9 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -121,10 +121,13 @@ func (r *Queryer) QueryStateAfterEvents( return fmt.Errorf("getAuthChain: %w", err) } - stateEvents, err = gomatrixserverlib.ResolveConflicts(info.RoomVersion, stateEvents, authEvents) + stateEventsPDU, err := gomatrixserverlib.ResolveConflicts( + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), + ) if err != nil { return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err) } + stateEvents = gomatrixserverlib.TempCastToEvents(stateEventsPDU) } for _, event := range stateEvents { @@ -585,11 +588,13 @@ func (r *Queryer) QueryStateAndAuthChain( } if request.ResolveState { - if stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, stateEvents, authEvents, - ); err != nil { - return err + stateEventsPDU, err2 := gomatrixserverlib.ResolveConflicts( + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), + ) + if err2 != nil { + return err2 } + stateEvents = gomatrixserverlib.TempCastToEvents(stateEventsPDU) } for _, event := range stateEvents { diff --git a/roomserver/state/state.go b/roomserver/state/state.go index d20877b4..d04b8f6c 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -996,7 +996,7 @@ func (v *StateResolution) resolveConflictsV2( // For each conflicted event, we will add a new set of auth events. Auth // events may be duplicated across these sets but that's OK. authSets := make(map[string][]*gomatrixserverlib.Event, len(conflicted)) - authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3) + authEvents := make([]gomatrixserverlib.PDU, 0, estimate*3) gotAuthEvents := make(map[string]struct{}, estimate*3) knownAuthEvents := make(map[string]types.Event, estimate*3) @@ -1046,7 +1046,7 @@ func (v *StateResolution) resolveConflictsV2( gotAuthEvents = nil // nolint:ineffassign // Resolve the conflicts. - resolvedEvents := func() []*gomatrixserverlib.Event { + resolvedEvents := func() []gomatrixserverlib.PDU { resolvedTrace, _ := internal.StartRegion(ctx, "StateResolution.ResolveStateConflictsV2") defer resolvedTrace.EndRegion() @@ -1119,11 +1119,11 @@ func (v *StateResolution) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.E // Returns an error if there was a problem talking to the database. func (v *StateResolution) loadStateEvents( ctx context.Context, entries []types.StateEntry, -) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { +) ([]gomatrixserverlib.PDU, map[string]types.StateEntry, error) { trace, ctx := internal.StartRegion(ctx, "StateResolution.loadStateEvents") defer trace.EndRegion() - result := make([]*gomatrixserverlib.Event, 0, len(entries)) + result := make([]gomatrixserverlib.PDU, 0, len(entries)) eventEntries := make([]types.StateEntry, 0, len(entries)) eventNIDs := make(types.EventNIDs, 0, len(entries)) for _, entry := range entries { @@ -1163,7 +1163,7 @@ type authEventLoader struct { // loadAuthEvents loads all of the auth events for a given event recursively, // along with a map that contains state entries for all of the auth events. func (l *authEventLoader) loadAuthEvents( - ctx context.Context, roomInfo *types.RoomInfo, event *gomatrixserverlib.Event, eventMap map[string]types.Event, + ctx context.Context, roomInfo *types.RoomInfo, event gomatrixserverlib.PDU, eventMap map[string]types.Event, ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { l.Lock() defer l.Unlock() diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 3915f4bb..1cf05d59 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -77,7 +77,7 @@ type Database interface { SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) // Stores a matrix room event in the database. Returns the room NID, the state snapshot or an error. - StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) + StoreEvent(ctx context.Context, event gomatrixserverlib.PDU, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) // Look up the state entries for a list of string event IDs // Returns an error if the there is an error talking to the database // Returns a types.MissingEventError if the event IDs aren't in the database. @@ -182,7 +182,7 @@ type Database interface { GetMembershipForHistoryVisibility( ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string, ) (map[string]*types.HeaderedEvent, error) - GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error) + GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (*types.RoomInfo, error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) MaybeRedactEvent( @@ -207,7 +207,7 @@ type RoomDatabase interface { StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) - GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error) + GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (*types.RoomInfo, error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) @@ -230,5 +230,5 @@ type EventDatabase interface { MaybeRedactEvent( ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, plResolver state.PowerLevelResolver, ) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) - StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) + StoreEvent(ctx context.Context, event gomatrixserverlib.PDU, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 6dc9280c..c31302cf 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -658,7 +658,7 @@ func (d *Database) IsEventRejected(ctx context.Context, roomNID types.RoomNID, e } // GetOrCreateRoomInfo gets or creates a new RoomInfo, which is only safe to use with functions only needing a roomVersion or roomNID. -func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (roomInfo *types.RoomInfo, err error) { +func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (roomInfo *types.RoomInfo, err error) { // Get the default room version. If the client doesn't supply a room_version // then we will use our configured default to create the room. // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom @@ -725,7 +725,7 @@ func (d *Database) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKe } func (d *EventDatabase) StoreEvent( - ctx context.Context, event *gomatrixserverlib.Event, + ctx context.Context, event gomatrixserverlib.PDU, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool, ) (types.EventNID, types.StateAtEvent, error) { @@ -909,7 +909,7 @@ func (d *EventDatabase) assignStateKeyNID( return eventStateKeyNID, err } -func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) ( +func extractRoomVersionFromCreateEvent(event gomatrixserverlib.PDU) ( gomatrixserverlib.RoomVersion, error, ) { var err error diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index b106a246..38412fa2 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -646,7 +646,7 @@ func (rc *reqCtx) getLocalEvent(roomID, eventID string) *types.HeaderedEvent { // into the roomserver as KindOutlier, with auth chains. func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsResponse) { var stateEvents gomatrixserverlib.EventJSONs - var messageEvents []*gomatrixserverlib.Event + var messageEvents []gomatrixserverlib.PDU for _, ev := range res.ParsedEvents { if ev.StateKey() != nil { stateEvents = append(stateEvents, ev.JSON()) @@ -665,7 +665,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo for _, outlier := range append(eventsInOrder, messageEvents...) { ires = append(ires, roomserver.InputRoomEvent{ Kind: roomserver.KindOutlier, - Event: &types.HeaderedEvent{Event: outlier}, + Event: &types.HeaderedEvent{Event: outlier.(*gomatrixserverlib.Event)}, }) } // we've got the data by this point so use a background context diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 723dd88f..e024dfae 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -279,13 +279,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( hisVisMap[re.EventID()] = re.Visibility } recEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering( - toEvents(snapshot.StreamEventsToEvents(device, recentStreamEvents)), + gomatrixserverlib.ToPDUs(toEvents(snapshot.StreamEventsToEvents(device, recentStreamEvents))), gomatrixserverlib.TopologicalOrderByPrevEvents, ) recentEvents := make([]*rstypes.HeaderedEvent, len(recEvents)) for i := range recEvents { recentEvents[i] = &rstypes.HeaderedEvent{ - Event: recEvents[i], + Event: recEvents[i].(*gomatrixserverlib.Event), Visibility: hisVisMap[recEvents[i].EventID()], } } @@ -358,13 +358,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( hisVisMap[re.EventID()] = re.Visibility } sEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering( - toEvents(removeDuplicates(delta.StateEvents, events)), + gomatrixserverlib.ToPDUs(toEvents(removeDuplicates(delta.StateEvents, events))), gomatrixserverlib.TopologicalOrderByAuthEvents, ) delta.StateEvents = make([]*rstypes.HeaderedEvent, len(sEvents)) for i := range sEvents { delta.StateEvents[i] = &rstypes.HeaderedEvent{ - Event: sEvents[i], + Event: sEvents[i].(*gomatrixserverlib.Event), Visibility: hisVisMap[sEvents[i].EventID()], } }