diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index ddda8081..9af0bf59 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -20,17 +20,22 @@ import ( "sort" "github.com/matrix-org/dendrite/roomserver/state" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) +type checkForAuthAndSoftFailStorage interface { + state.StateResolutionStorage + StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) + RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) +} + // CheckForSoftFail returns true if the event should be soft-failed // and false otherwise. The return error value should be checked before // the soft-fail bool. func CheckForSoftFail( ctx context.Context, - db storage.Database, + db checkForAuthAndSoftFailStorage, event *gomatrixserverlib.HeaderedEvent, stateEventIDs []string, ) (bool, error) { @@ -92,7 +97,7 @@ func CheckForSoftFail( // Returns the numeric IDs for the auth events. func CheckAuthEvents( ctx context.Context, - db storage.Database, + db checkForAuthAndSoftFailStorage, event *gomatrixserverlib.HeaderedEvent, authEventIDs []string, ) ([]types.EventNID, error) { @@ -193,7 +198,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) * // loadAuthEvents loads the events needed for authentication from the supplied room state. func loadAuthEvents( ctx context.Context, - db storage.Database, + db state.StateResolutionStorage, needed gomatrixserverlib.StateNeeded, state []types.StateEntry, ) (result authEvents, err error) { diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 7834e2ed..5bdec0a2 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -19,6 +19,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "sync" "time" @@ -38,6 +39,19 @@ import ( "github.com/tidwall/gjson" ) +type retryAction int +type commitAction int + +const ( + doNotRetry retryAction = iota + retryLater +) + +const ( + commitTransaction commitAction = iota + rollbackTransaction +) + var keyContentFields = map[string]string{ "m.room.join_rules": "join_rule", "m.room.history_visibility": "history_visibility", @@ -101,7 +115,8 @@ func (r *Inputer) Start() error { _ = msg.InProgress() // resets the acknowledgement wait timer defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - if err := r.processRoomEvent(context.Background(), &inputRoomEvent); err != nil { + action, err := r.processRoomEventUsingUpdater(context.Background(), roomID, &inputRoomEvent) + if err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) } @@ -111,7 +126,12 @@ func (r *Inputer) Start() error { "type": inputRoomEvent.Event.Type(), }).Warn("Roomserver failed to process async event") } - _ = msg.Ack() + switch action { + case retryLater: + _ = msg.Nak() + case doNotRetry: + _ = msg.Ack() + } }) }, // NATS wants to acknowledge automatically by default when the message is @@ -131,6 +151,37 @@ func (r *Inputer) Start() error { return err } +// processRoomEventUsingUpdater opens up a room updater and tries to +// process the event. It returns whether or not we should positively +// or negatively acknowledge the event (i.e. for NATS) and an error +// if it occurred. +func (r *Inputer) processRoomEventUsingUpdater( + ctx context.Context, + roomID string, + inputRoomEvent *api.InputRoomEvent, +) (retryAction, error) { + roomInfo, err := r.DB.RoomInfo(ctx, roomID) + if err != nil { + return doNotRetry, fmt.Errorf("r.DB.RoomInfo: %w", err) + } + updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) + if err != nil { + return retryLater, fmt.Errorf("r.DB.GetRoomUpdater: %w", err) + } + action, err := r.processRoomEvent(ctx, updater, inputRoomEvent) + switch action { + case commitTransaction: + if cerr := updater.Commit(); cerr != nil { + return retryLater, fmt.Errorf("updater.Commit: %w", cerr) + } + case rollbackTransaction: + if rerr := updater.Rollback(); rerr != nil { + return retryLater, fmt.Errorf("updater.Rollback: %w", rerr) + } + } + return doNotRetry, err +} + // InputRoomEvents implements api.RoomserverInternalAPI func (r *Inputer) InputRoomEvents( ctx context.Context, @@ -177,7 +228,7 @@ func (r *Inputer) InputRoomEvents( worker.Act(nil, func() { defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - err := r.processRoomEvent(ctx, &inputRoomEvent) + _, err := r.processRoomEventUsingUpdater(ctx, roomID, &inputRoomEvent) if err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 16703616..f3fa83d8 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -29,6 +29,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -67,14 +68,15 @@ var processRoomEventDuration = prometheus.NewHistogramVec( // nolint:gocyclo func (r *Inputer) processRoomEvent( ctx context.Context, + updater *shared.RoomUpdater, input *api.InputRoomEvent, -) (err error) { +) (commitAction, error) { select { case <-ctx.Done(): // Before we do anything, make sure the context hasn't expired for this pending task. // If it has then we'll give up straight away — it's probably a synchronous input // request and the caller has already given up, but the inbox task was still queued. - return context.DeadlineExceeded + return rollbackTransaction, context.DeadlineExceeded default: } @@ -107,7 +109,7 @@ func (r *Inputer) processRoomEvent( // if we have already got this event then do not process it again, if the input kind is an outlier. // Outliers contain no extra information which may warrant a re-processing. if input.Kind == api.KindOutlier { - evs, err2 := r.DB.EventsFromIDs(ctx, []string{event.EventID()}) + evs, err2 := updater.EventsFromIDs(ctx, []string{event.EventID()}) if err2 == nil && len(evs) == 1 { // check hash matches if we're on early room versions where the event ID was a random string idFormat, err2 := headered.RoomVersion.EventIDFormat() @@ -116,11 +118,11 @@ func (r *Inputer) processRoomEvent( case gomatrixserverlib.EventIDFormatV1: if bytes.Equal(event.EventReference().EventSHA256, evs[0].EventReference().EventSHA256) { logger.Debugf("Already processed event; ignoring") - return nil + return rollbackTransaction, nil } default: logger.Debugf("Already processed event; ignoring") - return nil + return rollbackTransaction, nil } } } @@ -134,8 +136,8 @@ func (r *Inputer) processRoomEvent( AuthEventIDs: event.AuthEventIDs(), PrevEventIDs: event.PrevEventIDs(), } - if err = r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil { - return fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err) + if err := r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil { + return rollbackTransaction, fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err) } } missingAuth := len(missingRes.MissingAuthEventIDs) > 0 @@ -146,8 +148,8 @@ func (r *Inputer) processRoomEvent( RoomID: event.RoomID(), ExcludeSelf: true, } - if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { - return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) + if err := r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { + return rollbackTransaction, fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) } // Sort all of the servers into a map so that we can randomise // their order. Then make sure that the input origin and the @@ -176,8 +178,8 @@ func (r *Inputer) processRoomEvent( isRejected := false authEvents := gomatrixserverlib.NewAuthEvents(nil) knownEvents := map[string]*types.Event{} - if err = r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { - return fmt.Errorf("r.fetchAuthEvents: %w", err) + if err := r.fetchAuthEvents(ctx, updater, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + return rollbackTransaction, fmt.Errorf("r.fetchAuthEvents: %w", err) } // Check if the event is allowed by its auth events. If it isn't then @@ -193,7 +195,7 @@ func (r *Inputer) processRoomEvent( authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) for _, authEventID := range authEventIDs { if _, ok := knownEvents[authEventID]; !ok { - return fmt.Errorf("missing auth event %s", authEventID) + return rollbackTransaction, fmt.Errorf("missing auth event %s", authEventID) } authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) } @@ -202,7 +204,8 @@ func (r *Inputer) processRoomEvent( if input.Kind == api.KindNew { // Check that the event passes authentication checks based on the // current room state. - softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) + var err error + softfail, err = helpers.CheckForSoftFail(ctx, updater, headered, input.StateEventIDs) if err != nil { logger.WithError(err).Warn("Error authing soft-failed event") } @@ -227,7 +230,7 @@ func (r *Inputer) processRoomEvent( origin: input.Origin, inputer: r, queryer: r.Queryer, - db: r.DB, + db: updater, federation: r.FSAPI, keys: r.KeyRing, roomsMu: internal.NewMutexByRoom(), @@ -235,7 +238,7 @@ func (r *Inputer) processRoomEvent( hadEvents: map[string]bool{}, haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{}, } - if err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { + if err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { isRejected = true rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err) } else { @@ -248,16 +251,16 @@ func (r *Inputer) processRoomEvent( } // Store the event. - _, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected) + _, _, stateAtEvent, redactionEvent, redactedEventID, err := updater.StoreEvent(ctx, event, authEventNIDs, isRejected) if err != nil { - return fmt.Errorf("r.DB.StoreEvent: %w", err) + return rollbackTransaction, fmt.Errorf("updater.StoreEvent: %w", err) } // if storing this event results in it being redacted then do so. if !isRejected && redactedEventID == event.EventID() { r, rerr := eventutil.RedactEvent(redactionEvent, event) if rerr != nil { - return fmt.Errorf("eventutil.RedactEvent: %w", rerr) + return rollbackTransaction, fmt.Errorf("eventutil.RedactEvent: %w", rerr) } event = r } @@ -268,23 +271,23 @@ func (r *Inputer) processRoomEvent( if input.Kind == api.KindOutlier { logger.Debug("Stored outlier") hooks.Run(hooks.KindNewEventPersisted, headered) - return nil + return commitTransaction, nil } - roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID()) + roomInfo, err := updater.RoomInfo(ctx, event.RoomID()) if err != nil { - return fmt.Errorf("r.DB.RoomInfo: %w", err) + return rollbackTransaction, fmt.Errorf("updater.RoomInfo: %w", err) } if roomInfo == nil { - return fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID()) + return rollbackTransaction, fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID()) } if !missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0 { // We haven't calculated a state for this event yet. // Lets calculate one. - err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected) + err = r.calculateAndSetState(ctx, updater, input, roomInfo, &stateAtEvent, event, isRejected) if err != nil { - return fmt.Errorf("r.calculateAndSetState: %w", err) + return rollbackTransaction, fmt.Errorf("r.calculateAndSetState: %w", err) } } @@ -294,13 +297,14 @@ func (r *Inputer) processRoomEvent( "soft_fail": softfail, "missing_prev": missingPrev, }).Warn("Stored rejected event") - return rejectionErr + return commitTransaction, rejectionErr } switch input.Kind { case api.KindNew: if err = r.updateLatestEvents( ctx, // context + updater, // room updater roomInfo, // room info for the room being updated stateAtEvent, // state at event (below) event, // event @@ -308,7 +312,7 @@ func (r *Inputer) processRoomEvent( input.TransactionID, // transaction ID input.HasState, // rewrites state? ); err != nil { - return fmt.Errorf("r.updateLatestEvents: %w", err) + return rollbackTransaction, fmt.Errorf("r.updateLatestEvents: %w", err) } case api.KindOld: err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{ @@ -320,7 +324,7 @@ func (r *Inputer) processRoomEvent( }, }) if err != nil { - return fmt.Errorf("r.WriteOutputEvents (old): %w", err) + return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (old): %w", err) } } @@ -339,14 +343,14 @@ func (r *Inputer) processRoomEvent( }, }) if err != nil { - return fmt.Errorf("r.WriteOutputEvents (redactions): %w", err) + return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (redactions): %w", err) } } // Everything was OK — the latest events updater didn't error and // we've sent output events. Finally, generate a hook call. hooks.Run(hooks.KindNewEventPersisted, headered) - return nil + return commitTransaction, nil } // fetchAuthEvents will check to see if any of the @@ -358,6 +362,7 @@ func (r *Inputer) processRoomEvent( // they are now in the database. func (r *Inputer) fetchAuthEvents( ctx context.Context, + updater *shared.RoomUpdater, logger *logrus.Entry, event *gomatrixserverlib.HeaderedEvent, auth *gomatrixserverlib.AuthEvents, @@ -375,7 +380,7 @@ func (r *Inputer) fetchAuthEvents( } for _, authEventID := range authEventIDs { - authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID}) + authEvents, err := updater.EventsFromIDs(ctx, []string{authEventID}) if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { unknown[authEventID] = struct{}{} continue @@ -454,9 +459,9 @@ func (r *Inputer) fetchAuthEvents( } // Finally, store the event in the database. - eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) + eventNID, _, _, _, _, err := updater.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) if err != nil { - return fmt.Errorf("r.DB.StoreEvent: %w", err) + return fmt.Errorf("updater.StoreEvent: %w", err) } // Now we know about this event, it was stored and the signatures were OK. @@ -471,6 +476,7 @@ func (r *Inputer) fetchAuthEvents( func (r *Inputer) calculateAndSetState( ctx context.Context, + updater *shared.RoomUpdater, input *api.InputRoomEvent, roomInfo *types.RoomInfo, stateAtEvent *types.StateAtEvent, @@ -478,14 +484,14 @@ func (r *Inputer) calculateAndSetState( isRejected bool, ) error { var err error - roomState := state.NewStateResolution(r.DB, roomInfo) + roomState := state.NewStateResolution(updater, roomInfo) if input.HasState { // Check here if we think we're in the room already. stateAtEvent.Overwrite = true var joinEventNIDs []types.EventNID // Request join memberships only for local users only. - if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil { + if joinEventNIDs, err = updater.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil { // If we have no local users that are joined to the room then any state about // the room that we have is quite possibly out of date. Therefore in that case // we should overwrite it rather than merge it. @@ -495,13 +501,13 @@ func (r *Inputer) calculateAndSetState( // We've been told what the state at the event is so we don't need to calculate it. // Check that those state events are in the database and store the state. var entries []types.StateEntry - if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { - return fmt.Errorf("r.DB.StateEntriesForEventIDs: %w", err) + if entries, err = updater.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { + return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err) } entries = types.DeduplicateStateEntries(entries) - if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { - return fmt.Errorf("r.DB.AddState: %w", err) + if stateAtEvent.BeforeStateSnapshotNID, err = updater.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { + return fmt.Errorf("updater.AddState: %w", err) } } else { stateAtEvent.Overwrite = false @@ -512,7 +518,7 @@ func (r *Inputer) calculateAndSetState( } } - err = r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) + err = updater.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) if err != nil { return fmt.Errorf("r.DB.SetState: %w", err) } diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 6137941e..5173d3ab 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -20,7 +20,6 @@ import ( "context" "fmt" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -48,6 +47,7 @@ import ( // Can only be called once at a time func (r *Inputer) updateLatestEvents( ctx context.Context, + updater *shared.RoomUpdater, roomInfo *types.RoomInfo, stateAtEvent types.StateAtEvent, event *gomatrixserverlib.Event, @@ -55,13 +55,6 @@ func (r *Inputer) updateLatestEvents( transactionID *api.TransactionID, rewritesState bool, ) (err error) { - updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo) - if err != nil { - return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err) - } - succeeded := false - defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) - u := latestEventsUpdater{ ctx: ctx, api: r, @@ -78,7 +71,6 @@ func (r *Inputer) updateLatestEvents( return fmt.Errorf("u.doUpdateLatestEvents: %w", err) } - succeeded = true return } @@ -89,7 +81,7 @@ func (r *Inputer) updateLatestEvents( type latestEventsUpdater struct { ctx context.Context api *Inputer - updater *shared.LatestEventsUpdater + updater *shared.RoomUpdater roomInfo *types.RoomInfo stateAtEvent types.StateAtEvent event *gomatrixserverlib.Event @@ -199,7 +191,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error - roomState := state.NewStateResolution(u.api.DB, u.roomInfo) + roomState := state.NewStateResolution(u.updater, u.roomInfo) // Work out if the state at the extremities has actually changed // or not. If they haven't then we won't bother doing all of the @@ -413,7 +405,7 @@ func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.Ro if len(extraEventIDs) == 0 { return nil, nil } - extraEvents, err := u.api.DB.EventsFromIDs(u.ctx, extraEventIDs) + extraEvents, err := u.updater.EventsFromIDs(u.ctx, extraEventIDs) if err != nil { return nil, err } @@ -436,7 +428,7 @@ func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error) stateEventNIDs = append(stateEventNIDs, entry.EventNID) } stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] - return u.api.DB.EventIDs(u.ctx, stateEventNIDs) + return u.updater.EventIDs(u.ctx, stateEventNIDs) } type eventNIDSorter []types.EventNID diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 2511097d..ff3ed7e5 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -31,7 +31,7 @@ import ( // consumers about the invites added or retired by the change in current state. func (r *Inputer) updateMemberships( ctx context.Context, - updater *shared.LatestEventsUpdater, + updater *shared.RoomUpdater, removed, added []types.StateEntry, ) ([]api.OutputEvent, error) { changes := membershipChanges(removed, added) @@ -79,7 +79,7 @@ func (r *Inputer) updateMemberships( } func (r *Inputer) updateMembership( - updater *shared.LatestEventsUpdater, + updater *shared.RoomUpdater, targetUserNID types.EventStateKeyNID, remove, add *gomatrixserverlib.Event, updates []api.OutputEvent, diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index d401fa0e..4cd2b3de 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -11,7 +11,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/query" - "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -19,7 +19,7 @@ import ( type missingStateReq struct { origin gomatrixserverlib.ServerName - db storage.Database + db *shared.RoomUpdater inputer *Inputer queryer *query.Queryer keys gomatrixserverlib.JSONVerifier @@ -78,7 +78,7 @@ func (t *missingStateReq) processEventWithMissingState( // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // in the gap in the DAG for _, newEvent := range newEvents { - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ Kind: api.KindNew, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -187,7 +187,7 @@ func (t *missingStateReq) processEventWithMissingState( } // TODO: we could do this concurrently? for _, ire := range outlierRoomEvents { - if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { + if _, err = t.inputer.processRoomEvent(ctx, t.db, &ire); err != nil { return fmt.Errorf("t.inputer.processRoomEvent[outlier]: %w", err) } } @@ -200,7 +200,7 @@ func (t *missingStateReq) processEventWithMissingState( stateIDs = append(stateIDs, event.EventID()) } - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ Kind: api.KindOld, Event: backwardsExtremity.Headered(roomVersion), Origin: t.origin, @@ -217,7 +217,7 @@ func (t *missingStateReq) processEventWithMissingState( // they will automatically fast-forward based on the room state at the // extremity in the last step. for _, newEvent := range newEvents { - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 15d592b4..e5f69521 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -22,7 +22,6 @@ import ( "sort" "time" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" @@ -30,13 +29,25 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) +type StateResolutionStorage interface { + EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) + SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) + StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) + Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) +} + type StateResolution struct { - db storage.Database + db StateResolutionStorage roomInfo *types.RoomInfo events map[types.EventNID]*gomatrixserverlib.Event } -func NewStateResolution(db storage.Database, roomInfo *types.RoomInfo) StateResolution { +func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution { return StateResolution{ db: db, roomInfo: roomInfo, diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 15764366..a9851e05 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -86,11 +86,10 @@ type Database interface { // Lookup the event IDs for a batch of event numeric IDs. // Returns an error if the retrieval went wrong. EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) - // Look up the latest events in a room in preparation for an update. - // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. - // Returns the latest events in the room and the last eventID sent to the log along with an updater. + // Opens and returns a room updater, which locks the room and opens a transaction. + // The GetRoomUpdater must have Commit or Rollback called on it if this doesn't return an error. // If this returns an error then no further action is required. - GetLatestEventsForUpdate(ctx context.Context, roomInfo types.RoomInfo) (*shared.LatestEventsUpdater, error) + GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error) // Look up event references for the latest events in the room and the current state snapshot. // Returns the latest events, the current state and the maximum depth of the latest events plus 1. // Returns an error if there was a problem talking to the database. diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index 32e45782..433e445d 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -81,9 +81,10 @@ func (s *eventJSONStatements) InsertEventJSON( } func (s *eventJSONStatements) BulkSelectEventJSON( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]tables.EventJSONPair, error) { - rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventJSONStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go index 3a7cf03e..762b3a1f 100644 --- a/roomserver/storage/postgres/event_state_keys_table.go +++ b/roomserver/storage/postgres/event_state_keys_table.go @@ -111,9 +111,10 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( - ctx context.Context, eventStateKeys []string, + ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { - rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext( + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt) + rows, err := stmt.QueryContext( ctx, pq.StringArray(eventStateKeys), ) if err != nil { @@ -134,13 +135,14 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKey( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { nIDs := make(pq.Int64Array, len(eventStateKeyNIDs)) for i := range eventStateKeyNIDs { nIDs[i] = int64(eventStateKeyNIDs[i]) } - rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs) + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyStmt) + rows, err := stmt.QueryContext(ctx, nIDs) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index e558072a..1d5de582 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -133,9 +133,10 @@ func (s *eventTypeStatements) SelectEventTypeNID( } func (s *eventTypeStatements) BulkSelectEventTypeNID( - ctx context.Context, eventTypes []string, + ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { - rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventTypeNIDStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventTypes)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 778cd8d7..6c384775 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -212,9 +212,10 @@ func (s *eventStatements) SelectEvent( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError func (s *eventStatements) BulkSelectStateEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateEntry, error) { - rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } @@ -254,13 +255,14 @@ func (s *eventStatements) BulkSelectStateEventByID( // bulkSelectStateEventByNID lookups a list of state events by event NID. // If any of the requested events are missing from the database it returns a types.MissingEventError func (s *eventStatements) BulkSelectStateEventByNID( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { tuples := stateKeyTupleSorter(stateKeyTuples) sort.Sort(tuples) eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray) if err != nil { return nil, err } @@ -291,9 +293,10 @@ func (s *eventStatements) BulkSelectStateEventByNID( // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. func (s *eventStatements) BulkSelectStateAtEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateAtEvent, error) { - rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateAtEventByIDStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } @@ -428,8 +431,9 @@ func (s *eventStatements) BulkSelectEventReference( } // bulkSelectEventID returns a map from numeric event ID to string event ID. -func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { - rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) +func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventIDStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } @@ -455,8 +459,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { - rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } @@ -484,9 +489,10 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, } func (s *eventStatements) SelectRoomNIDsForEventNIDs( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) (map[types.EventNID]types.RoomNID, error) { - rows, err := s.selectRoomNIDsForEventNIDsStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDsForEventNIDsStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/invite_table.go b/roomserver/storage/postgres/invite_table.go index 344302c8..176c16e4 100644 --- a/roomserver/storage/postgres/invite_table.go +++ b/roomserver/storage/postgres/invite_table.go @@ -97,8 +97,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) { } func (s *inviteStatements) InsertInviteEvent( - ctx context.Context, - txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, + ctx context.Context, txn *sql.Tx, + inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { @@ -116,8 +116,8 @@ func (s *inviteStatements) InsertInviteEvent( } func (s *inviteStatements) UpdateInviteRetired( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) ([]string, error) { stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) @@ -139,10 +139,11 @@ func (s *inviteStatements) UpdateInviteRetired( // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs func (s *inviteStatements) SelectInviteActiveForUserInRoom( - ctx context.Context, + ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, []string, error) { - rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( + stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt) + rows, err := stmt.QueryContext( ctx, targetUserNID, roomNID, ) if err != nil { diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index b0d906c8..48c2c35c 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -186,8 +186,8 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { } func (s *membershipStatements) InsertMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool, ) error { stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) @@ -196,8 +196,8 @@ func (s *membershipStatements) InsertMembership( } func (s *membershipStatements) SelectMembershipForUpdate( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (membership tables.MembershipState, err error) { err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( ctx, roomNID, targetUserNID, @@ -206,17 +206,19 @@ func (s *membershipStatements) SelectMembershipForUpdate( } func (s *membershipStatements) SelectMembershipFromRoomAndTarget( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { - err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt) + err = stmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID, &forgotten) return } func (s *membershipStatements) SelectMembershipsFromRoom( - ctx context.Context, roomNID types.RoomNID, localOnly bool, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var stmt *sql.Stmt if localOnly { @@ -224,6 +226,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom( } else { stmt = s.selectMembershipsFromRoomStmt } + stmt = sqlutil.TxStmt(txn, stmt) rows, err := stmt.QueryContext(ctx, roomNID) if err != nil { return @@ -241,7 +244,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom( } func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var rows *sql.Rows @@ -251,6 +254,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( } else { stmt = s.selectMembershipsFromRoomAndMembershipStmt } + stmt = sqlutil.TxStmt(txn, stmt) rows, err = stmt.QueryContext(ctx, roomNID, membership) if err != nil { return @@ -268,8 +272,8 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( } func (s *membershipStatements) UpdateMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, forgotten bool, ) error { _, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext( @@ -279,9 +283,11 @@ func (s *membershipStatements) UpdateMembership( } func (s *membershipStatements) SelectRoomsWithMembership( - ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, + ctx context.Context, txn *sql.Tx, + userID types.EventStateKeyNID, membershipState tables.MembershipState, ) ([]types.RoomNID, error) { - rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt) + rows, err := stmt.QueryContext(ctx, membershipState, userID) if err != nil { return nil, err } @@ -297,12 +303,16 @@ func (s *membershipStatements) SelectRoomsWithMembership( return roomNIDs, nil } -func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { +func (s *membershipStatements) SelectJoinedUsersSetForRooms( + ctx context.Context, txn *sql.Tx, + roomNIDs []types.RoomNID, +) (map[types.EventStateKeyNID]int, error) { roomIDarray := make([]int64, len(roomNIDs)) for i := range roomNIDs { roomIDarray[i] = int64(roomNIDs[i]) } - rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) + stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt) + rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) if err != nil { return nil, err } @@ -319,8 +329,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, return result, rows.Err() } -func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { - rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) +func (s *membershipStatements) SelectKnownUsers( + ctx context.Context, txn *sql.Tx, + userID types.EventStateKeyNID, searchString string, limit int, +) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt) + rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) if err != nil { return nil, err } @@ -337,9 +351,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type } func (s *membershipStatements) UpdateForgetMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, - forget bool, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool, ) error { _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( ctx, roomNID, targetUserNID, forget, @@ -347,9 +360,13 @@ func (s *membershipStatements) UpdateForgetMembership( return err } -func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { +func (s *membershipStatements) SelectLocalServerInRoom( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, +) (bool, error) { var nid types.RoomNID - err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) + stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt) + err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil @@ -360,9 +377,13 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room return found, nil } -func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { +func (s *membershipStatements) SelectServerInRoom( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, serverName gomatrixserverlib.ServerName, +) (bool, error) { var nid types.RoomNID - err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) + stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt) + err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go index 8deb6844..15985fcd 100644 --- a/roomserver/storage/postgres/published_table.go +++ b/roomserver/storage/postgres/published_table.go @@ -73,9 +73,10 @@ func (s *publishedStatements) UpsertRoomPublished( } func (s *publishedStatements) SelectPublishedFromRoomID( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (published bool, err error) { - err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) + stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt) + err = stmt.QueryRowContext(ctx, roomID).Scan(&published) if err == sql.ErrNoRows { return false, nil } @@ -83,9 +84,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID( } func (s *publishedStatements) SelectAllPublishedRooms( - ctx context.Context, published bool, + ctx context.Context, txn *sql.Tx, published bool, ) ([]string, error) { - rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) + stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) + rows, err := stmt.QueryContext(ctx, published) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go index 031825fe..d13df8e7 100644 --- a/roomserver/storage/postgres/room_aliases_table.go +++ b/roomserver/storage/postgres/room_aliases_table.go @@ -87,9 +87,10 @@ func (s *roomAliasesStatements) InsertRoomAlias( } func (s *roomAliasesStatements) SelectRoomIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (roomID string, err error) { - err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) + stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt) + err = stmt.QueryRowContext(ctx, alias).Scan(&roomID) if err == sql.ErrNoRows { return "", nil } @@ -97,9 +98,10 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias( } func (s *roomAliasesStatements) SelectAliasesFromRoomID( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) ([]string, error) { - rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) + stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt) + rows, err := stmt.QueryContext(ctx, roomID) if err != nil { return nil, err } @@ -118,9 +120,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID( } func (s *roomAliasesStatements) SelectCreatorIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (creatorID string, err error) { - err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) + stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt) + err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID) if err == sql.ErrNoRows { return "", nil } diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index f51eba4d..b2685084 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -117,8 +117,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { }.Prepare(db) } -func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { - rows, err := s.selectRoomIDsStmt.QueryContext(ctx) +func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt) + rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err } @@ -143,10 +144,11 @@ func (s *roomStatements) InsertRoomNID( return types.RoomNID(roomNID), err } -func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { +func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) { var info types.RoomInfo var latestNIDs pq.Int64Array - err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan( + stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan( &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDs, ) if err == sql.ErrNoRows { @@ -170,7 +172,7 @@ func (s *roomStatements) SelectLatestEventNIDs( ) ([]types.EventNID, types.StateSnapshotNID, error) { var nids pq.Int64Array var stateSnapshotNID int64 - stmt := s.selectLatestEventNIDsStmt + stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt) err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID) if err != nil { return nil, 0, err @@ -220,9 +222,10 @@ func (s *roomStatements) UpdateLatestEventNIDs( } func (s *roomStatements) SelectRoomVersionsForRoomNIDs( - ctx context.Context, roomNIDs []types.RoomNID, + ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) { - rows, err := s.selectRoomVersionsForRoomNIDsStmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs)) + stmt := sqlutil.TxStmt(txn, s.selectRoomVersionsForRoomNIDsStmt) + rows, err := stmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs)) if err != nil { return nil, err } @@ -239,12 +242,13 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs( return result, nil } -func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { +func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) { var array pq.Int64Array for _, nid := range roomNIDs { array = append(array, int64(nid)) } - rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array) + stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomIDsStmt) + rows, err := stmt.QueryContext(ctx, array) if err != nil { return nil, err } @@ -260,12 +264,13 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types return roomIDs, nil } -func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { +func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) { var array pq.StringArray for _, roomID := range roomIDs { array = append(array, roomID) } - rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array) + stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomNIDsStmt) + rows, err := stmt.QueryContext(ctx, array) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index 27d85e83..6f8f9e1b 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -86,8 +86,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { } func (s *stateBlockStatements) BulkInsertStateData( - ctx context.Context, - txn *sql.Tx, + ctx context.Context, txn *sql.Tx, entries types.StateEntries, ) (id types.StateBlockNID, err error) { entries = entries[:util.SortAndUnique(entries)] @@ -95,16 +94,18 @@ func (s *stateBlockStatements) BulkInsertStateData( for _, e := range entries { nids = append(nids, e.EventNID) } - err = s.insertStateDataStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt) + err = stmt.QueryRowContext( ctx, nids.Hash(), eventNIDsAsArray(nids), ).Scan(&id) return } func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs types.StateBlockNIDs, + ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs, ) ([][]types.EventNID, error) { - rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockEntriesStmt) + rows, err := stmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index 4fc0fa48..ce9f2463 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -105,13 +105,14 @@ func (s *stateSnapshotStatements) InsertState( } func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, + ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { nids := make([]int64, len(stateNIDs)) for i := range stateNIDs { nids[i] = int64(stateNIDs[i]) } - rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockNIDsStmt) + rows, err := stmt.QueryContext(ctx, pq.Int64Array(nids)) if err != nil { return nil, err } diff --git a/roomserver/storage/shared/latest_events_updater.go b/roomserver/storage/shared/latest_events_updater.go deleted file mode 100644 index 36865081..00000000 --- a/roomserver/storage/shared/latest_events_updater.go +++ /dev/null @@ -1,133 +0,0 @@ -package shared - -import ( - "context" - "database/sql" - "fmt" - - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type LatestEventsUpdater struct { - transaction - d *Database - roomInfo types.RoomInfo - latestEvents []types.StateAtEventAndReference - lastEventIDSent string - currentStateSnapshotNID types.StateSnapshotNID -} - -func rollback(txn *sql.Tx) { - if txn == nil { - return - } - txn.Rollback() // nolint: errcheck -} - -func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) { - eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := - d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID) - if err != nil { - rollback(txn) - return nil, err - } - stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) - if err != nil { - rollback(txn) - return nil, err - } - var lastEventIDSent string - if lastEventNIDSent != 0 { - lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent) - if err != nil { - rollback(txn) - return nil, err - } - } - return &LatestEventsUpdater{ - transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, - }, nil -} - -// RoomVersion implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { - return u.roomInfo.RoomVersion -} - -// LatestEvents implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) LatestEvents() []types.StateAtEventAndReference { - return u.latestEvents -} - -// LastEventIDSent implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) LastEventIDSent() string { - return u.lastEventIDSent -} - -// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { - return u.currentStateSnapshotNID -} - -// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer -func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { - for _, ref := range previousEventReferences { - if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { - return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) - } - } - return nil -} - -// IsReferenced implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { - err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) - if err == nil { - return true, nil - } - if err == sql.ErrNoRows { - return false, nil - } - return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err) -} - -// SetLatestEvents implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) SetLatestEvents( - roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, - currentStateSnapshotNID types.StateSnapshotNID, -) error { - eventNIDs := make([]types.EventNID, len(latest)) - for i := range latest { - eventNIDs[i] = latest[i].EventNID - } - return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil { - return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err) - } - if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok { - if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok { - roomInfo.StateSnapshotNID = currentStateSnapshotNID - roomInfo.IsStub = false - u.d.Cache.StoreRoomInfo(roomID, roomInfo) - } - } - return nil - }) -} - -// HasEventBeenSent implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { - return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID) -} - -// MarkEventAsSent implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { - return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID) - }) -} - -func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { - return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) -} diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go new file mode 100644 index 00000000..bb9f5dc6 --- /dev/null +++ b/roomserver/storage/shared/room_updater.go @@ -0,0 +1,262 @@ +package shared + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type RoomUpdater struct { + transaction + d *Database + roomInfo *types.RoomInfo + latestEvents []types.StateAtEventAndReference + lastEventIDSent string + currentStateSnapshotNID types.StateSnapshotNID +} + +func rollback(txn *sql.Tx) { + if txn == nil { + return + } + txn.Rollback() // nolint: errcheck +} + +func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *types.RoomInfo) (*RoomUpdater, error) { + // If the roomInfo is nil then that means that the room doesn't exist + // yet, so we can't do `SelectLatestEventsNIDsForUpdate` because that + // would involve locking a row on the table that doesn't exist. Instead + // we will just run with a normal database transaction. It'll either + // succeed, processing a create event which creates the room, or it won't. + if roomInfo == nil { + return &RoomUpdater{ + transaction{ctx, txn}, d, nil, nil, "", 0, + }, nil + } + + eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := + d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID) + if err != nil { + rollback(txn) + return nil, err + } + stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) + if err != nil { + rollback(txn) + return nil, err + } + var lastEventIDSent string + if lastEventNIDSent != 0 { + lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent) + if err != nil { + rollback(txn) + return nil, err + } + } + return &RoomUpdater{ + transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + }, nil +} + +// Implements sqlutil.Transaction +func (u *RoomUpdater) Commit() error { + if u.txn == nil { // SQLite mode probably + return nil + } + return u.txn.Commit() +} + +// Implements sqlutil.Transaction +func (u *RoomUpdater) Rollback() error { + if u.txn == nil { // SQLite mode probably + return nil + } + return u.txn.Rollback() +} + +// RoomVersion implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { + return u.roomInfo.RoomVersion +} + +// LatestEvents implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) LatestEvents() []types.StateAtEventAndReference { + return u.latestEvents +} + +// LastEventIDSent implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) LastEventIDSent() string { + return u.lastEventIDSent +} + +// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { + return u.currentStateSnapshotNID +} + +// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer +func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + for _, ref := range previousEventReferences { + if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) + } + } + return nil + }) +} + +func (u *RoomUpdater) Events( + ctx context.Context, eventNIDs []types.EventNID, +) ([]types.Event, error) { + return u.d.events(ctx, u.txn, eventNIDs) +} + +func (u *RoomUpdater) SnapshotNIDFromEventID( + ctx context.Context, eventID string, +) (types.StateSnapshotNID, error) { + return u.d.snapshotNIDFromEventID(ctx, u.txn, eventID) +} + +func (u *RoomUpdater) StoreEvent( + ctx context.Context, event *gomatrixserverlib.Event, + authEventNIDs []types.EventNID, isRejected bool, +) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { + return u.d.storeEvent(ctx, u, event, authEventNIDs, isRejected) +} + +func (u *RoomUpdater) StateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + return u.d.stateBlockNIDs(ctx, u.txn, stateNIDs) +} + +func (u *RoomUpdater) StateEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + return u.d.stateEntries(ctx, u.txn, stateBlockNIDs) +} + +func (u *RoomUpdater) StateEntriesForTuples( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + return u.d.stateEntriesForTuples(ctx, u.txn, stateBlockNIDs, stateKeyTuples) +} + +func (u *RoomUpdater) AddState( + ctx context.Context, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, +) (stateNID types.StateSnapshotNID, err error) { + return u.d.addState(ctx, u.txn, roomNID, stateBlockNIDs, state) +} + +func (u *RoomUpdater) SetState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + return u.d.EventsTable.UpdateEventState(ctx, txn, eventNID, stateNID) + }) +} + +func (u *RoomUpdater) EventTypeNIDs( + ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + return u.d.eventTypeNIDs(ctx, u.txn, eventTypes) +} + +func (u *RoomUpdater) EventStateKeyNIDs( + ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + return u.d.eventStateKeyNIDs(ctx, u.txn, eventStateKeys) +} + +func (u *RoomUpdater) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { + return u.d.roomInfo(ctx, u.txn, roomID) +} + +func (u *RoomUpdater) EventIDs( + ctx context.Context, eventNIDs []types.EventNID, +) (map[types.EventNID]string, error) { + return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs) +} + +func (u *RoomUpdater) StateAtEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateAtEvent, error) { + return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) +} + +func (u *RoomUpdater) StateEntriesForEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateEntry, error) { + return u.d.EventsTable.BulkSelectStateEventByID(ctx, u.txn, eventIDs) +} + +func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + return u.d.eventsFromIDs(ctx, u.txn, eventIDs) +} + +func (u *RoomUpdater) GetMembershipEventNIDsForRoom( + ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, +) ([]types.EventNID, error) { + return u.d.getMembershipEventNIDsForRoom(ctx, u.txn, roomNID, joinOnly, localOnly) +} + +// IsReferenced implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { + err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) + if err == nil { + return true, nil + } + if err == sql.ErrNoRows { + return false, nil + } + return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err) +} + +// SetLatestEvents implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) SetLatestEvents( + roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, + currentStateSnapshotNID types.StateSnapshotNID, +) error { + eventNIDs := make([]types.EventNID, len(latest)) + for i := range latest { + eventNIDs[i] = latest[i].EventNID + } + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil { + return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err) + } + if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok { + if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok { + roomInfo.StateSnapshotNID = currentStateSnapshotNID + roomInfo.IsStub = false + u.d.Cache.StoreRoomInfo(roomID, roomInfo) + } + } + return nil + }) +} + +// HasEventBeenSent implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { + return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID) +} + +// MarkEventAsSent implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error { + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID) + }) +} + +func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { + return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index d4c5ebb5..127cd1f5 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -26,23 +26,23 @@ import ( const redactionsArePermanent = true type Database struct { - DB *sql.DB - Cache caching.RoomServerCaches - Writer sqlutil.Writer - EventsTable tables.Events - EventJSONTable tables.EventJSON - EventTypesTable tables.EventTypes - EventStateKeysTable tables.EventStateKeys - RoomsTable tables.Rooms - StateSnapshotTable tables.StateSnapshot - StateBlockTable tables.StateBlock - RoomAliasesTable tables.RoomAliases - PrevEventsTable tables.PreviousEvents - InvitesTable tables.Invites - MembershipTable tables.Membership - PublishedTable tables.Published - RedactionsTable tables.Redactions - GetLatestEventsForUpdateFn func(ctx context.Context, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) + DB *sql.DB + Cache caching.RoomServerCaches + Writer sqlutil.Writer + EventsTable tables.Events + EventJSONTable tables.EventJSON + EventTypesTable tables.EventTypes + EventStateKeysTable tables.EventStateKeys + RoomsTable tables.Rooms + StateSnapshotTable tables.StateSnapshot + StateBlockTable tables.StateBlock + RoomAliasesTable tables.RoomAliases + PrevEventsTable tables.PreviousEvents + InvitesTable tables.Invites + MembershipTable tables.Membership + PublishedTable tables.Published + RedactionsTable tables.Redactions + GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) } func (d *Database) SupportsConcurrentRoomInputs() bool { @@ -51,6 +51,12 @@ func (d *Database) SupportsConcurrentRoomInputs() bool { func (d *Database) EventTypeNIDs( ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + return d.eventTypeNIDs(ctx, nil, eventTypes) +} + +func (d *Database) eventTypeNIDs( + ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { result := make(map[string]types.EventTypeNID) remaining := []string{} @@ -62,7 +68,7 @@ func (d *Database) EventTypeNIDs( } } if len(remaining) > 0 { - nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, remaining) + nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, remaining) if err != nil { return nil, err } @@ -77,11 +83,17 @@ func (d *Database) EventTypeNIDs( func (d *Database) EventStateKeys( ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { - return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs) + return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs) } func (d *Database) EventStateKeyNIDs( ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + return d.eventStateKeyNIDs(ctx, nil, eventStateKeys) +} + +func (d *Database) eventStateKeyNIDs( + ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { result := make(map[string]types.EventStateKeyNID) remaining := []string{} @@ -93,7 +105,7 @@ func (d *Database) EventStateKeyNIDs( } } if len(remaining) > 0 { - nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, remaining) + nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, remaining) if err != nil { return nil, err } @@ -108,23 +120,31 @@ func (d *Database) EventStateKeyNIDs( func (d *Database) StateEntriesForEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateEntry, error) { - return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs) + return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs) } func (d *Database) StateEntriesForTuples( ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + return d.stateEntriesForTuples(ctx, nil, stateBlockNIDs, stateKeyTuples) +} + +func (d *Database) stateEntriesForTuples( + ctx context.Context, txn *sql.Tx, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntryList, error) { entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( - ctx, stateBlockNIDs, + ctx, txn, stateBlockNIDs, ) if err != nil { return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) } lists := []types.StateEntryList{} for i, entry := range entries { - entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, stateKeyTuples) + entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, stateKeyTuples) if err != nil { return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) } @@ -137,10 +157,14 @@ func (d *Database) StateEntriesForTuples( } func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { + return d.roomInfo(ctx, nil, roomID) +} + +func (d *Database) roomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) { if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok { return &roomInfo, nil } - roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, roomID) + roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID) if err == nil && roomInfo != nil { d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID) d.Cache.StoreRoomInfo(roomID, *roomInfo) @@ -153,13 +177,22 @@ func (d *Database) AddState( roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, +) (stateNID types.StateSnapshotNID, err error) { + return d.addState(ctx, nil, roomNID, stateBlockNIDs, state) +} + +func (d *Database) addState( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, ) (stateNID types.StateSnapshotNID, err error) { if len(stateBlockNIDs) > 0 && len(state) > 0 { // Check to see if the event already appears in any of the existing state // blocks. If it does then we should not add it again, as this will just // result in excess state blocks and snapshots. // TODO: Investigate why this is happening - probably input_events.go! - blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) + blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs) if berr != nil { return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr) } @@ -180,7 +213,7 @@ func (d *Database) AddState( } } } - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { if len(state) > 0 { // If there's any state left to add then let's add new blocks. var stateBlockNID types.StateBlockNID @@ -205,7 +238,13 @@ func (d *Database) AddState( func (d *Database) EventNIDs( ctx context.Context, eventIDs []string, ) (map[string]types.EventNID, error) { - return d.EventsTable.BulkSelectEventNID(ctx, eventIDs) + return d.eventNIDs(ctx, nil, eventIDs) +} + +func (d *Database) eventNIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) (map[string]types.EventNID, error) { + return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs) } func (d *Database) SetState( @@ -219,24 +258,34 @@ func (d *Database) SetState( func (d *Database) StateAtEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { - return d.EventsTable.BulkSelectStateAtEventByID(ctx, eventIDs) + return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs) } func (d *Database) SnapshotNIDFromEventID( ctx context.Context, eventID string, ) (types.StateSnapshotNID, error) { - _, stateNID, err := d.EventsTable.SelectEvent(ctx, nil, eventID) + return d.snapshotNIDFromEventID(ctx, nil, eventID) +} + +func (d *Database) snapshotNIDFromEventID( + ctx context.Context, txn *sql.Tx, eventID string, +) (types.StateSnapshotNID, error) { + _, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID) return stateNID, err } func (d *Database) EventIDs( ctx context.Context, eventNIDs []types.EventNID, ) (map[types.EventNID]string, error) { - return d.EventsTable.BulkSelectEventID(ctx, eventNIDs) + return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) } func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - nidMap, err := d.EventNIDs(ctx, eventIDs) + return d.eventsFromIDs(ctx, nil, eventIDs) +} + +func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) { + nidMap, err := d.eventNIDs(ctx, txn, eventIDs) if err != nil { return nil, err } @@ -246,7 +295,7 @@ func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]type nids = append(nids, nid) } - return d.Events(ctx, nids) + return d.events(ctx, txn, nids) } func (d *Database) LatestEventIDs( @@ -271,21 +320,33 @@ func (d *Database) LatestEventIDs( func (d *Database) StateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { - return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs) + return d.stateBlockNIDs(ctx, nil, stateNIDs) +} + +func (d *Database) stateBlockNIDs( + ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, txn, stateNIDs) } func (d *Database) StateEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + return d.stateEntries(ctx, nil, stateBlockNIDs) +} + +func (d *Database) stateEntries( + ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( - ctx, stateBlockNIDs, + ctx, txn, stateBlockNIDs, ) if err != nil { return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) } lists := make([]types.StateEntryList, 0, len(entries)) for i, entry := range entries { - eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, nil) + eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, nil) if err != nil { return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) } @@ -304,17 +365,17 @@ func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string } func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { - return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, alias) + return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, nil, alias) } func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { - return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, roomID) + return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, nil, roomID) } func (d *Database) GetCreatorIDForAlias( ctx context.Context, alias string, ) (string, error) { - return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, alias) + return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, nil, alias) } func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { @@ -335,7 +396,7 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req senderMembershipEventNID, senderMembership, isRoomforgotten, err := d.MembershipTable.SelectMembershipFromRoomAndTarget( - ctx, roomNID, requestSenderUserNID, + ctx, nil, roomNID, requestSenderUserNID, ) if err == sql.ErrNoRows { // The user has never been a member of that room @@ -349,14 +410,20 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req func (d *Database) GetMembershipEventNIDsForRoom( ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, +) ([]types.EventNID, error) { + return d.getMembershipEventNIDsForRoom(ctx, nil, roomNID, joinOnly, localOnly) +} + +func (d *Database) getMembershipEventNIDsForRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, joinOnly bool, localOnly bool, ) ([]types.EventNID, error) { if joinOnly { return d.MembershipTable.SelectMembershipsFromRoomAndMembership( - ctx, roomNID, tables.MembershipStateJoin, localOnly, + ctx, txn, roomNID, tables.MembershipStateJoin, localOnly, ) } - return d.MembershipTable.SelectMembershipsFromRoom(ctx, roomNID, localOnly) + return d.MembershipTable.SelectMembershipsFromRoom(ctx, txn, roomNID, localOnly) } func (d *Database) GetInvitesForUser( @@ -364,22 +431,28 @@ func (d *Database) GetInvitesForUser( roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) { - return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) + return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID) } func (d *Database) Events( ctx context.Context, eventNIDs []types.EventNID, ) ([]types.Event, error) { - eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) + return d.events(ctx, nil, eventNIDs) +} + +func (d *Database) events( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, +) ([]types.Event, error) { + eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs) if err != nil { return nil, err } - eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs) if err != nil { eventIDs = map[types.EventNID]string{} } var roomNIDs map[types.EventNID]types.RoomNID - roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, eventNIDs) + roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, txn, eventNIDs) if err != nil { return nil, err } @@ -398,7 +471,7 @@ func (d *Database) Events( } fetchNIDList = append(fetchNIDList, n) } - dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, fetchNIDList) + dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, fetchNIDList) if err != nil { return nil, err } @@ -440,19 +513,19 @@ func (d *Database) MembershipUpdater( return updater, err } -func (d *Database) GetLatestEventsForUpdate( - ctx context.Context, roomInfo types.RoomInfo, -) (*LatestEventsUpdater, error) { - if d.GetLatestEventsForUpdateFn != nil { - return d.GetLatestEventsForUpdateFn(ctx, roomInfo) +func (d *Database) GetRoomUpdater( + ctx context.Context, roomInfo *types.RoomInfo, +) (*RoomUpdater, error) { + if d.GetRoomUpdaterFn != nil { + return d.GetRoomUpdaterFn(ctx, roomInfo) } txn, err := d.DB.Begin() if err != nil { return nil, err } - var updater *LatestEventsUpdater + var updater *RoomUpdater _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { - updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo) + updater, err = NewRoomUpdater(ctx, d, txn, roomInfo) return err }) return updater, err @@ -461,6 +534,13 @@ func (d *Database) GetLatestEventsForUpdate( func (d *Database) StoreEvent( ctx context.Context, event *gomatrixserverlib.Event, authEventNIDs []types.EventNID, isRejected bool, +) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { + return d.storeEvent(ctx, nil, event, authEventNIDs, isRejected) +} + +func (d *Database) storeEvent( + ctx context.Context, updater *RoomUpdater, event *gomatrixserverlib.Event, + authEventNIDs []types.EventNID, isRejected bool, ) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { var ( roomNID types.RoomNID @@ -472,8 +552,11 @@ func (d *Database) StoreEvent( redactedEventID string err error ) - - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var txn *sql.Tx + if updater != nil { + txn = updater.txn + } + err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { // TODO: Here we should aim to have two different code paths for new rooms // vs existing ones. @@ -546,42 +629,32 @@ func (d *Database) StoreEvent( // events updater because it somewhat works as a mutex, ensuring // that there's a row-level lock on the latest room events (well, // on Postgres at least). - var roomInfo *types.RoomInfo - var updater *LatestEventsUpdater if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { - roomInfo, err = d.RoomInfo(ctx, event.RoomID()) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) - } - if roomInfo == nil && len(prevEvents) > 0 { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) - } // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This // function only does SELECTs though so the created txn (at this point) is just a read txn like // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater // to do writes however then this will need to go inside `Writer.Do`. - updater, err = d.GetLatestEventsForUpdate(ctx, *roomInfo) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("NewLatestEventsUpdater: %w", err) - } - // Ensure that we atomically store prev events AND commit them. If we don't wrap StorePreviousEvents - // and EndTransaction in a writer then it's possible for a new write txn to be made between the two - // function calls which will then fail with 'database is locked'. This new write txn would HAVE to be - // something like SetRoomAlias/RemoveRoomAlias as normal input events are already done sequentially due to - // SupportsConcurrentRoomInputs() == false on sqlite, though this does not apply to setting room aliases - // as they don't go via InputRoomEvents - err = d.Writer.Do(d.DB, updater.txn, func(txn *sql.Tx) error { - if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { - return fmt.Errorf("updater.StorePreviousEvents: %w", err) + succeeded := false + if updater == nil { + var roomInfo *types.RoomInfo + roomInfo, err = d.RoomInfo(ctx, event.RoomID()) + if err != nil { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) } - succeeded := true - err = sqlutil.EndTransaction(updater, &succeeded) - return err - }) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", err + if roomInfo == nil && len(prevEvents) > 0 { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) + } + updater, err = d.GetRoomUpdater(ctx, roomInfo) + if err != nil { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err) + } + defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) } + if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err) + } + succeeded = true } return eventNID, roomNID, types.StateAtEvent{ @@ -603,7 +676,7 @@ func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) } func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { - return d.PublishedTable.SelectAllPublishedRooms(ctx, true) + return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true) } func (d *Database) assignRoomNID( @@ -875,14 +948,14 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s eventNIDs = append(eventNIDs, e.EventNID) } } - eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) if err != nil { eventIDs = map[types.EventNID]string{} } // return the event requested for _, e := range entries { if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID { - data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID}) + data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{e.EventNID}) if err != nil { return nil, err } @@ -922,11 +995,11 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership } return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err) } - roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState) + roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState) if err != nil { return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err) } - roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs) + roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs) if err != nil { return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err) } @@ -945,7 +1018,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } // we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which // isn't a failure. - eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes) + eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, eventTypes) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err) } @@ -965,7 +1038,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } - eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys) + eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, eventStateKeys) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err) } @@ -999,11 +1072,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } } } - eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) if err != nil { eventIDs = map[types.EventNID]string{} } - events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) + events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err) } @@ -1027,11 +1100,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { - roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs) + roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs) if err != nil { return nil, err } - userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs) + userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs) if err != nil { return nil, err } @@ -1041,7 +1114,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) stateKeyNIDs[i] = nid i++ } - nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs) + nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs) if err != nil { return nil, err } @@ -1057,12 +1130,12 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { - return d.MembershipTable.SelectLocalServerInRoom(ctx, roomNID) + return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID) } // GetServerInRoom returns true if we think a server is in a given room or false otherwise. func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { - return d.MembershipTable.SelectServerInRoom(ctx, roomNID, serverName) + return d.MembershipTable.SelectServerInRoom(ctx, nil, roomNID, serverName) } // GetKnownUsers searches all users that userID knows about. @@ -1071,17 +1144,17 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin if err != nil { return nil, err } - return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit) + return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) } // GetKnownRooms returns a list of all rooms we know about. func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { - return d.RoomsTable.SelectRoomIDs(ctx) + return d.RoomsTable.SelectRoomIDs(ctx, nil) } // ForgetRoom sets a users room to forgotten func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error { - roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, []string{roomID}) + roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, []string{roomID}) if err != nil { return err } diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index 53b21929..f470ea32 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -76,15 +76,20 @@ func (s *eventJSONStatements) InsertEventJSON( } func (s *eventJSONStatements) BulkSelectEventJSON( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]tables.EventJSONPair, error) { iEventNIDs := make([]interface{}, len(eventNIDs)) for k, v := range eventNIDs { iEventNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) - - rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, selectOrig, iEventNIDs...) + } else { + rows, err = s.db.QueryContext(ctx, selectOrig, iEventNIDs...) + } if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index 62fbce2d..bf12d5b8 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -112,15 +112,20 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( - ctx context.Context, eventStateKeys []string, + ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { iEventStateKeys := make([]interface{}, len(eventStateKeys)) for k, v := range eventStateKeys { iEventStateKeys[k] = v } selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1) - - rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, selectOrig, iEventStateKeys...) + } else { + rows, err = s.db.QueryContext(ctx, selectOrig, iEventStateKeys...) + } if err != nil { return nil, err } @@ -138,15 +143,19 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKey( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs)) for k, v := range eventStateKeyNIDs { iEventStateKeyNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1) - - rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + stmt := sqlutil.TxStmt(txn, selectPrep) + rows, err := stmt.QueryContext(ctx, iEventStateKeyNIDs...) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 22df3fb2..f2c9c42f 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -128,7 +128,7 @@ func (s *eventTypeStatements) SelectEventTypeNID( } func (s *eventTypeStatements) BulkSelectEventTypeNID( - ctx context.Context, eventTypes []string, + ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { /////////////// iEventTypes := make([]interface{}, len(eventTypes)) @@ -140,9 +140,10 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID( if err != nil { return nil, err } + stmt := sqlutil.TxStmt(txn, selectPrep) /////////////// - rows, err := selectPrep.QueryContext(ctx, iEventTypes...) + rows, err := stmt.QueryContext(ctx, iEventTypes...) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 7483e281..e1e6a597 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -184,7 +184,7 @@ func (s *eventStatements) SelectEvent( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError func (s *eventStatements) BulkSelectStateEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateEntry, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) @@ -196,6 +196,7 @@ func (s *eventStatements) BulkSelectStateEventByID( if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventIDs...) @@ -235,7 +236,7 @@ func (s *eventStatements) BulkSelectStateEventByID( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError func (s *eventStatements) BulkSelectStateEventByNID( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { tuples := stateKeyTupleSorter(stateKeyTuples) @@ -263,6 +264,7 @@ func (s *eventStatements) BulkSelectStateEventByNID( if err != nil { return nil, fmt.Errorf("s.db.Prepare: %w", err) } + selectStmt = sqlutil.TxStmt(txn, selectStmt) rows, err := selectStmt.QueryContext(ctx, params...) if err != nil { return nil, fmt.Errorf("selectStmt.QueryContext: %w", err) @@ -291,7 +293,7 @@ func (s *eventStatements) BulkSelectStateEventByNID( // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. func (s *eventStatements) BulkSelectStateAtEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateAtEvent, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) @@ -303,6 +305,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID( if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { @@ -381,6 +384,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( if err != nil { return nil, err } + selectPrep = sqlutil.TxStmt(txn, selectPrep) ////////////// rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) @@ -454,7 +458,7 @@ func (s *eventStatements) BulkSelectEventReference( } // bulkSelectEventID returns a map from numeric event ID to string event ID. -func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { +func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { /////////////// iEventNIDs := make([]interface{}, len(eventNIDs)) for k, v := range eventNIDs { @@ -465,6 +469,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) @@ -490,7 +495,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) for k, v := range eventIDs { @@ -501,6 +506,7 @@ func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []str if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { @@ -538,13 +544,14 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, } func (s *eventStatements) SelectRoomNIDsForEventNIDs( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) (map[types.EventNID]types.RoomNID, error) { sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1) sqlPrep, err := s.db.Prepare(sqlStr) if err != nil { return nil, err } + sqlPrep = sqlutil.TxStmt(txn, sqlPrep) iEventNIDs := make([]interface{}, len(eventNIDs)) for i, v := range eventNIDs { iEventNIDs[i] = v diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index c1d7347a..d54d313a 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -88,8 +88,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) { } func (s *inviteStatements) InsertInviteEvent( - ctx context.Context, - txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, + ctx context.Context, txn *sql.Tx, + inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { @@ -109,8 +109,8 @@ func (s *inviteStatements) InsertInviteEvent( } func (s *inviteStatements) UpdateInviteRetired( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventIDs []string, err error) { // gather all the event IDs we will retire stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) @@ -134,10 +134,11 @@ func (s *inviteStatements) UpdateInviteRetired( // selectInviteActiveForUserInRoom returns a list of sender state key NIDs func (s *inviteStatements) SelectInviteActiveForUserInRoom( - ctx context.Context, + ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, []string, error) { - rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( + stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt) + rows, err := stmt.QueryContext( ctx, targetUserNID, roomNID, ) if err != nil { diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 2e58431d..181b4b4c 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -184,17 +184,18 @@ func (s *membershipStatements) SelectMembershipForUpdate( } func (s *membershipStatements) SelectMembershipFromRoomAndTarget( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { - err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt) + err = stmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID, &forgotten) return } func (s *membershipStatements) SelectMembershipsFromRoom( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var selectStmt *sql.Stmt @@ -203,6 +204,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom( } else { selectStmt = s.selectMembershipsFromRoomStmt } + selectStmt = sqlutil.TxStmt(txn, selectStmt) rows, err := selectStmt.QueryContext(ctx, roomNID) if err != nil { return nil, err @@ -220,7 +222,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom( } func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var stmt *sql.Stmt @@ -229,6 +231,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( } else { stmt = s.selectMembershipsFromRoomAndMembershipStmt } + stmt = sqlutil.TxStmt(txn, stmt) rows, err := stmt.QueryContext(ctx, roomNID, membership) if err != nil { return @@ -258,9 +261,10 @@ func (s *membershipStatements) UpdateMembership( } func (s *membershipStatements) SelectRoomsWithMembership( - ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, + ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState, ) ([]types.RoomNID, error) { - rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt) + rows, err := stmt.QueryContext(ctx, membershipState, userID) if err != nil { return nil, err } @@ -276,13 +280,19 @@ func (s *membershipStatements) SelectRoomsWithMembership( return roomNIDs, nil } -func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { +func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { iRoomNIDs := make([]interface{}, len(roomNIDs)) for i, v := range roomNIDs { iRoomNIDs[i] = v } query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) - rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, query, iRoomNIDs...) + } else { + rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...) + } if err != nil { return nil, err } @@ -299,8 +309,9 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, return result, rows.Err() } -func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { - rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) +func (s *membershipStatements) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt) + rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) if err != nil { return nil, err } @@ -317,8 +328,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type } func (s *membershipStatements) UpdateForgetMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool, ) error { _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( @@ -327,9 +338,10 @@ func (s *membershipStatements) UpdateForgetMembership( return err } -func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { +func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) { var nid types.RoomNID - err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) + stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt) + err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil @@ -340,9 +352,10 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room return found, nil } -func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { +func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { var nid types.RoomNID - err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) + stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt) + err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index b07c0ac4..9e416ace 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -75,9 +75,10 @@ func (s *publishedStatements) UpsertRoomPublished( } func (s *publishedStatements) SelectPublishedFromRoomID( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (published bool, err error) { - err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) + stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt) + err = stmt.QueryRowContext(ctx, roomID).Scan(&published) if err == sql.ErrNoRows { return false, nil } @@ -85,9 +86,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID( } func (s *publishedStatements) SelectAllPublishedRooms( - ctx context.Context, published bool, + ctx context.Context, txn *sql.Tx, published bool, ) ([]string, error) { - rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) + stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) + rows, err := stmt.QueryContext(ctx, published) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index 323945b8..7c7bead9 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -91,9 +91,10 @@ func (s *roomAliasesStatements) InsertRoomAlias( } func (s *roomAliasesStatements) SelectRoomIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (roomID string, err error) { - err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) + stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt) + err = stmt.QueryRowContext(ctx, alias).Scan(&roomID) if err == sql.ErrNoRows { return "", nil } @@ -101,10 +102,11 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias( } func (s *roomAliasesStatements) SelectAliasesFromRoomID( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (aliases []string, err error) { aliases = []string{} - rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) + stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt) + rows, err := stmt.QueryContext(ctx, roomID) if err != nil { return } @@ -124,9 +126,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID( } func (s *roomAliasesStatements) SelectCreatorIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (creatorID string, err error) { - err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) + stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt) + err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID) if err == sql.ErrNoRows { return "", nil } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index c441daec..5413475e 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -107,8 +107,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { }.Prepare(db) } -func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { - rows, err := s.selectRoomIDsStmt.QueryContext(ctx) +func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt) + rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err } @@ -124,10 +125,11 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { return roomIDs, nil } -func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { +func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) { var info types.RoomInfo var latestNIDsJSON string - err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan( + stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan( &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON, ) if err != nil { @@ -224,13 +226,14 @@ func (s *roomStatements) UpdateLatestEventNIDs( } func (s *roomStatements) SelectRoomVersionsForRoomNIDs( - ctx context.Context, roomNIDs []types.RoomNID, + ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) { sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) sqlPrep, err := s.db.Prepare(sqlStr) if err != nil { return nil, err } + sqlPrep = sqlutil.TxStmt(txn, sqlPrep) iRoomNIDs := make([]interface{}, len(roomNIDs)) for i, v := range roomNIDs { iRoomNIDs[i] = v @@ -252,13 +255,19 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs( return result, nil } -func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { +func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) { iRoomNIDs := make([]interface{}, len(roomNIDs)) for i, v := range roomNIDs { iRoomNIDs[i] = v } sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) - rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, sqlQuery, iRoomNIDs...) + } else { + rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) + } if err != nil { return nil, err } @@ -274,13 +283,19 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types return roomIDs, nil } -func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { +func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) { iRoomIDs := make([]interface{}, len(roomIDs)) for i, v := range roomIDs { iRoomIDs[i] = v } sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) - rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, sqlQuery, iRoomIDs...) + } else { + rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomIDs...) + } if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 58b0b5dc..d51fc492 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -81,8 +81,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { } func (s *stateBlockStatements) BulkInsertStateData( - ctx context.Context, - txn *sql.Tx, + ctx context.Context, txn *sql.Tx, entries types.StateEntries, ) (id types.StateBlockNID, err error) { entries = entries[:util.SortAndUnique(entries)] @@ -94,14 +93,15 @@ func (s *stateBlockStatements) BulkInsertStateData( if err != nil { return 0, fmt.Errorf("json.Marshal: %w", err) } - err = s.insertStateDataStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt) + err = stmt.QueryRowContext( ctx, nids.Hash(), js, ).Scan(&id) return } func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs types.StateBlockNIDs, + ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs, ) ([][]types.EventNID, error) { intfs := make([]interface{}, len(stateBlockNIDs)) for i := range stateBlockNIDs { @@ -112,6 +112,7 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries( if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) rows, err := selectStmt.QueryContext(ctx, intfs...) if err != nil { return nil, err diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 040d99ae..3c4bde3f 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -106,7 +106,7 @@ func (s *stateSnapshotStatements) InsertState( } func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, + ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { nids := make([]interface{}, len(stateNIDs)) for k, v := range stateNIDs { @@ -117,6 +117,7 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) rows, err := selectStmt.QueryContext(ctx, nids...) if err != nil { diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 1fcc7989..325c253b 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -172,23 +172,23 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { return err } d.Database = shared.Database{ - DB: db, - Cache: cache, - Writer: sqlutil.NewExclusiveWriter(), - EventsTable: events, - EventTypesTable: eventTypes, - EventStateKeysTable: eventStateKeys, - EventJSONTable: eventJSON, - RoomsTable: rooms, - StateBlockTable: stateBlock, - StateSnapshotTable: stateSnapshot, - PrevEventsTable: prevEvents, - RoomAliasesTable: roomAliases, - InvitesTable: invites, - MembershipTable: membership, - PublishedTable: published, - RedactionsTable: redactions, - GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate, + DB: db, + Cache: cache, + Writer: sqlutil.NewExclusiveWriter(), + EventsTable: events, + EventTypesTable: eventTypes, + EventStateKeysTable: eventStateKeys, + EventJSONTable: eventJSON, + RoomsTable: rooms, + StateBlockTable: stateBlock, + StateSnapshotTable: stateSnapshot, + PrevEventsTable: prevEvents, + RoomAliasesTable: roomAliases, + InvitesTable: invites, + MembershipTable: membership, + PublishedTable: published, + RedactionsTable: redactions, + GetRoomUpdaterFn: d.GetRoomUpdater, } return nil } @@ -201,16 +201,16 @@ func (d *Database) SupportsConcurrentRoomInputs() bool { return false } -func (d *Database) GetLatestEventsForUpdate( - ctx context.Context, roomInfo types.RoomInfo, -) (*shared.LatestEventsUpdater, error) { +func (d *Database) GetRoomUpdater( + ctx context.Context, roomInfo *types.RoomInfo, +) (*shared.RoomUpdater, error) { // TODO: Do not use transactions. We should be holding open this transaction but we cannot have // multiple write transactions on sqlite. The code will perform additional // write transactions independent of this one which will consistently cause // 'database is locked' errors. As sqlite doesn't support multi-process on the // same DB anyway, and we only execute updates sequentially, the only worries // are for rolling back when things go wrong. (atomicity) - return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo) + return shared.NewRoomUpdater(ctx, &d.Database, nil, roomInfo) } func (d *Database) MembershipUpdater( diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 6ad7ed2e..fed39b94 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -18,20 +18,20 @@ type EventJSONPair struct { type EventJSON interface { // Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions). InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error - BulkSelectEventJSON(ctx context.Context, eventNIDs []types.EventNID) ([]EventJSONPair, error) + BulkSelectEventJSON(ctx context.Context, tx *sql.Tx, eventNIDs []types.EventNID) ([]EventJSONPair, error) } type EventTypes interface { InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) - BulkSelectEventTypeNID(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + BulkSelectEventTypeNID(ctx context.Context, txn *sql.Tx, eventTypes []string) (map[string]types.EventTypeNID, error) } type EventStateKeys interface { InsertEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) SelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) - BulkSelectEventStateKeyNID(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) - BulkSelectEventStateKey(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) + BulkSelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + BulkSelectEventStateKey(ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) } type Events interface { @@ -42,12 +42,12 @@ type Events interface { SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError - BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) - BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error) + BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateEntry, error) + BulkSelectStateEventByNID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error) // BulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. - BulkSelectStateAtEventByID(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + BulkSelectStateAtEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateAtEvent, error) UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error @@ -55,12 +55,12 @@ type Events interface { BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) // BulkSelectEventID returns a map from numeric event ID to string event ID. - BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) + BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. - BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) + BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) - SelectRoomNIDsForEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) + SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) } type Rooms interface { @@ -69,29 +69,29 @@ type Rooms interface { SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error - SelectRoomVersionsForRoomNIDs(ctx context.Context, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) - SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) - SelectRoomIDs(ctx context.Context) ([]string, error) - BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) - BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) + SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) + SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) + SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) + BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) + BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) } type StateSnapshot interface { InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error) - BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + BulkSelectStateBlockNIDs(ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) } type StateBlock interface { BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error) - BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error) + BulkSelectStateBlockEntries(ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error) //BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) } type RoomAliases interface { InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error) - SelectRoomIDFromAlias(ctx context.Context, alias string) (roomID string, err error) - SelectAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error) - SelectCreatorIDFromAlias(ctx context.Context, alias string) (creatorID string, err error) + SelectRoomIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (roomID string, err error) + SelectAliasesFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) ([]string, error) + SelectCreatorIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (creatorID string, err error) DeleteRoomAlias(ctx context.Context, txn *sql.Tx, alias string) (err error) } @@ -106,7 +106,7 @@ type Invites interface { InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error) UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error) // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids. - SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error) + SelectInviteActiveForUserInRoom(ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error) } type MembershipState int64 @@ -121,24 +121,24 @@ const ( type Membership interface { InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error) - SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error) - SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) - SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) + SelectMembershipFromRoomAndTarget(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error) + SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) + SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error - SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) + SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the // counts of how many rooms they are joined. - SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) - SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) + SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) + SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error - SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) - SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) + SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) + SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) } type Published interface { UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error) - SelectPublishedFromRoomID(ctx context.Context, roomID string) (published bool, err error) - SelectAllPublishedRooms(ctx context.Context, published bool) ([]string, error) + SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error) + SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, published bool) ([]string, error) } type RedactionInfo struct {