Full roomserver input transactional isolation (#2141)

* Add transaction to all database tables in roomserver, rename latest events updater to room updater, use room updater for all RS input

* Better transaction management

* Tweak order

* Handle cases where the room does not exist

* Other fixes

* More tweaks

* Fill some gaps

* Fill in the gaps

* good lord it gets worse

* Don't roll back transactions when events rejected

* Pass through errors properly

* Fix bugs

* Fix incorrect error check

* Don't panic on nil txns

* Tweaks

* Hopefully fix panics for good in SQLite this time

* Fix rollback

* Minor bug fixes with latest event updater

* Some review comments

* Revert "Some review comments"

This reverts commit 0caf8cf53e62c33f7b83c52e9df1d963871f751e.

* Fix a couple of bugs

* Clearer commit and rollback results

* Remove unnecessary prepares
This commit is contained in:
Neil Alexander 2022-02-04 10:39:34 +00:00 committed by GitHub
parent 4d9f5b2e57
commit eb352a5f6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 867 additions and 499 deletions

View File

@ -20,17 +20,22 @@ import (
"sort" "sort"
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
type checkForAuthAndSoftFailStorage interface {
state.StateResolutionStorage
StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error)
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
}
// CheckForSoftFail returns true if the event should be soft-failed // CheckForSoftFail returns true if the event should be soft-failed
// and false otherwise. The return error value should be checked before // and false otherwise. The return error value should be checked before
// the soft-fail bool. // the soft-fail bool.
func CheckForSoftFail( func CheckForSoftFail(
ctx context.Context, ctx context.Context,
db storage.Database, db checkForAuthAndSoftFailStorage,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
stateEventIDs []string, stateEventIDs []string,
) (bool, error) { ) (bool, error) {
@ -92,7 +97,7 @@ func CheckForSoftFail(
// Returns the numeric IDs for the auth events. // Returns the numeric IDs for the auth events.
func CheckAuthEvents( func CheckAuthEvents(
ctx context.Context, ctx context.Context,
db storage.Database, db checkForAuthAndSoftFailStorage,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
authEventIDs []string, authEventIDs []string,
) ([]types.EventNID, error) { ) ([]types.EventNID, error) {
@ -193,7 +198,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
// loadAuthEvents loads the events needed for authentication from the supplied room state. // loadAuthEvents loads the events needed for authentication from the supplied room state.
func loadAuthEvents( func loadAuthEvents(
ctx context.Context, ctx context.Context,
db storage.Database, db state.StateResolutionStorage,
needed gomatrixserverlib.StateNeeded, needed gomatrixserverlib.StateNeeded,
state []types.StateEntry, state []types.StateEntry,
) (result authEvents, err error) { ) (result authEvents, err error) {

View File

@ -19,6 +19,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"sync" "sync"
"time" "time"
@ -38,6 +39,19 @@ import (
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
type retryAction int
type commitAction int
const (
doNotRetry retryAction = iota
retryLater
)
const (
commitTransaction commitAction = iota
rollbackTransaction
)
var keyContentFields = map[string]string{ var keyContentFields = map[string]string{
"m.room.join_rules": "join_rule", "m.room.join_rules": "join_rule",
"m.room.history_visibility": "history_visibility", "m.room.history_visibility": "history_visibility",
@ -101,7 +115,8 @@ func (r *Inputer) Start() error {
_ = msg.InProgress() // resets the acknowledgement wait timer _ = msg.InProgress() // resets the acknowledgement wait timer
defer eventsInProgress.Delete(index) defer eventsInProgress.Delete(index)
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
if err := r.processRoomEvent(context.Background(), &inputRoomEvent); err != nil { action, err := r.processRoomEventUsingUpdater(context.Background(), roomID, &inputRoomEvent)
if err != nil {
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
sentry.CaptureException(err) sentry.CaptureException(err)
} }
@ -111,7 +126,12 @@ func (r *Inputer) Start() error {
"type": inputRoomEvent.Event.Type(), "type": inputRoomEvent.Event.Type(),
}).Warn("Roomserver failed to process async event") }).Warn("Roomserver failed to process async event")
} }
_ = msg.Ack() switch action {
case retryLater:
_ = msg.Nak()
case doNotRetry:
_ = msg.Ack()
}
}) })
}, },
// NATS wants to acknowledge automatically by default when the message is // NATS wants to acknowledge automatically by default when the message is
@ -131,6 +151,37 @@ func (r *Inputer) Start() error {
return err return err
} }
// processRoomEventUsingUpdater opens up a room updater and tries to
// process the event. It returns whether or not we should positively
// or negatively acknowledge the event (i.e. for NATS) and an error
// if it occurred.
func (r *Inputer) processRoomEventUsingUpdater(
ctx context.Context,
roomID string,
inputRoomEvent *api.InputRoomEvent,
) (retryAction, error) {
roomInfo, err := r.DB.RoomInfo(ctx, roomID)
if err != nil {
return doNotRetry, fmt.Errorf("r.DB.RoomInfo: %w", err)
}
updater, err := r.DB.GetRoomUpdater(ctx, roomInfo)
if err != nil {
return retryLater, fmt.Errorf("r.DB.GetRoomUpdater: %w", err)
}
action, err := r.processRoomEvent(ctx, updater, inputRoomEvent)
switch action {
case commitTransaction:
if cerr := updater.Commit(); cerr != nil {
return retryLater, fmt.Errorf("updater.Commit: %w", cerr)
}
case rollbackTransaction:
if rerr := updater.Rollback(); rerr != nil {
return retryLater, fmt.Errorf("updater.Rollback: %w", rerr)
}
}
return doNotRetry, err
}
// InputRoomEvents implements api.RoomserverInternalAPI // InputRoomEvents implements api.RoomserverInternalAPI
func (r *Inputer) InputRoomEvents( func (r *Inputer) InputRoomEvents(
ctx context.Context, ctx context.Context,
@ -177,7 +228,7 @@ func (r *Inputer) InputRoomEvents(
worker.Act(nil, func() { worker.Act(nil, func() {
defer eventsInProgress.Delete(index) defer eventsInProgress.Delete(index)
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
err := r.processRoomEvent(ctx, &inputRoomEvent) _, err := r.processRoomEventUsingUpdater(ctx, roomID, &inputRoomEvent)
if err != nil { if err != nil {
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
sentry.CaptureException(err) sentry.CaptureException(err)

View File

@ -29,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -67,14 +68,15 @@ var processRoomEventDuration = prometheus.NewHistogramVec(
// nolint:gocyclo // nolint:gocyclo
func (r *Inputer) processRoomEvent( func (r *Inputer) processRoomEvent(
ctx context.Context, ctx context.Context,
updater *shared.RoomUpdater,
input *api.InputRoomEvent, input *api.InputRoomEvent,
) (err error) { ) (commitAction, error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
// Before we do anything, make sure the context hasn't expired for this pending task. // Before we do anything, make sure the context hasn't expired for this pending task.
// If it has then we'll give up straight away — it's probably a synchronous input // If it has then we'll give up straight away — it's probably a synchronous input
// request and the caller has already given up, but the inbox task was still queued. // request and the caller has already given up, but the inbox task was still queued.
return context.DeadlineExceeded return rollbackTransaction, context.DeadlineExceeded
default: default:
} }
@ -107,7 +109,7 @@ func (r *Inputer) processRoomEvent(
// if we have already got this event then do not process it again, if the input kind is an outlier. // if we have already got this event then do not process it again, if the input kind is an outlier.
// Outliers contain no extra information which may warrant a re-processing. // Outliers contain no extra information which may warrant a re-processing.
if input.Kind == api.KindOutlier { if input.Kind == api.KindOutlier {
evs, err2 := r.DB.EventsFromIDs(ctx, []string{event.EventID()}) evs, err2 := updater.EventsFromIDs(ctx, []string{event.EventID()})
if err2 == nil && len(evs) == 1 { if err2 == nil && len(evs) == 1 {
// check hash matches if we're on early room versions where the event ID was a random string // check hash matches if we're on early room versions where the event ID was a random string
idFormat, err2 := headered.RoomVersion.EventIDFormat() idFormat, err2 := headered.RoomVersion.EventIDFormat()
@ -116,11 +118,11 @@ func (r *Inputer) processRoomEvent(
case gomatrixserverlib.EventIDFormatV1: case gomatrixserverlib.EventIDFormatV1:
if bytes.Equal(event.EventReference().EventSHA256, evs[0].EventReference().EventSHA256) { if bytes.Equal(event.EventReference().EventSHA256, evs[0].EventReference().EventSHA256) {
logger.Debugf("Already processed event; ignoring") logger.Debugf("Already processed event; ignoring")
return nil return rollbackTransaction, nil
} }
default: default:
logger.Debugf("Already processed event; ignoring") logger.Debugf("Already processed event; ignoring")
return nil return rollbackTransaction, nil
} }
} }
} }
@ -134,8 +136,8 @@ func (r *Inputer) processRoomEvent(
AuthEventIDs: event.AuthEventIDs(), AuthEventIDs: event.AuthEventIDs(),
PrevEventIDs: event.PrevEventIDs(), PrevEventIDs: event.PrevEventIDs(),
} }
if err = r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil { if err := r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil {
return fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err) return rollbackTransaction, fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err)
} }
} }
missingAuth := len(missingRes.MissingAuthEventIDs) > 0 missingAuth := len(missingRes.MissingAuthEventIDs) > 0
@ -146,8 +148,8 @@ func (r *Inputer) processRoomEvent(
RoomID: event.RoomID(), RoomID: event.RoomID(),
ExcludeSelf: true, ExcludeSelf: true,
} }
if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { if err := r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil {
return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) return rollbackTransaction, fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)
} }
// Sort all of the servers into a map so that we can randomise // Sort all of the servers into a map so that we can randomise
// their order. Then make sure that the input origin and the // their order. Then make sure that the input origin and the
@ -176,8 +178,8 @@ func (r *Inputer) processRoomEvent(
isRejected := false isRejected := false
authEvents := gomatrixserverlib.NewAuthEvents(nil) authEvents := gomatrixserverlib.NewAuthEvents(nil)
knownEvents := map[string]*types.Event{} knownEvents := map[string]*types.Event{}
if err = r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { if err := r.fetchAuthEvents(ctx, updater, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
return fmt.Errorf("r.fetchAuthEvents: %w", err) return rollbackTransaction, fmt.Errorf("r.fetchAuthEvents: %w", err)
} }
// Check if the event is allowed by its auth events. If it isn't then // Check if the event is allowed by its auth events. If it isn't then
@ -193,7 +195,7 @@ func (r *Inputer) processRoomEvent(
authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) authEventNIDs := make([]types.EventNID, 0, len(authEventIDs))
for _, authEventID := range authEventIDs { for _, authEventID := range authEventIDs {
if _, ok := knownEvents[authEventID]; !ok { if _, ok := knownEvents[authEventID]; !ok {
return fmt.Errorf("missing auth event %s", authEventID) return rollbackTransaction, fmt.Errorf("missing auth event %s", authEventID)
} }
authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID)
} }
@ -202,7 +204,8 @@ func (r *Inputer) processRoomEvent(
if input.Kind == api.KindNew { if input.Kind == api.KindNew {
// Check that the event passes authentication checks based on the // Check that the event passes authentication checks based on the
// current room state. // current room state.
softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) var err error
softfail, err = helpers.CheckForSoftFail(ctx, updater, headered, input.StateEventIDs)
if err != nil { if err != nil {
logger.WithError(err).Warn("Error authing soft-failed event") logger.WithError(err).Warn("Error authing soft-failed event")
} }
@ -227,7 +230,7 @@ func (r *Inputer) processRoomEvent(
origin: input.Origin, origin: input.Origin,
inputer: r, inputer: r,
queryer: r.Queryer, queryer: r.Queryer,
db: r.DB, db: updater,
federation: r.FSAPI, federation: r.FSAPI,
keys: r.KeyRing, keys: r.KeyRing,
roomsMu: internal.NewMutexByRoom(), roomsMu: internal.NewMutexByRoom(),
@ -235,7 +238,7 @@ func (r *Inputer) processRoomEvent(
hadEvents: map[string]bool{}, hadEvents: map[string]bool{},
haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{}, haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{},
} }
if err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { if err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil {
isRejected = true isRejected = true
rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err) rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err)
} else { } else {
@ -248,16 +251,16 @@ func (r *Inputer) processRoomEvent(
} }
// Store the event. // Store the event.
_, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected) _, _, stateAtEvent, redactionEvent, redactedEventID, err := updater.StoreEvent(ctx, event, authEventNIDs, isRejected)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.StoreEvent: %w", err) return rollbackTransaction, fmt.Errorf("updater.StoreEvent: %w", err)
} }
// if storing this event results in it being redacted then do so. // if storing this event results in it being redacted then do so.
if !isRejected && redactedEventID == event.EventID() { if !isRejected && redactedEventID == event.EventID() {
r, rerr := eventutil.RedactEvent(redactionEvent, event) r, rerr := eventutil.RedactEvent(redactionEvent, event)
if rerr != nil { if rerr != nil {
return fmt.Errorf("eventutil.RedactEvent: %w", rerr) return rollbackTransaction, fmt.Errorf("eventutil.RedactEvent: %w", rerr)
} }
event = r event = r
} }
@ -268,23 +271,23 @@ func (r *Inputer) processRoomEvent(
if input.Kind == api.KindOutlier { if input.Kind == api.KindOutlier {
logger.Debug("Stored outlier") logger.Debug("Stored outlier")
hooks.Run(hooks.KindNewEventPersisted, headered) hooks.Run(hooks.KindNewEventPersisted, headered)
return nil return commitTransaction, nil
} }
roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID()) roomInfo, err := updater.RoomInfo(ctx, event.RoomID())
if err != nil { if err != nil {
return fmt.Errorf("r.DB.RoomInfo: %w", err) return rollbackTransaction, fmt.Errorf("updater.RoomInfo: %w", err)
} }
if roomInfo == nil { if roomInfo == nil {
return fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID()) return rollbackTransaction, fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID())
} }
if !missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0 { if !missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0 {
// We haven't calculated a state for this event yet. // We haven't calculated a state for this event yet.
// Lets calculate one. // Lets calculate one.
err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected) err = r.calculateAndSetState(ctx, updater, input, roomInfo, &stateAtEvent, event, isRejected)
if err != nil { if err != nil {
return fmt.Errorf("r.calculateAndSetState: %w", err) return rollbackTransaction, fmt.Errorf("r.calculateAndSetState: %w", err)
} }
} }
@ -294,13 +297,14 @@ func (r *Inputer) processRoomEvent(
"soft_fail": softfail, "soft_fail": softfail,
"missing_prev": missingPrev, "missing_prev": missingPrev,
}).Warn("Stored rejected event") }).Warn("Stored rejected event")
return rejectionErr return commitTransaction, rejectionErr
} }
switch input.Kind { switch input.Kind {
case api.KindNew: case api.KindNew:
if err = r.updateLatestEvents( if err = r.updateLatestEvents(
ctx, // context ctx, // context
updater, // room updater
roomInfo, // room info for the room being updated roomInfo, // room info for the room being updated
stateAtEvent, // state at event (below) stateAtEvent, // state at event (below)
event, // event event, // event
@ -308,7 +312,7 @@ func (r *Inputer) processRoomEvent(
input.TransactionID, // transaction ID input.TransactionID, // transaction ID
input.HasState, // rewrites state? input.HasState, // rewrites state?
); err != nil { ); err != nil {
return fmt.Errorf("r.updateLatestEvents: %w", err) return rollbackTransaction, fmt.Errorf("r.updateLatestEvents: %w", err)
} }
case api.KindOld: case api.KindOld:
err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{ err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{
@ -320,7 +324,7 @@ func (r *Inputer) processRoomEvent(
}, },
}) })
if err != nil { if err != nil {
return fmt.Errorf("r.WriteOutputEvents (old): %w", err) return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (old): %w", err)
} }
} }
@ -339,14 +343,14 @@ func (r *Inputer) processRoomEvent(
}, },
}) })
if err != nil { if err != nil {
return fmt.Errorf("r.WriteOutputEvents (redactions): %w", err) return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (redactions): %w", err)
} }
} }
// Everything was OK — the latest events updater didn't error and // Everything was OK — the latest events updater didn't error and
// we've sent output events. Finally, generate a hook call. // we've sent output events. Finally, generate a hook call.
hooks.Run(hooks.KindNewEventPersisted, headered) hooks.Run(hooks.KindNewEventPersisted, headered)
return nil return commitTransaction, nil
} }
// fetchAuthEvents will check to see if any of the // fetchAuthEvents will check to see if any of the
@ -358,6 +362,7 @@ func (r *Inputer) processRoomEvent(
// they are now in the database. // they are now in the database.
func (r *Inputer) fetchAuthEvents( func (r *Inputer) fetchAuthEvents(
ctx context.Context, ctx context.Context,
updater *shared.RoomUpdater,
logger *logrus.Entry, logger *logrus.Entry,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
auth *gomatrixserverlib.AuthEvents, auth *gomatrixserverlib.AuthEvents,
@ -375,7 +380,7 @@ func (r *Inputer) fetchAuthEvents(
} }
for _, authEventID := range authEventIDs { for _, authEventID := range authEventIDs {
authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID}) authEvents, err := updater.EventsFromIDs(ctx, []string{authEventID})
if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil {
unknown[authEventID] = struct{}{} unknown[authEventID] = struct{}{}
continue continue
@ -454,9 +459,9 @@ func (r *Inputer) fetchAuthEvents(
} }
// Finally, store the event in the database. // Finally, store the event in the database.
eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) eventNID, _, _, _, _, err := updater.StoreEvent(ctx, authEvent, authEventNIDs, isRejected)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.StoreEvent: %w", err) return fmt.Errorf("updater.StoreEvent: %w", err)
} }
// Now we know about this event, it was stored and the signatures were OK. // Now we know about this event, it was stored and the signatures were OK.
@ -471,6 +476,7 @@ func (r *Inputer) fetchAuthEvents(
func (r *Inputer) calculateAndSetState( func (r *Inputer) calculateAndSetState(
ctx context.Context, ctx context.Context,
updater *shared.RoomUpdater,
input *api.InputRoomEvent, input *api.InputRoomEvent,
roomInfo *types.RoomInfo, roomInfo *types.RoomInfo,
stateAtEvent *types.StateAtEvent, stateAtEvent *types.StateAtEvent,
@ -478,14 +484,14 @@ func (r *Inputer) calculateAndSetState(
isRejected bool, isRejected bool,
) error { ) error {
var err error var err error
roomState := state.NewStateResolution(r.DB, roomInfo) roomState := state.NewStateResolution(updater, roomInfo)
if input.HasState { if input.HasState {
// Check here if we think we're in the room already. // Check here if we think we're in the room already.
stateAtEvent.Overwrite = true stateAtEvent.Overwrite = true
var joinEventNIDs []types.EventNID var joinEventNIDs []types.EventNID
// Request join memberships only for local users only. // Request join memberships only for local users only.
if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil { if joinEventNIDs, err = updater.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil {
// If we have no local users that are joined to the room then any state about // If we have no local users that are joined to the room then any state about
// the room that we have is quite possibly out of date. Therefore in that case // the room that we have is quite possibly out of date. Therefore in that case
// we should overwrite it rather than merge it. // we should overwrite it rather than merge it.
@ -495,13 +501,13 @@ func (r *Inputer) calculateAndSetState(
// We've been told what the state at the event is so we don't need to calculate it. // We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state. // Check that those state events are in the database and store the state.
var entries []types.StateEntry var entries []types.StateEntry
if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { if entries, err = updater.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
return fmt.Errorf("r.DB.StateEntriesForEventIDs: %w", err) return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err)
} }
entries = types.DeduplicateStateEntries(entries) entries = types.DeduplicateStateEntries(entries)
if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { if stateAtEvent.BeforeStateSnapshotNID, err = updater.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil {
return fmt.Errorf("r.DB.AddState: %w", err) return fmt.Errorf("updater.AddState: %w", err)
} }
} else { } else {
stateAtEvent.Overwrite = false stateAtEvent.Overwrite = false
@ -512,7 +518,7 @@ func (r *Inputer) calculateAndSetState(
} }
} }
err = r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) err = updater.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
if err != nil { if err != nil {
return fmt.Errorf("r.DB.SetState: %w", err) return fmt.Errorf("r.DB.SetState: %w", err)
} }

View File

@ -20,7 +20,6 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
@ -48,6 +47,7 @@ import (
// Can only be called once at a time // Can only be called once at a time
func (r *Inputer) updateLatestEvents( func (r *Inputer) updateLatestEvents(
ctx context.Context, ctx context.Context,
updater *shared.RoomUpdater,
roomInfo *types.RoomInfo, roomInfo *types.RoomInfo,
stateAtEvent types.StateAtEvent, stateAtEvent types.StateAtEvent,
event *gomatrixserverlib.Event, event *gomatrixserverlib.Event,
@ -55,13 +55,6 @@ func (r *Inputer) updateLatestEvents(
transactionID *api.TransactionID, transactionID *api.TransactionID,
rewritesState bool, rewritesState bool,
) (err error) { ) (err error) {
updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo)
if err != nil {
return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err)
}
succeeded := false
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
u := latestEventsUpdater{ u := latestEventsUpdater{
ctx: ctx, ctx: ctx,
api: r, api: r,
@ -78,7 +71,6 @@ func (r *Inputer) updateLatestEvents(
return fmt.Errorf("u.doUpdateLatestEvents: %w", err) return fmt.Errorf("u.doUpdateLatestEvents: %w", err)
} }
succeeded = true
return return
} }
@ -89,7 +81,7 @@ func (r *Inputer) updateLatestEvents(
type latestEventsUpdater struct { type latestEventsUpdater struct {
ctx context.Context ctx context.Context
api *Inputer api *Inputer
updater *shared.LatestEventsUpdater updater *shared.RoomUpdater
roomInfo *types.RoomInfo roomInfo *types.RoomInfo
stateAtEvent types.StateAtEvent stateAtEvent types.StateAtEvent
event *gomatrixserverlib.Event event *gomatrixserverlib.Event
@ -199,7 +191,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
func (u *latestEventsUpdater) latestState() error { func (u *latestEventsUpdater) latestState() error {
var err error var err error
roomState := state.NewStateResolution(u.api.DB, u.roomInfo) roomState := state.NewStateResolution(u.updater, u.roomInfo)
// Work out if the state at the extremities has actually changed // Work out if the state at the extremities has actually changed
// or not. If they haven't then we won't bother doing all of the // or not. If they haven't then we won't bother doing all of the
@ -413,7 +405,7 @@ func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.Ro
if len(extraEventIDs) == 0 { if len(extraEventIDs) == 0 {
return nil, nil return nil, nil
} }
extraEvents, err := u.api.DB.EventsFromIDs(u.ctx, extraEventIDs) extraEvents, err := u.updater.EventsFromIDs(u.ctx, extraEventIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -436,7 +428,7 @@ func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error)
stateEventNIDs = append(stateEventNIDs, entry.EventNID) stateEventNIDs = append(stateEventNIDs, entry.EventNID)
} }
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
return u.api.DB.EventIDs(u.ctx, stateEventNIDs) return u.updater.EventIDs(u.ctx, stateEventNIDs)
} }
type eventNIDSorter []types.EventNID type eventNIDSorter []types.EventNID

View File

@ -31,7 +31,7 @@ import (
// consumers about the invites added or retired by the change in current state. // consumers about the invites added or retired by the change in current state.
func (r *Inputer) updateMemberships( func (r *Inputer) updateMemberships(
ctx context.Context, ctx context.Context,
updater *shared.LatestEventsUpdater, updater *shared.RoomUpdater,
removed, added []types.StateEntry, removed, added []types.StateEntry,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
changes := membershipChanges(removed, added) changes := membershipChanges(removed, added)
@ -79,7 +79,7 @@ func (r *Inputer) updateMemberships(
} }
func (r *Inputer) updateMembership( func (r *Inputer) updateMembership(
updater *shared.LatestEventsUpdater, updater *shared.RoomUpdater,
targetUserNID types.EventStateKeyNID, targetUserNID types.EventStateKeyNID,
remove, add *gomatrixserverlib.Event, remove, add *gomatrixserverlib.Event,
updates []api.OutputEvent, updates []api.OutputEvent,

View File

@ -11,7 +11,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -19,7 +19,7 @@ import (
type missingStateReq struct { type missingStateReq struct {
origin gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName
db storage.Database db *shared.RoomUpdater
inputer *Inputer inputer *Inputer
queryer *query.Queryer queryer *query.Queryer
keys gomatrixserverlib.JSONVerifier keys gomatrixserverlib.JSONVerifier
@ -78,7 +78,7 @@ func (t *missingStateReq) processEventWithMissingState(
// we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled
// in the gap in the DAG // in the gap in the DAG
for _, newEvent := range newEvents { for _, newEvent := range newEvents {
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
Kind: api.KindNew, Kind: api.KindNew,
Event: newEvent.Headered(roomVersion), Event: newEvent.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,
@ -187,7 +187,7 @@ func (t *missingStateReq) processEventWithMissingState(
} }
// TODO: we could do this concurrently? // TODO: we could do this concurrently?
for _, ire := range outlierRoomEvents { for _, ire := range outlierRoomEvents {
if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { if _, err = t.inputer.processRoomEvent(ctx, t.db, &ire); err != nil {
return fmt.Errorf("t.inputer.processRoomEvent[outlier]: %w", err) return fmt.Errorf("t.inputer.processRoomEvent[outlier]: %w", err)
} }
} }
@ -200,7 +200,7 @@ func (t *missingStateReq) processEventWithMissingState(
stateIDs = append(stateIDs, event.EventID()) stateIDs = append(stateIDs, event.EventID())
} }
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: backwardsExtremity.Headered(roomVersion), Event: backwardsExtremity.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,
@ -217,7 +217,7 @@ func (t *missingStateReq) processEventWithMissingState(
// they will automatically fast-forward based on the room state at the // they will automatically fast-forward based on the room state at the
// extremity in the last step. // extremity in the last step.
for _, newEvent := range newEvents { for _, newEvent := range newEvents {
err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
Kind: api.KindOld, Kind: api.KindOld,
Event: newEvent.Headered(roomVersion), Event: newEvent.Headered(roomVersion),
Origin: t.origin, Origin: t.origin,

View File

@ -22,7 +22,6 @@ import (
"sort" "sort"
"time" "time"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -30,13 +29,25 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
type StateResolutionStorage interface {
EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
}
type StateResolution struct { type StateResolution struct {
db storage.Database db StateResolutionStorage
roomInfo *types.RoomInfo roomInfo *types.RoomInfo
events map[types.EventNID]*gomatrixserverlib.Event events map[types.EventNID]*gomatrixserverlib.Event
} }
func NewStateResolution(db storage.Database, roomInfo *types.RoomInfo) StateResolution { func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution {
return StateResolution{ return StateResolution{
db: db, db: db,
roomInfo: roomInfo, roomInfo: roomInfo,

View File

@ -86,11 +86,10 @@ type Database interface {
// Lookup the event IDs for a batch of event numeric IDs. // Lookup the event IDs for a batch of event numeric IDs.
// Returns an error if the retrieval went wrong. // Returns an error if the retrieval went wrong.
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
// Look up the latest events in a room in preparation for an update. // Opens and returns a room updater, which locks the room and opens a transaction.
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. // The GetRoomUpdater must have Commit or Rollback called on it if this doesn't return an error.
// Returns the latest events in the room and the last eventID sent to the log along with an updater.
// If this returns an error then no further action is required. // If this returns an error then no further action is required.
GetLatestEventsForUpdate(ctx context.Context, roomInfo types.RoomInfo) (*shared.LatestEventsUpdater, error) GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error)
// Look up event references for the latest events in the room and the current state snapshot. // Look up event references for the latest events in the room and the current state snapshot.
// Returns the latest events, the current state and the maximum depth of the latest events plus 1. // Returns the latest events, the current state and the maximum depth of the latest events plus 1.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.

View File

@ -81,9 +81,10 @@ func (s *eventJSONStatements) InsertEventJSON(
} }
func (s *eventJSONStatements) BulkSelectEventJSON( func (s *eventJSONStatements) BulkSelectEventJSON(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]tables.EventJSONPair, error) { ) ([]tables.EventJSONPair, error) {
rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectEventJSONStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -111,9 +111,10 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID(
} }
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
ctx context.Context, eventStateKeys []string, ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) { ) (map[string]types.EventStateKeyNID, error) {
rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext( stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt)
rows, err := stmt.QueryContext(
ctx, pq.StringArray(eventStateKeys), ctx, pq.StringArray(eventStateKeys),
) )
if err != nil { if err != nil {
@ -134,13 +135,14 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
} }
func (s *eventStateKeyStatements) BulkSelectEventStateKey( func (s *eventStateKeyStatements) BulkSelectEventStateKey(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) { ) (map[types.EventStateKeyNID]string, error) {
nIDs := make(pq.Int64Array, len(eventStateKeyNIDs)) nIDs := make(pq.Int64Array, len(eventStateKeyNIDs))
for i := range eventStateKeyNIDs { for i := range eventStateKeyNIDs {
nIDs[i] = int64(eventStateKeyNIDs[i]) nIDs[i] = int64(eventStateKeyNIDs[i])
} }
rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs) stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyStmt)
rows, err := stmt.QueryContext(ctx, nIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -133,9 +133,10 @@ func (s *eventTypeStatements) SelectEventTypeNID(
} }
func (s *eventTypeStatements) BulkSelectEventTypeNID( func (s *eventTypeStatements) BulkSelectEventTypeNID(
ctx context.Context, eventTypes []string, ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) { ) (map[string]types.EventTypeNID, error) {
rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes)) stmt := sqlutil.TxStmt(txn, s.bulkSelectEventTypeNIDStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventTypes))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -212,9 +212,10 @@ func (s *eventStatements) SelectEvent(
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByID( func (s *eventStatements) BulkSelectStateEventByID(
ctx context.Context, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -254,13 +255,14 @@ func (s *eventStatements) BulkSelectStateEventByID(
// bulkSelectStateEventByNID lookups a list of state events by event NID. // bulkSelectStateEventByNID lookups a list of state events by event NID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByNID( func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
tuples := stateKeyTupleSorter(stateKeyTuples) tuples := stateKeyTupleSorter(stateKeyTuples)
sort.Sort(tuples) sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray) stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -291,9 +293,10 @@ func (s *eventStatements) BulkSelectStateEventByNID(
// If any of the requested events are missing from the database it returns a types.MissingEventError. // If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError.
func (s *eventStatements) BulkSelectStateAtEventByID( func (s *eventStatements) BulkSelectStateAtEventByID(
ctx context.Context, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateAtEvent, error) { ) ([]types.StateAtEvent, error) {
rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectStateAtEventByIDStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -428,8 +431,9 @@ func (s *eventStatements) BulkSelectEventReference(
} }
// bulkSelectEventID returns a map from numeric event ID to string event ID. // bulkSelectEventID returns a map from numeric event ID to string event ID.
func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectEventIDStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -455,8 +459,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -484,9 +489,10 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
} }
func (s *eventStatements) SelectRoomNIDsForEventNIDs( func (s *eventStatements) SelectRoomNIDsForEventNIDs(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) (map[types.EventNID]types.RoomNID, error) { ) (map[types.EventNID]types.RoomNID, error) {
rows, err := s.selectRoomNIDsForEventNIDsStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) stmt := sqlutil.TxStmt(txn, s.selectRoomNIDsForEventNIDsStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -97,8 +97,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
} }
func (s *inviteStatements) InsertInviteEvent( func (s *inviteStatements) InsertInviteEvent(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, inviteEventID string, roomNID types.RoomNID,
targetUserNID, senderUserNID types.EventStateKeyNID, targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte, inviteEventJSON []byte,
) (bool, error) { ) (bool, error) {
@ -116,8 +116,8 @@ func (s *inviteStatements) InsertInviteEvent(
} }
func (s *inviteStatements) UpdateInviteRetired( func (s *inviteStatements) UpdateInviteRetired(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) ([]string, error) { ) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
@ -139,10 +139,11 @@ func (s *inviteStatements) UpdateInviteRetired(
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs
func (s *inviteStatements) SelectInviteActiveForUserInRoom( func (s *inviteStatements) SelectInviteActiveForUserInRoom(
ctx context.Context, ctx context.Context, txn *sql.Tx,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, []string, error) { ) ([]types.EventStateKeyNID, []string, error) {
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
rows, err := stmt.QueryContext(
ctx, targetUserNID, roomNID, ctx, targetUserNID, roomNID,
) )
if err != nil { if err != nil {

View File

@ -186,8 +186,8 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) {
} }
func (s *membershipStatements) InsertMembership( func (s *membershipStatements) InsertMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
localTarget bool, localTarget bool,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
@ -196,8 +196,8 @@ func (s *membershipStatements) InsertMembership(
} }
func (s *membershipStatements) SelectMembershipForUpdate( func (s *membershipStatements) SelectMembershipForUpdate(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (membership tables.MembershipState, err error) { ) (membership tables.MembershipState, err error) {
err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
ctx, roomNID, targetUserNID, ctx, roomNID, targetUserNID,
@ -206,17 +206,19 @@ func (s *membershipStatements) SelectMembershipForUpdate(
} }
func (s *membershipStatements) SelectMembershipFromRoomAndTarget( func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
ctx context.Context, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
err = stmt.QueryRowContext(
ctx, roomNID, targetUserNID, ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID, &forgotten) ).Scan(&membership, &eventNID, &forgotten)
return return
} }
func (s *membershipStatements) SelectMembershipsFromRoom( func (s *membershipStatements) SelectMembershipsFromRoom(
ctx context.Context, roomNID types.RoomNID, localOnly bool, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var stmt *sql.Stmt var stmt *sql.Stmt
if localOnly { if localOnly {
@ -224,6 +226,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
} else { } else {
stmt = s.selectMembershipsFromRoomStmt stmt = s.selectMembershipsFromRoomStmt
} }
stmt = sqlutil.TxStmt(txn, stmt)
rows, err := stmt.QueryContext(ctx, roomNID) rows, err := stmt.QueryContext(ctx, roomNID)
if err != nil { if err != nil {
return return
@ -241,7 +244,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
} }
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var rows *sql.Rows var rows *sql.Rows
@ -251,6 +254,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
} else { } else {
stmt = s.selectMembershipsFromRoomAndMembershipStmt stmt = s.selectMembershipsFromRoomAndMembershipStmt
} }
stmt = sqlutil.TxStmt(txn, stmt)
rows, err = stmt.QueryContext(ctx, roomNID, membership) rows, err = stmt.QueryContext(ctx, roomNID, membership)
if err != nil { if err != nil {
return return
@ -268,8 +272,8 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
} }
func (s *membershipStatements) UpdateMembership( func (s *membershipStatements) UpdateMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
eventNID types.EventNID, forgotten bool, eventNID types.EventNID, forgotten bool,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext(
@ -279,9 +283,11 @@ func (s *membershipStatements) UpdateMembership(
} }
func (s *membershipStatements) SelectRoomsWithMembership( func (s *membershipStatements) SelectRoomsWithMembership(
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, ctx context.Context, txn *sql.Tx,
userID types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) { ) ([]types.RoomNID, error) {
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
rows, err := stmt.QueryContext(ctx, membershipState, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -297,12 +303,16 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil return roomNIDs, nil
} }
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { func (s *membershipStatements) SelectJoinedUsersSetForRooms(
ctx context.Context, txn *sql.Tx,
roomNIDs []types.RoomNID,
) (map[types.EventStateKeyNID]int, error) {
roomIDarray := make([]int64, len(roomNIDs)) roomIDarray := make([]int64, len(roomNIDs))
for i := range roomNIDs { for i := range roomNIDs {
roomIDarray[i] = int64(roomNIDs[i]) roomIDarray[i] = int64(roomNIDs[i])
} }
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -319,8 +329,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
return result, rows.Err() return result, rows.Err()
} }
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { func (s *membershipStatements) SelectKnownUsers(
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) ctx context.Context, txn *sql.Tx,
userID types.EventStateKeyNID, searchString string, limit int,
) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt)
rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -337,9 +351,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
} }
func (s *membershipStatements) UpdateForgetMembership( func (s *membershipStatements) UpdateForgetMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool,
forget bool,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
ctx, roomNID, targetUserNID, forget, ctx, roomNID, targetUserNID, forget,
@ -347,9 +360,13 @@ func (s *membershipStatements) UpdateForgetMembership(
return err return err
} }
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { func (s *membershipStatements) SelectLocalServerInRoom(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID,
) (bool, error) {
var nid types.RoomNID var nid types.RoomNID
err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil
@ -360,9 +377,13 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
return found, nil return found, nil
} }
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { func (s *membershipStatements) SelectServerInRoom(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, serverName gomatrixserverlib.ServerName,
) (bool, error) {
var nid types.RoomNID var nid types.RoomNID
err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil

View File

@ -73,9 +73,10 @@ func (s *publishedStatements) UpsertRoomPublished(
} }
func (s *publishedStatements) SelectPublishedFromRoomID( func (s *publishedStatements) SelectPublishedFromRoomID(
ctx context.Context, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (published bool, err error) { ) (published bool, err error) {
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt)
err = stmt.QueryRowContext(ctx, roomID).Scan(&published)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil
} }
@ -83,9 +84,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
} }
func (s *publishedStatements) SelectAllPublishedRooms( func (s *publishedStatements) SelectAllPublishedRooms(
ctx context.Context, published bool, ctx context.Context, txn *sql.Tx, published bool,
) ([]string, error) { ) ([]string, error) {
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
rows, err := stmt.QueryContext(ctx, published)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -87,9 +87,10 @@ func (s *roomAliasesStatements) InsertRoomAlias(
} }
func (s *roomAliasesStatements) SelectRoomIDFromAlias( func (s *roomAliasesStatements) SelectRoomIDFromAlias(
ctx context.Context, alias string, ctx context.Context, txn *sql.Tx, alias string,
) (roomID string, err error) { ) (roomID string, err error) {
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt)
err = stmt.QueryRowContext(ctx, alias).Scan(&roomID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }
@ -97,9 +98,10 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias(
} }
func (s *roomAliasesStatements) SelectAliasesFromRoomID( func (s *roomAliasesStatements) SelectAliasesFromRoomID(
ctx context.Context, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) ([]string, error) { ) ([]string, error) {
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
rows, err := stmt.QueryContext(ctx, roomID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -118,9 +120,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
} }
func (s *roomAliasesStatements) SelectCreatorIDFromAlias( func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
ctx context.Context, alias string, ctx context.Context, txn *sql.Tx, alias string,
) (creatorID string, err error) { ) (creatorID string, err error) {
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }

View File

@ -117,8 +117,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
rows, err := s.selectRoomIDsStmt.QueryContext(ctx) stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -143,10 +144,11 @@ func (s *roomStatements) InsertRoomNID(
return types.RoomNID(roomNID), err return types.RoomNID(roomNID), err
} }
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
var info types.RoomInfo var info types.RoomInfo
var latestNIDs pq.Int64Array var latestNIDs pq.Int64Array
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan( stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDs, &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDs,
) )
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -170,7 +172,7 @@ func (s *roomStatements) SelectLatestEventNIDs(
) ([]types.EventNID, types.StateSnapshotNID, error) { ) ([]types.EventNID, types.StateSnapshotNID, error) {
var nids pq.Int64Array var nids pq.Int64Array
var stateSnapshotNID int64 var stateSnapshotNID int64
stmt := s.selectLatestEventNIDsStmt stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt)
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID) err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@ -220,9 +222,10 @@ func (s *roomStatements) UpdateLatestEventNIDs(
} }
func (s *roomStatements) SelectRoomVersionsForRoomNIDs( func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
ctx context.Context, roomNIDs []types.RoomNID, ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID,
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) { ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
rows, err := s.selectRoomVersionsForRoomNIDsStmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs)) stmt := sqlutil.TxStmt(txn, s.selectRoomVersionsForRoomNIDsStmt)
rows, err := stmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -239,12 +242,13 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
return result, nil return result, nil
} }
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) {
var array pq.Int64Array var array pq.Int64Array
for _, nid := range roomNIDs { for _, nid := range roomNIDs {
array = append(array, int64(nid)) array = append(array, int64(nid))
} }
rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array) stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx, array)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -260,12 +264,13 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
return roomIDs, nil return roomIDs, nil
} }
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) {
var array pq.StringArray var array pq.StringArray
for _, roomID := range roomIDs { for _, roomID := range roomIDs {
array = append(array, roomID) array = append(array, roomID)
} }
rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array) stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomNIDsStmt)
rows, err := stmt.QueryContext(ctx, array)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -86,8 +86,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
} }
func (s *stateBlockStatements) BulkInsertStateData( func (s *stateBlockStatements) BulkInsertStateData(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx,
entries types.StateEntries, entries types.StateEntries,
) (id types.StateBlockNID, err error) { ) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)] entries = entries[:util.SortAndUnique(entries)]
@ -95,16 +94,18 @@ func (s *stateBlockStatements) BulkInsertStateData(
for _, e := range entries { for _, e := range entries {
nids = append(nids, e.EventNID) nids = append(nids, e.EventNID)
} }
err = s.insertStateDataStmt.QueryRowContext( stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
err = stmt.QueryRowContext(
ctx, nids.Hash(), eventNIDsAsArray(nids), ctx, nids.Hash(), eventNIDsAsArray(nids),
).Scan(&id) ).Scan(&id)
return return
} }
func (s *stateBlockStatements) BulkSelectStateBlockEntries( func (s *stateBlockStatements) BulkSelectStateBlockEntries(
ctx context.Context, stateBlockNIDs types.StateBlockNIDs, ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs,
) ([][]types.EventNID, error) { ) ([][]types.EventNID, error) {
rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs)) stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockEntriesStmt)
rows, err := stmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -105,13 +105,14 @@ func (s *stateSnapshotStatements) InsertState(
} }
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID, ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) { ) ([]types.StateBlockNIDList, error) {
nids := make([]int64, len(stateNIDs)) nids := make([]int64, len(stateNIDs))
for i := range stateNIDs { for i := range stateNIDs {
nids[i] = int64(stateNIDs[i]) nids[i] = int64(stateNIDs[i])
} }
rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids)) stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockNIDsStmt)
rows, err := stmt.QueryContext(ctx, pq.Int64Array(nids))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,133 +0,0 @@
package shared
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
type LatestEventsUpdater struct {
transaction
d *Database
roomInfo types.RoomInfo
latestEvents []types.StateAtEventAndReference
lastEventIDSent string
currentStateSnapshotNID types.StateSnapshotNID
}
func rollback(txn *sql.Tx) {
if txn == nil {
return
}
txn.Rollback() // nolint: errcheck
}
func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) {
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID)
if err != nil {
rollback(txn)
return nil, err
}
stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
if err != nil {
rollback(txn)
return nil, err
}
var lastEventIDSent string
if lastEventNIDSent != 0 {
lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent)
if err != nil {
rollback(txn)
return nil, err
}
}
return &LatestEventsUpdater{
transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
}, nil
}
// RoomVersion implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
return u.roomInfo.RoomVersion
}
// LatestEvents implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) LatestEvents() []types.StateAtEventAndReference {
return u.latestEvents
}
// LastEventIDSent implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) LastEventIDSent() string {
return u.lastEventIDSent
}
// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
return u.currentStateSnapshotNID
}
// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
for _, ref := range previousEventReferences {
if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
}
}
return nil
}
// IsReferenced implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
if err == nil {
return true, nil
}
if err == sql.ErrNoRows {
return false, nil
}
return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err)
}
// SetLatestEvents implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) SetLatestEvents(
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
currentStateSnapshotNID types.StateSnapshotNID,
) error {
eventNIDs := make([]types.EventNID, len(latest))
for i := range latest {
eventNIDs[i] = latest[i].EventNID
}
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
}
if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok {
if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok {
roomInfo.StateSnapshotNID = currentStateSnapshotNID
roomInfo.IsStub = false
u.d.Cache.StoreRoomInfo(roomID, roomInfo)
}
}
return nil
})
}
// HasEventBeenSent implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID)
}
// MarkEventAsSent implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID)
})
}
func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
}

View File

@ -0,0 +1,262 @@
package shared
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
type RoomUpdater struct {
transaction
d *Database
roomInfo *types.RoomInfo
latestEvents []types.StateAtEventAndReference
lastEventIDSent string
currentStateSnapshotNID types.StateSnapshotNID
}
func rollback(txn *sql.Tx) {
if txn == nil {
return
}
txn.Rollback() // nolint: errcheck
}
func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *types.RoomInfo) (*RoomUpdater, error) {
// If the roomInfo is nil then that means that the room doesn't exist
// yet, so we can't do `SelectLatestEventsNIDsForUpdate` because that
// would involve locking a row on the table that doesn't exist. Instead
// we will just run with a normal database transaction. It'll either
// succeed, processing a create event which creates the room, or it won't.
if roomInfo == nil {
return &RoomUpdater{
transaction{ctx, txn}, d, nil, nil, "", 0,
}, nil
}
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID)
if err != nil {
rollback(txn)
return nil, err
}
stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
if err != nil {
rollback(txn)
return nil, err
}
var lastEventIDSent string
if lastEventNIDSent != 0 {
lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent)
if err != nil {
rollback(txn)
return nil, err
}
}
return &RoomUpdater{
transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
}, nil
}
// Implements sqlutil.Transaction
func (u *RoomUpdater) Commit() error {
if u.txn == nil { // SQLite mode probably
return nil
}
return u.txn.Commit()
}
// Implements sqlutil.Transaction
func (u *RoomUpdater) Rollback() error {
if u.txn == nil { // SQLite mode probably
return nil
}
return u.txn.Rollback()
}
// RoomVersion implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
return u.roomInfo.RoomVersion
}
// LatestEvents implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) LatestEvents() []types.StateAtEventAndReference {
return u.latestEvents
}
// LastEventIDSent implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) LastEventIDSent() string {
return u.lastEventIDSent
}
// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
return u.currentStateSnapshotNID
}
// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
for _, ref := range previousEventReferences {
if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
}
}
return nil
})
}
func (u *RoomUpdater) Events(
ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) {
return u.d.events(ctx, u.txn, eventNIDs)
}
func (u *RoomUpdater) SnapshotNIDFromEventID(
ctx context.Context, eventID string,
) (types.StateSnapshotNID, error) {
return u.d.snapshotNIDFromEventID(ctx, u.txn, eventID)
}
func (u *RoomUpdater) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
return u.d.storeEvent(ctx, u, event, authEventNIDs, isRejected)
}
func (u *RoomUpdater) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
return u.d.stateBlockNIDs(ctx, u.txn, stateNIDs)
}
func (u *RoomUpdater) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
return u.d.stateEntries(ctx, u.txn, stateBlockNIDs)
}
func (u *RoomUpdater) StateEntriesForTuples(
ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
return u.d.stateEntriesForTuples(ctx, u.txn, stateBlockNIDs, stateKeyTuples)
}
func (u *RoomUpdater) AddState(
ctx context.Context,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
return u.d.addState(ctx, u.txn, roomNID, stateBlockNIDs, state)
}
func (u *RoomUpdater) SetState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
return u.d.EventsTable.UpdateEventState(ctx, txn, eventNID, stateNID)
})
}
func (u *RoomUpdater) EventTypeNIDs(
ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
return u.d.eventTypeNIDs(ctx, u.txn, eventTypes)
}
func (u *RoomUpdater) EventStateKeyNIDs(
ctx context.Context, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
return u.d.eventStateKeyNIDs(ctx, u.txn, eventStateKeys)
}
func (u *RoomUpdater) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
return u.d.roomInfo(ctx, u.txn, roomID)
}
func (u *RoomUpdater) EventIDs(
ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) {
return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs)
}
func (u *RoomUpdater) StateAtEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) {
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) {
return u.d.EventsTable.BulkSelectStateEventByID(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) {
return u.d.getMembershipEventNIDsForRoom(ctx, u.txn, roomNID, joinOnly, localOnly)
}
// IsReferenced implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
if err == nil {
return true, nil
}
if err == sql.ErrNoRows {
return false, nil
}
return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err)
}
// SetLatestEvents implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) SetLatestEvents(
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
currentStateSnapshotNID types.StateSnapshotNID,
) error {
eventNIDs := make([]types.EventNID, len(latest))
for i := range latest {
eventNIDs[i] = latest[i].EventNID
}
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
}
if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok {
if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok {
roomInfo.StateSnapshotNID = currentStateSnapshotNID
roomInfo.IsStub = false
u.d.Cache.StoreRoomInfo(roomID, roomInfo)
}
}
return nil
})
}
// HasEventBeenSent implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID)
}
// MarkEventAsSent implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID)
})
}
func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
}

View File

@ -26,23 +26,23 @@ import (
const redactionsArePermanent = true const redactionsArePermanent = true
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
Cache caching.RoomServerCaches Cache caching.RoomServerCaches
Writer sqlutil.Writer Writer sqlutil.Writer
EventsTable tables.Events EventsTable tables.Events
EventJSONTable tables.EventJSON EventJSONTable tables.EventJSON
EventTypesTable tables.EventTypes EventTypesTable tables.EventTypes
EventStateKeysTable tables.EventStateKeys EventStateKeysTable tables.EventStateKeys
RoomsTable tables.Rooms RoomsTable tables.Rooms
StateSnapshotTable tables.StateSnapshot StateSnapshotTable tables.StateSnapshot
StateBlockTable tables.StateBlock StateBlockTable tables.StateBlock
RoomAliasesTable tables.RoomAliases RoomAliasesTable tables.RoomAliases
PrevEventsTable tables.PreviousEvents PrevEventsTable tables.PreviousEvents
InvitesTable tables.Invites InvitesTable tables.Invites
MembershipTable tables.Membership MembershipTable tables.Membership
PublishedTable tables.Published PublishedTable tables.Published
RedactionsTable tables.Redactions RedactionsTable tables.Redactions
GetLatestEventsForUpdateFn func(ctx context.Context, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
} }
func (d *Database) SupportsConcurrentRoomInputs() bool { func (d *Database) SupportsConcurrentRoomInputs() bool {
@ -51,6 +51,12 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
func (d *Database) EventTypeNIDs( func (d *Database) EventTypeNIDs(
ctx context.Context, eventTypes []string, ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
return d.eventTypeNIDs(ctx, nil, eventTypes)
}
func (d *Database) eventTypeNIDs(
ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) { ) (map[string]types.EventTypeNID, error) {
result := make(map[string]types.EventTypeNID) result := make(map[string]types.EventTypeNID)
remaining := []string{} remaining := []string{}
@ -62,7 +68,7 @@ func (d *Database) EventTypeNIDs(
} }
} }
if len(remaining) > 0 { if len(remaining) > 0 {
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, remaining) nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, remaining)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -77,11 +83,17 @@ func (d *Database) EventTypeNIDs(
func (d *Database) EventStateKeys( func (d *Database) EventStateKeys(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) { ) (map[types.EventStateKeyNID]string, error) {
return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs) return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs)
} }
func (d *Database) EventStateKeyNIDs( func (d *Database) EventStateKeyNIDs(
ctx context.Context, eventStateKeys []string, ctx context.Context, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
return d.eventStateKeyNIDs(ctx, nil, eventStateKeys)
}
func (d *Database) eventStateKeyNIDs(
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) { ) (map[string]types.EventStateKeyNID, error) {
result := make(map[string]types.EventStateKeyNID) result := make(map[string]types.EventStateKeyNID)
remaining := []string{} remaining := []string{}
@ -93,7 +105,7 @@ func (d *Database) EventStateKeyNIDs(
} }
} }
if len(remaining) > 0 { if len(remaining) > 0 {
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, remaining) nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, remaining)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -108,23 +120,31 @@ func (d *Database) EventStateKeyNIDs(
func (d *Database) StateEntriesForEventIDs( func (d *Database) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs) return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs)
} }
func (d *Database) StateEntriesForTuples( func (d *Database) StateEntriesForTuples(
ctx context.Context, ctx context.Context,
stateBlockNIDs []types.StateBlockNID, stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
return d.stateEntriesForTuples(ctx, nil, stateBlockNIDs, stateKeyTuples)
}
func (d *Database) stateEntriesForTuples(
ctx context.Context, txn *sql.Tx,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) { ) ([]types.StateEntryList, error) {
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
ctx, stateBlockNIDs, ctx, txn, stateBlockNIDs,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
} }
lists := []types.StateEntryList{} lists := []types.StateEntryList{}
for i, entry := range entries { for i, entry := range entries {
entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, stateKeyTuples) entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, stateKeyTuples)
if err != nil { if err != nil {
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
} }
@ -137,10 +157,14 @@ func (d *Database) StateEntriesForTuples(
} }
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
return d.roomInfo(ctx, nil, roomID)
}
func (d *Database) roomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok { if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
return &roomInfo, nil return &roomInfo, nil
} }
roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, roomID) roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID)
if err == nil && roomInfo != nil { if err == nil && roomInfo != nil {
d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID) d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID)
d.Cache.StoreRoomInfo(roomID, *roomInfo) d.Cache.StoreRoomInfo(roomID, *roomInfo)
@ -153,13 +177,22 @@ func (d *Database) AddState(
roomNID types.RoomNID, roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID, stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry, state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
return d.addState(ctx, nil, roomNID, stateBlockNIDs, state)
}
func (d *Database) addState(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) { ) (stateNID types.StateSnapshotNID, err error) {
if len(stateBlockNIDs) > 0 && len(state) > 0 { if len(stateBlockNIDs) > 0 && len(state) > 0 {
// Check to see if the event already appears in any of the existing state // Check to see if the event already appears in any of the existing state
// blocks. If it does then we should not add it again, as this will just // blocks. If it does then we should not add it again, as this will just
// result in excess state blocks and snapshots. // result in excess state blocks and snapshots.
// TODO: Investigate why this is happening - probably input_events.go! // TODO: Investigate why this is happening - probably input_events.go!
blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs)
if berr != nil { if berr != nil {
return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr) return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr)
} }
@ -180,7 +213,7 @@ func (d *Database) AddState(
} }
} }
} }
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
if len(state) > 0 { if len(state) > 0 {
// If there's any state left to add then let's add new blocks. // If there's any state left to add then let's add new blocks.
var stateBlockNID types.StateBlockNID var stateBlockNID types.StateBlockNID
@ -205,7 +238,13 @@ func (d *Database) AddState(
func (d *Database) EventNIDs( func (d *Database) EventNIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) { ) (map[string]types.EventNID, error) {
return d.EventsTable.BulkSelectEventNID(ctx, eventIDs) return d.eventNIDs(ctx, nil, eventIDs)
}
func (d *Database) eventNIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) (map[string]types.EventNID, error) {
return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs)
} }
func (d *Database) SetState( func (d *Database) SetState(
@ -219,24 +258,34 @@ func (d *Database) SetState(
func (d *Database) StateAtEventIDs( func (d *Database) StateAtEventIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) { ) ([]types.StateAtEvent, error) {
return d.EventsTable.BulkSelectStateAtEventByID(ctx, eventIDs) return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs)
} }
func (d *Database) SnapshotNIDFromEventID( func (d *Database) SnapshotNIDFromEventID(
ctx context.Context, eventID string, ctx context.Context, eventID string,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
_, stateNID, err := d.EventsTable.SelectEvent(ctx, nil, eventID) return d.snapshotNIDFromEventID(ctx, nil, eventID)
}
func (d *Database) snapshotNIDFromEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) (types.StateSnapshotNID, error) {
_, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID)
return stateNID, err return stateNID, err
} }
func (d *Database) EventIDs( func (d *Database) EventIDs(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) { ) (map[types.EventNID]string, error) {
return d.EventsTable.BulkSelectEventID(ctx, eventNIDs) return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
} }
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.EventNIDs(ctx, eventIDs) return d.eventsFromIDs(ctx, nil, eventIDs)
}
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.eventNIDs(ctx, txn, eventIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -246,7 +295,7 @@ func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]type
nids = append(nids, nid) nids = append(nids, nid)
} }
return d.Events(ctx, nids) return d.events(ctx, txn, nids)
} }
func (d *Database) LatestEventIDs( func (d *Database) LatestEventIDs(
@ -271,21 +320,33 @@ func (d *Database) LatestEventIDs(
func (d *Database) StateBlockNIDs( func (d *Database) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID, ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) { ) ([]types.StateBlockNIDList, error) {
return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs) return d.stateBlockNIDs(ctx, nil, stateNIDs)
}
func (d *Database) stateBlockNIDs(
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, txn, stateNIDs)
} }
func (d *Database) StateEntries( func (d *Database) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID, ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
return d.stateEntries(ctx, nil, stateBlockNIDs)
}
func (d *Database) stateEntries(
ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) { ) ([]types.StateEntryList, error) {
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
ctx, stateBlockNIDs, ctx, txn, stateBlockNIDs,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
} }
lists := make([]types.StateEntryList, 0, len(entries)) lists := make([]types.StateEntryList, 0, len(entries))
for i, entry := range entries { for i, entry := range entries {
eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, nil) eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
} }
@ -304,17 +365,17 @@ func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string
} }
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, alias) return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, nil, alias)
} }
func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, roomID) return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, nil, roomID)
} }
func (d *Database) GetCreatorIDForAlias( func (d *Database) GetCreatorIDForAlias(
ctx context.Context, alias string, ctx context.Context, alias string,
) (string, error) { ) (string, error) {
return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, alias) return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, nil, alias)
} }
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
@ -335,7 +396,7 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
senderMembershipEventNID, senderMembership, isRoomforgotten, err := senderMembershipEventNID, senderMembership, isRoomforgotten, err :=
d.MembershipTable.SelectMembershipFromRoomAndTarget( d.MembershipTable.SelectMembershipFromRoomAndTarget(
ctx, roomNID, requestSenderUserNID, ctx, nil, roomNID, requestSenderUserNID,
) )
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// The user has never been a member of that room // The user has never been a member of that room
@ -349,14 +410,20 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
func (d *Database) GetMembershipEventNIDsForRoom( func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) {
return d.getMembershipEventNIDsForRoom(ctx, nil, roomNID, joinOnly, localOnly)
}
func (d *Database) getMembershipEventNIDsForRoom(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) { ) ([]types.EventNID, error) {
if joinOnly { if joinOnly {
return d.MembershipTable.SelectMembershipsFromRoomAndMembership( return d.MembershipTable.SelectMembershipsFromRoomAndMembership(
ctx, roomNID, tables.MembershipStateJoin, localOnly, ctx, txn, roomNID, tables.MembershipStateJoin, localOnly,
) )
} }
return d.MembershipTable.SelectMembershipsFromRoom(ctx, roomNID, localOnly) return d.MembershipTable.SelectMembershipsFromRoom(ctx, txn, roomNID, localOnly)
} }
func (d *Database) GetInvitesForUser( func (d *Database) GetInvitesForUser(
@ -364,22 +431,28 @@ func (d *Database) GetInvitesForUser(
roomNID types.RoomNID, roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID, targetUserNID types.EventStateKeyNID,
) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) { ) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) {
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
} }
func (d *Database) Events( func (d *Database) Events(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) { ) ([]types.Event, error) {
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) return d.events(ctx, nil, eventNIDs)
}
func (d *Database) events(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]types.Event, error) {
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs)
if err != nil { if err != nil {
eventIDs = map[types.EventNID]string{} eventIDs = map[types.EventNID]string{}
} }
var roomNIDs map[types.EventNID]types.RoomNID var roomNIDs map[types.EventNID]types.RoomNID
roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, eventNIDs) roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, txn, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -398,7 +471,7 @@ func (d *Database) Events(
} }
fetchNIDList = append(fetchNIDList, n) fetchNIDList = append(fetchNIDList, n)
} }
dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, fetchNIDList) dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, fetchNIDList)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -440,19 +513,19 @@ func (d *Database) MembershipUpdater(
return updater, err return updater, err
} }
func (d *Database) GetLatestEventsForUpdate( func (d *Database) GetRoomUpdater(
ctx context.Context, roomInfo types.RoomInfo, ctx context.Context, roomInfo *types.RoomInfo,
) (*LatestEventsUpdater, error) { ) (*RoomUpdater, error) {
if d.GetLatestEventsForUpdateFn != nil { if d.GetRoomUpdaterFn != nil {
return d.GetLatestEventsForUpdateFn(ctx, roomInfo) return d.GetRoomUpdaterFn(ctx, roomInfo)
} }
txn, err := d.DB.Begin() txn, err := d.DB.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
var updater *LatestEventsUpdater var updater *RoomUpdater
_ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo) updater, err = NewRoomUpdater(ctx, d, txn, roomInfo)
return err return err
}) })
return updater, err return updater, err
@ -461,6 +534,13 @@ func (d *Database) GetLatestEventsForUpdate(
func (d *Database) StoreEvent( func (d *Database) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event, ctx context.Context, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool, authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
return d.storeEvent(ctx, nil, event, authEventNIDs, isRejected)
}
func (d *Database) storeEvent(
ctx context.Context, updater *RoomUpdater, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { ) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
var ( var (
roomNID types.RoomNID roomNID types.RoomNID
@ -472,8 +552,11 @@ func (d *Database) StoreEvent(
redactedEventID string redactedEventID string
err error err error
) )
var txn *sql.Tx
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if updater != nil {
txn = updater.txn
}
err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
// TODO: Here we should aim to have two different code paths for new rooms // TODO: Here we should aim to have two different code paths for new rooms
// vs existing ones. // vs existing ones.
@ -546,42 +629,32 @@ func (d *Database) StoreEvent(
// events updater because it somewhat works as a mutex, ensuring // events updater because it somewhat works as a mutex, ensuring
// that there's a row-level lock on the latest room events (well, // that there's a row-level lock on the latest room events (well,
// on Postgres at least). // on Postgres at least).
var roomInfo *types.RoomInfo
var updater *LatestEventsUpdater
if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { if prevEvents := event.PrevEvents(); len(prevEvents) > 0 {
roomInfo, err = d.RoomInfo(ctx, event.RoomID())
if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
}
if roomInfo == nil && len(prevEvents) > 0 {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
}
// Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of
// GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This
// function only does SELECTs though so the created txn (at this point) is just a read txn like // function only does SELECTs though so the created txn (at this point) is just a read txn like
// any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
// to do writes however then this will need to go inside `Writer.Do`. // to do writes however then this will need to go inside `Writer.Do`.
updater, err = d.GetLatestEventsForUpdate(ctx, *roomInfo) succeeded := false
if err != nil { if updater == nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("NewLatestEventsUpdater: %w", err) var roomInfo *types.RoomInfo
} roomInfo, err = d.RoomInfo(ctx, event.RoomID())
// Ensure that we atomically store prev events AND commit them. If we don't wrap StorePreviousEvents if err != nil {
// and EndTransaction in a writer then it's possible for a new write txn to be made between the two return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
// function calls which will then fail with 'database is locked'. This new write txn would HAVE to be
// something like SetRoomAlias/RemoveRoomAlias as normal input events are already done sequentially due to
// SupportsConcurrentRoomInputs() == false on sqlite, though this does not apply to setting room aliases
// as they don't go via InputRoomEvents
err = d.Writer.Do(d.DB, updater.txn, func(txn *sql.Tx) error {
if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
return fmt.Errorf("updater.StorePreviousEvents: %w", err)
} }
succeeded := true if roomInfo == nil && len(prevEvents) > 0 {
err = sqlutil.EndTransaction(updater, &succeeded) return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
return err }
}) updater, err = d.GetRoomUpdater(ctx, roomInfo)
if err != nil { if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", err return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err)
}
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
} }
if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err)
}
succeeded = true
} }
return eventNID, roomNID, types.StateAtEvent{ return eventNID, roomNID, types.StateAtEvent{
@ -603,7 +676,7 @@ func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool)
} }
func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) {
return d.PublishedTable.SelectAllPublishedRooms(ctx, true) return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true)
} }
func (d *Database) assignRoomNID( func (d *Database) assignRoomNID(
@ -875,14 +948,14 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
eventNIDs = append(eventNIDs, e.EventNID) eventNIDs = append(eventNIDs, e.EventNID)
} }
} }
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil { if err != nil {
eventIDs = map[types.EventNID]string{} eventIDs = map[types.EventNID]string{}
} }
// return the event requested // return the event requested
for _, e := range entries { for _, e := range entries {
if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID { if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID {
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID}) data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{e.EventNID})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -922,11 +995,11 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
} }
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err) return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
} }
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState) roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err) return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
} }
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs) roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err) return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err)
} }
@ -945,7 +1018,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
} }
// we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which // we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which
// isn't a failure. // isn't a failure.
eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes) eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, eventTypes)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err) return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err)
} }
@ -965,7 +1038,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
} }
eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys) eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, eventStateKeys)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err) return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err)
} }
@ -999,11 +1072,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
} }
} }
} }
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil { if err != nil {
eventIDs = map[types.EventNID]string{} eventIDs = map[types.EventNID]string{}
} }
events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err) return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err)
} }
@ -1027,11 +1100,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs) roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs) userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1041,7 +1114,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
stateKeyNIDs[i] = nid stateKeyNIDs[i] = nid
i++ i++
} }
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs) nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1057,12 +1130,12 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
return d.MembershipTable.SelectLocalServerInRoom(ctx, roomNID) return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID)
} }
// GetServerInRoom returns true if we think a server is in a given room or false otherwise. // GetServerInRoom returns true if we think a server is in a given room or false otherwise.
func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
return d.MembershipTable.SelectServerInRoom(ctx, roomNID, serverName) return d.MembershipTable.SelectServerInRoom(ctx, nil, roomNID, serverName)
} }
// GetKnownUsers searches all users that userID knows about. // GetKnownUsers searches all users that userID knows about.
@ -1071,17 +1144,17 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
if err != nil { if err != nil {
return nil, err return nil, err
} }
return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit) return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
} }
// GetKnownRooms returns a list of all rooms we know about. // GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.RoomsTable.SelectRoomIDs(ctx) return d.RoomsTable.SelectRoomIDs(ctx, nil)
} }
// ForgetRoom sets a users room to forgotten // ForgetRoom sets a users room to forgotten
func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error { func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, []string{roomID}) roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, []string{roomID})
if err != nil { if err != nil {
return err return err
} }

View File

@ -76,15 +76,20 @@ func (s *eventJSONStatements) InsertEventJSON(
} }
func (s *eventJSONStatements) BulkSelectEventJSON( func (s *eventJSONStatements) BulkSelectEventJSON(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]tables.EventJSONPair, error) { ) ([]tables.EventJSONPair, error) {
iEventNIDs := make([]interface{}, len(eventNIDs)) iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs { for k, v := range eventNIDs {
iEventNIDs[k] = v iEventNIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
var rows *sql.Rows
rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...) var err error
if txn != nil {
rows, err = txn.QueryContext(ctx, selectOrig, iEventNIDs...)
} else {
rows, err = s.db.QueryContext(ctx, selectOrig, iEventNIDs...)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -112,15 +112,20 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID(
} }
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
ctx context.Context, eventStateKeys []string, ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) { ) (map[string]types.EventStateKeyNID, error) {
iEventStateKeys := make([]interface{}, len(eventStateKeys)) iEventStateKeys := make([]interface{}, len(eventStateKeys))
for k, v := range eventStateKeys { for k, v := range eventStateKeys {
iEventStateKeys[k] = v iEventStateKeys[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1) selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1)
var rows *sql.Rows
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...) var err error
if txn != nil {
rows, err = txn.QueryContext(ctx, selectOrig, iEventStateKeys...)
} else {
rows, err = s.db.QueryContext(ctx, selectOrig, iEventStateKeys...)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -138,15 +143,19 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
} }
func (s *eventStateKeyStatements) BulkSelectEventStateKey( func (s *eventStateKeyStatements) BulkSelectEventStateKey(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) { ) (map[types.EventStateKeyNID]string, error) {
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs)) iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
for k, v := range eventStateKeyNIDs { for k, v := range eventStateKeyNIDs {
iEventStateKeyNIDs[k] = v iEventStateKeyNIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1) selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) if err != nil {
return nil, err
}
stmt := sqlutil.TxStmt(txn, selectPrep)
rows, err := stmt.QueryContext(ctx, iEventStateKeyNIDs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -128,7 +128,7 @@ func (s *eventTypeStatements) SelectEventTypeNID(
} }
func (s *eventTypeStatements) BulkSelectEventTypeNID( func (s *eventTypeStatements) BulkSelectEventTypeNID(
ctx context.Context, eventTypes []string, ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) { ) (map[string]types.EventTypeNID, error) {
/////////////// ///////////////
iEventTypes := make([]interface{}, len(eventTypes)) iEventTypes := make([]interface{}, len(eventTypes))
@ -140,9 +140,10 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID(
if err != nil { if err != nil {
return nil, err return nil, err
} }
stmt := sqlutil.TxStmt(txn, selectPrep)
/////////////// ///////////////
rows, err := selectPrep.QueryContext(ctx, iEventTypes...) rows, err := stmt.QueryContext(ctx, iEventTypes...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -184,7 +184,7 @@ func (s *eventStatements) SelectEvent(
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByID( func (s *eventStatements) BulkSelectStateEventByID(
ctx context.Context, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
/////////////// ///////////////
iEventIDs := make([]interface{}, len(eventIDs)) iEventIDs := make([]interface{}, len(eventIDs))
@ -196,6 +196,7 @@ func (s *eventStatements) BulkSelectStateEventByID(
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
/////////////// ///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...) rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
@ -235,7 +236,7 @@ func (s *eventStatements) BulkSelectStateEventByID(
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByNID( func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
tuples := stateKeyTupleSorter(stateKeyTuples) tuples := stateKeyTupleSorter(stateKeyTuples)
@ -263,6 +264,7 @@ func (s *eventStatements) BulkSelectStateEventByNID(
if err != nil { if err != nil {
return nil, fmt.Errorf("s.db.Prepare: %w", err) return nil, fmt.Errorf("s.db.Prepare: %w", err)
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, params...) rows, err := selectStmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, fmt.Errorf("selectStmt.QueryContext: %w", err) return nil, fmt.Errorf("selectStmt.QueryContext: %w", err)
@ -291,7 +293,7 @@ func (s *eventStatements) BulkSelectStateEventByNID(
// If any of the requested events are missing from the database it returns a types.MissingEventError. // If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError.
func (s *eventStatements) BulkSelectStateAtEventByID( func (s *eventStatements) BulkSelectStateAtEventByID(
ctx context.Context, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateAtEvent, error) { ) ([]types.StateAtEvent, error) {
/////////////// ///////////////
iEventIDs := make([]interface{}, len(eventIDs)) iEventIDs := make([]interface{}, len(eventIDs))
@ -303,6 +305,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
/////////////// ///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...) rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
if err != nil { if err != nil {
@ -381,6 +384,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference(
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectPrep = sqlutil.TxStmt(txn, selectPrep)
////////////// //////////////
rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...)
@ -454,7 +458,7 @@ func (s *eventStatements) BulkSelectEventReference(
} }
// bulkSelectEventID returns a map from numeric event ID to string event ID. // bulkSelectEventID returns a map from numeric event ID to string event ID.
func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
/////////////// ///////////////
iEventNIDs := make([]interface{}, len(eventNIDs)) iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs { for k, v := range eventNIDs {
@ -465,6 +469,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
/////////////// ///////////////
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
@ -490,7 +495,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
/////////////// ///////////////
iEventIDs := make([]interface{}, len(eventIDs)) iEventIDs := make([]interface{}, len(eventIDs))
for k, v := range eventIDs { for k, v := range eventIDs {
@ -501,6 +506,7 @@ func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []str
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
/////////////// ///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...) rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
if err != nil { if err != nil {
@ -538,13 +544,14 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
} }
func (s *eventStatements) SelectRoomNIDsForEventNIDs( func (s *eventStatements) SelectRoomNIDsForEventNIDs(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) (map[types.EventNID]types.RoomNID, error) { ) (map[types.EventNID]types.RoomNID, error) {
sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1) sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
sqlPrep, err := s.db.Prepare(sqlStr) sqlPrep, err := s.db.Prepare(sqlStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
iEventNIDs := make([]interface{}, len(eventNIDs)) iEventNIDs := make([]interface{}, len(eventNIDs))
for i, v := range eventNIDs { for i, v := range eventNIDs {
iEventNIDs[i] = v iEventNIDs[i] = v

View File

@ -88,8 +88,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
} }
func (s *inviteStatements) InsertInviteEvent( func (s *inviteStatements) InsertInviteEvent(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, inviteEventID string, roomNID types.RoomNID,
targetUserNID, senderUserNID types.EventStateKeyNID, targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte, inviteEventJSON []byte,
) (bool, error) { ) (bool, error) {
@ -109,8 +109,8 @@ func (s *inviteStatements) InsertInviteEvent(
} }
func (s *inviteStatements) UpdateInviteRetired( func (s *inviteStatements) UpdateInviteRetired(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) { ) (eventIDs []string, err error) {
// gather all the event IDs we will retire // gather all the event IDs we will retire
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
@ -134,10 +134,11 @@ func (s *inviteStatements) UpdateInviteRetired(
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs // selectInviteActiveForUserInRoom returns a list of sender state key NIDs
func (s *inviteStatements) SelectInviteActiveForUserInRoom( func (s *inviteStatements) SelectInviteActiveForUserInRoom(
ctx context.Context, ctx context.Context, txn *sql.Tx,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, []string, error) { ) ([]types.EventStateKeyNID, []string, error) {
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
rows, err := stmt.QueryContext(
ctx, targetUserNID, roomNID, ctx, targetUserNID, roomNID,
) )
if err != nil { if err != nil {

View File

@ -184,17 +184,18 @@ func (s *membershipStatements) SelectMembershipForUpdate(
} }
func (s *membershipStatements) SelectMembershipFromRoomAndTarget( func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
ctx context.Context, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
err = stmt.QueryRowContext(
ctx, roomNID, targetUserNID, ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID, &forgotten) ).Scan(&membership, &eventNID, &forgotten)
return return
} }
func (s *membershipStatements) SelectMembershipsFromRoom( func (s *membershipStatements) SelectMembershipsFromRoom(
ctx context.Context, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, localOnly bool, roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var selectStmt *sql.Stmt var selectStmt *sql.Stmt
@ -203,6 +204,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
} else { } else {
selectStmt = s.selectMembershipsFromRoomStmt selectStmt = s.selectMembershipsFromRoomStmt
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, roomNID) rows, err := selectStmt.QueryContext(ctx, roomNID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -220,7 +222,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
} }
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var stmt *sql.Stmt var stmt *sql.Stmt
@ -229,6 +231,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
} else { } else {
stmt = s.selectMembershipsFromRoomAndMembershipStmt stmt = s.selectMembershipsFromRoomAndMembershipStmt
} }
stmt = sqlutil.TxStmt(txn, stmt)
rows, err := stmt.QueryContext(ctx, roomNID, membership) rows, err := stmt.QueryContext(ctx, roomNID, membership)
if err != nil { if err != nil {
return return
@ -258,9 +261,10 @@ func (s *membershipStatements) UpdateMembership(
} }
func (s *membershipStatements) SelectRoomsWithMembership( func (s *membershipStatements) SelectRoomsWithMembership(
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) { ) ([]types.RoomNID, error) {
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
rows, err := stmt.QueryContext(ctx, membershipState, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -276,13 +280,19 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil return roomNIDs, nil
} }
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs)) iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs { for i, v := range roomNIDs {
iRoomNIDs[i] = v iRoomNIDs[i] = v
} }
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...) var rows *sql.Rows
var err error
if txn != nil {
rows, err = txn.QueryContext(ctx, query, iRoomNIDs...)
} else {
rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -299,8 +309,9 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
return result, rows.Err() return result, rows.Err()
} }
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { func (s *membershipStatements) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt)
rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -317,8 +328,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
} }
func (s *membershipStatements) UpdateForgetMembership( func (s *membershipStatements) UpdateForgetMembership(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
forget bool, forget bool,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
@ -327,9 +338,10 @@ func (s *membershipStatements) UpdateForgetMembership(
return err return err
} }
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) {
var nid types.RoomNID var nid types.RoomNID
err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil
@ -340,9 +352,10 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
return found, nil return found, nil
} }
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
var nid types.RoomNID var nid types.RoomNID
err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil

View File

@ -75,9 +75,10 @@ func (s *publishedStatements) UpsertRoomPublished(
} }
func (s *publishedStatements) SelectPublishedFromRoomID( func (s *publishedStatements) SelectPublishedFromRoomID(
ctx context.Context, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (published bool, err error) { ) (published bool, err error) {
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt)
err = stmt.QueryRowContext(ctx, roomID).Scan(&published)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil
} }
@ -85,9 +86,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
} }
func (s *publishedStatements) SelectAllPublishedRooms( func (s *publishedStatements) SelectAllPublishedRooms(
ctx context.Context, published bool, ctx context.Context, txn *sql.Tx, published bool,
) ([]string, error) { ) ([]string, error) {
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
rows, err := stmt.QueryContext(ctx, published)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -91,9 +91,10 @@ func (s *roomAliasesStatements) InsertRoomAlias(
} }
func (s *roomAliasesStatements) SelectRoomIDFromAlias( func (s *roomAliasesStatements) SelectRoomIDFromAlias(
ctx context.Context, alias string, ctx context.Context, txn *sql.Tx, alias string,
) (roomID string, err error) { ) (roomID string, err error) {
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt)
err = stmt.QueryRowContext(ctx, alias).Scan(&roomID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }
@ -101,10 +102,11 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias(
} }
func (s *roomAliasesStatements) SelectAliasesFromRoomID( func (s *roomAliasesStatements) SelectAliasesFromRoomID(
ctx context.Context, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (aliases []string, err error) { ) (aliases []string, err error) {
aliases = []string{} aliases = []string{}
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
rows, err := stmt.QueryContext(ctx, roomID)
if err != nil { if err != nil {
return return
} }
@ -124,9 +126,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
} }
func (s *roomAliasesStatements) SelectCreatorIDFromAlias( func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
ctx context.Context, alias string, ctx context.Context, txn *sql.Tx, alias string,
) (creatorID string, err error) { ) (creatorID string, err error) {
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }

View File

@ -107,8 +107,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
rows, err := s.selectRoomIDsStmt.QueryContext(ctx) stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -124,10 +125,11 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
return roomIDs, nil return roomIDs, nil
} }
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
var info types.RoomInfo var info types.RoomInfo
var latestNIDsJSON string var latestNIDsJSON string
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan( stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON, &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON,
) )
if err != nil { if err != nil {
@ -224,13 +226,14 @@ func (s *roomStatements) UpdateLatestEventNIDs(
} }
func (s *roomStatements) SelectRoomVersionsForRoomNIDs( func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
ctx context.Context, roomNIDs []types.RoomNID, ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID,
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) { ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
sqlPrep, err := s.db.Prepare(sqlStr) sqlPrep, err := s.db.Prepare(sqlStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
iRoomNIDs := make([]interface{}, len(roomNIDs)) iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs { for i, v := range roomNIDs {
iRoomNIDs[i] = v iRoomNIDs[i] = v
@ -252,13 +255,19 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
return result, nil return result, nil
} }
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs)) iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs { for i, v := range roomNIDs {
iRoomNIDs[i] = v iRoomNIDs[i] = v
} }
sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) var rows *sql.Rows
var err error
if txn != nil {
rows, err = txn.QueryContext(ctx, sqlQuery, iRoomNIDs...)
} else {
rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -274,13 +283,19 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
return roomIDs, nil return roomIDs, nil
} }
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) {
iRoomIDs := make([]interface{}, len(roomIDs)) iRoomIDs := make([]interface{}, len(roomIDs))
for i, v := range roomIDs { for i, v := range roomIDs {
iRoomIDs[i] = v iRoomIDs[i] = v
} }
sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...) var rows *sql.Rows
var err error
if txn != nil {
rows, err = txn.QueryContext(ctx, sqlQuery, iRoomIDs...)
} else {
rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomIDs...)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -81,8 +81,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
} }
func (s *stateBlockStatements) BulkInsertStateData( func (s *stateBlockStatements) BulkInsertStateData(
ctx context.Context, ctx context.Context, txn *sql.Tx,
txn *sql.Tx,
entries types.StateEntries, entries types.StateEntries,
) (id types.StateBlockNID, err error) { ) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)] entries = entries[:util.SortAndUnique(entries)]
@ -94,14 +93,15 @@ func (s *stateBlockStatements) BulkInsertStateData(
if err != nil { if err != nil {
return 0, fmt.Errorf("json.Marshal: %w", err) return 0, fmt.Errorf("json.Marshal: %w", err)
} }
err = s.insertStateDataStmt.QueryRowContext( stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
err = stmt.QueryRowContext(
ctx, nids.Hash(), js, ctx, nids.Hash(), js,
).Scan(&id) ).Scan(&id)
return return
} }
func (s *stateBlockStatements) BulkSelectStateBlockEntries( func (s *stateBlockStatements) BulkSelectStateBlockEntries(
ctx context.Context, stateBlockNIDs types.StateBlockNIDs, ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs,
) ([][]types.EventNID, error) { ) ([][]types.EventNID, error) {
intfs := make([]interface{}, len(stateBlockNIDs)) intfs := make([]interface{}, len(stateBlockNIDs))
for i := range stateBlockNIDs { for i := range stateBlockNIDs {
@ -112,6 +112,7 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, intfs...) rows, err := selectStmt.QueryContext(ctx, intfs...)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -106,7 +106,7 @@ func (s *stateSnapshotStatements) InsertState(
} }
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID, ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) { ) ([]types.StateBlockNIDList, error) {
nids := make([]interface{}, len(stateNIDs)) nids := make([]interface{}, len(stateNIDs))
for k, v := range stateNIDs { for k, v := range stateNIDs {
@ -117,6 +117,7 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, nids...) rows, err := selectStmt.QueryContext(ctx, nids...)
if err != nil { if err != nil {

View File

@ -172,23 +172,23 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
return err return err
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: db, DB: db,
Cache: cache, Cache: cache,
Writer: sqlutil.NewExclusiveWriter(), Writer: sqlutil.NewExclusiveWriter(),
EventsTable: events, EventsTable: events,
EventTypesTable: eventTypes, EventTypesTable: eventTypes,
EventStateKeysTable: eventStateKeys, EventStateKeysTable: eventStateKeys,
EventJSONTable: eventJSON, EventJSONTable: eventJSON,
RoomsTable: rooms, RoomsTable: rooms,
StateBlockTable: stateBlock, StateBlockTable: stateBlock,
StateSnapshotTable: stateSnapshot, StateSnapshotTable: stateSnapshot,
PrevEventsTable: prevEvents, PrevEventsTable: prevEvents,
RoomAliasesTable: roomAliases, RoomAliasesTable: roomAliases,
InvitesTable: invites, InvitesTable: invites,
MembershipTable: membership, MembershipTable: membership,
PublishedTable: published, PublishedTable: published,
RedactionsTable: redactions, RedactionsTable: redactions,
GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate, GetRoomUpdaterFn: d.GetRoomUpdater,
} }
return nil return nil
} }
@ -201,16 +201,16 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
return false return false
} }
func (d *Database) GetLatestEventsForUpdate( func (d *Database) GetRoomUpdater(
ctx context.Context, roomInfo types.RoomInfo, ctx context.Context, roomInfo *types.RoomInfo,
) (*shared.LatestEventsUpdater, error) { ) (*shared.RoomUpdater, error) {
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have // TODO: Do not use transactions. We should be holding open this transaction but we cannot have
// multiple write transactions on sqlite. The code will perform additional // multiple write transactions on sqlite. The code will perform additional
// write transactions independent of this one which will consistently cause // write transactions independent of this one which will consistently cause
// 'database is locked' errors. As sqlite doesn't support multi-process on the // 'database is locked' errors. As sqlite doesn't support multi-process on the
// same DB anyway, and we only execute updates sequentially, the only worries // same DB anyway, and we only execute updates sequentially, the only worries
// are for rolling back when things go wrong. (atomicity) // are for rolling back when things go wrong. (atomicity)
return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo) return shared.NewRoomUpdater(ctx, &d.Database, nil, roomInfo)
} }
func (d *Database) MembershipUpdater( func (d *Database) MembershipUpdater(

View File

@ -18,20 +18,20 @@ type EventJSONPair struct {
type EventJSON interface { type EventJSON interface {
// Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions). // Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions).
InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error
BulkSelectEventJSON(ctx context.Context, eventNIDs []types.EventNID) ([]EventJSONPair, error) BulkSelectEventJSON(ctx context.Context, tx *sql.Tx, eventNIDs []types.EventNID) ([]EventJSONPair, error)
} }
type EventTypes interface { type EventTypes interface {
InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error)
SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error)
BulkSelectEventTypeNID(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) BulkSelectEventTypeNID(ctx context.Context, txn *sql.Tx, eventTypes []string) (map[string]types.EventTypeNID, error)
} }
type EventStateKeys interface { type EventStateKeys interface {
InsertEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) InsertEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error)
SelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) SelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error)
BulkSelectEventStateKeyNID(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) BulkSelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
BulkSelectEventStateKey(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) BulkSelectEventStateKey(ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
} }
type Events interface { type Events interface {
@ -42,12 +42,12 @@ type Events interface {
SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error)
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateEntry, error)
BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error) BulkSelectStateEventByNID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error)
// BulkSelectStateAtEventByID lookups the state at a list of events by event ID. // BulkSelectStateAtEventByID lookups the state at a list of events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError. // If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError.
BulkSelectStateAtEventByID(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) BulkSelectStateAtEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateAtEvent, error)
UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error
SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error)
UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error
@ -55,12 +55,12 @@ type Events interface {
BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error)
BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error)
// BulkSelectEventID returns a map from numeric event ID to string event ID. // BulkSelectEventID returns a map from numeric event ID to string event ID.
BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error)
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
SelectRoomNIDsForEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
} }
type Rooms interface { type Rooms interface {
@ -69,29 +69,29 @@ type Rooms interface {
SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error)
SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error)
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
SelectRoomVersionsForRoomNIDs(ctx context.Context, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error)
SelectRoomIDs(ctx context.Context) ([]string, error) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error)
BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error)
BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error)
} }
type StateSnapshot interface { type StateSnapshot interface {
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error) InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error)
BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) BulkSelectStateBlockNIDs(ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
} }
type StateBlock interface { type StateBlock interface {
BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error) BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error)
BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error) BulkSelectStateBlockEntries(ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error)
//BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) //BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
} }
type RoomAliases interface { type RoomAliases interface {
InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error) InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error)
SelectRoomIDFromAlias(ctx context.Context, alias string) (roomID string, err error) SelectRoomIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (roomID string, err error)
SelectAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error) SelectAliasesFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) ([]string, error)
SelectCreatorIDFromAlias(ctx context.Context, alias string) (creatorID string, err error) SelectCreatorIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (creatorID string, err error)
DeleteRoomAlias(ctx context.Context, txn *sql.Tx, alias string) (err error) DeleteRoomAlias(ctx context.Context, txn *sql.Tx, alias string) (err error)
} }
@ -106,7 +106,7 @@ type Invites interface {
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error) InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error)
UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error) UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error)
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids. // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids.
SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error) SelectInviteActiveForUserInRoom(ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error)
} }
type MembershipState int64 type MembershipState int64
@ -121,24 +121,24 @@ const (
type Membership interface { type Membership interface {
InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error
SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error) SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error)
SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error) SelectMembershipFromRoomAndTarget(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error)
SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
// counts of how many rooms they are joined. // counts of how many rooms they are joined.
SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
} }
type Published interface { type Published interface {
UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error) UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error)
SelectPublishedFromRoomID(ctx context.Context, roomID string) (published bool, err error) SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error)
SelectAllPublishedRooms(ctx context.Context, published bool) ([]string, error) SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, published bool) ([]string, error)
} }
type RedactionInfo struct { type RedactionInfo struct {