mirror of
https://github.com/1f349/dendrite.git
synced 2024-12-22 08:14:13 +00:00
Transaction writer changes, move roomserver writers (#1285)
* Updated TransactionWriters, moved locks in roomserver, various other tweaks * Fix redaction deadlocks * Fix lint issue * Rename SQLiteTransactionWriter to ExclusiveTransactionWriter * Fix us not sending transactions through in latest events updater
This commit is contained in:
parent
775b04d776
commit
b24747b305
@ -67,7 +67,7 @@ const (
|
||||
|
||||
type eventsStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
selectEventsByApplicationServiceIDStmt *sql.Stmt
|
||||
countEventsByApplicationServiceIDStmt *sql.Stmt
|
||||
insertEventStmt *sql.Stmt
|
||||
|
@ -38,7 +38,7 @@ const selectTxnIDSQL = `
|
||||
|
||||
type txnStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
selectTxnIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
|
@ -83,7 +83,7 @@ const selectKnownUsersSQL = "" +
|
||||
|
||||
type currentRoomStateStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
upsertRoomStateStmt *sql.Stmt
|
||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||
selectRoomIDsWithMembershipStmt *sql.Stmt
|
||||
|
@ -42,7 +42,6 @@ const deleteBlacklistSQL = "" +
|
||||
|
||||
type blacklistStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertBlacklistStmt *sql.Stmt
|
||||
selectBlacklistStmt *sql.Stmt
|
||||
deleteBlacklistStmt *sql.Stmt
|
||||
@ -50,8 +49,7 @@ type blacklistStatements struct {
|
||||
|
||||
func NewPostgresBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
|
||||
s = &blacklistStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(blacklistSchema)
|
||||
if err != nil {
|
||||
@ -75,11 +73,9 @@ func NewPostgresBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
|
||||
func (s *blacklistStatements) InsertBlacklist(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
})
|
||||
stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
|
||||
@ -105,9 +101,7 @@ func (s *blacklistStatements) SelectBlacklist(
|
||||
func (s *blacklistStatements) DeleteBlacklist(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
})
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ const deleteBlacklistSQL = "" +
|
||||
|
||||
type blacklistStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
insertBlacklistStmt *sql.Stmt
|
||||
selectBlacklistStmt *sql.Stmt
|
||||
deleteBlacklistStmt *sql.Stmt
|
||||
|
@ -65,7 +65,7 @@ const selectJoinedHostsForRoomsSQL = "" +
|
||||
|
||||
type joinedHostsStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
insertJoinedHostsStmt *sql.Stmt
|
||||
deleteJoinedHostsStmt *sql.Stmt
|
||||
selectJoinedHostsStmt *sql.Stmt
|
||||
|
@ -64,7 +64,7 @@ const selectQueueServerNamesSQL = "" +
|
||||
|
||||
type queueEDUsStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
insertQueueEDUStmt *sql.Stmt
|
||||
selectQueueEDUStmt *sql.Stmt
|
||||
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
||||
|
@ -50,7 +50,7 @@ const selectJSONSQL = "" +
|
||||
|
||||
type queueJSONStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
insertJSONStmt *sql.Stmt
|
||||
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
|
@ -71,7 +71,7 @@ const selectQueuePDUsServerNamesSQL = "" +
|
||||
|
||||
type queuePDUsStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
insertQueuePDUStmt *sql.Stmt
|
||||
selectQueueNextTransactionIDStmt *sql.Stmt
|
||||
selectQueuePDUsByTransactionStmt *sql.Stmt
|
||||
|
@ -44,7 +44,7 @@ const updateRoomSQL = "" +
|
||||
|
||||
type roomStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
insertRoomStmt *sql.Stmt
|
||||
selectRoomForUpdateStmt *sql.Stmt
|
||||
updateRoomStmt *sql.Stmt
|
||||
|
@ -19,8 +19,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// ErrUserExists is returned if a username already exists in the database.
|
||||
@ -52,7 +50,7 @@ func EndTransaction(txn Transaction, succeeded *bool) error {
|
||||
func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
return
|
||||
return fmt.Errorf("sqlutil.WithTransaction.Begin: %w", err)
|
||||
}
|
||||
succeeded := false
|
||||
defer func() {
|
||||
@ -106,69 +104,6 @@ func SQLiteDriverName() string {
|
||||
return "sqlite3"
|
||||
}
|
||||
|
||||
// TransactionWriter allows queuing database writes so that you don't
|
||||
// contend on database locks in, e.g. SQLite. Only one task will run
|
||||
// at a time on a given TransactionWriter.
|
||||
type TransactionWriter struct {
|
||||
running atomic.Bool
|
||||
todo chan transactionWriterTask
|
||||
}
|
||||
|
||||
func NewTransactionWriter() *TransactionWriter {
|
||||
return &TransactionWriter{
|
||||
todo: make(chan transactionWriterTask),
|
||||
}
|
||||
}
|
||||
|
||||
// transactionWriterTask represents a specific task.
|
||||
type transactionWriterTask struct {
|
||||
db *sql.DB
|
||||
txn *sql.Tx
|
||||
f func(txn *sql.Tx) error
|
||||
wait chan error
|
||||
}
|
||||
|
||||
// Do queues a task to be run by a TransactionWriter. The function
|
||||
// provided will be ran within a transaction as supplied by the
|
||||
// txn parameter if one is supplied, and if not, will take out a
|
||||
// new transaction from the database supplied in the database
|
||||
// parameter. Either way, this will block until the task is done.
|
||||
func (w *TransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
|
||||
if w.todo == nil {
|
||||
return errors.New("not initialised")
|
||||
}
|
||||
if !w.running.Load() {
|
||||
go w.run()
|
||||
}
|
||||
task := transactionWriterTask{
|
||||
db: db,
|
||||
txn: txn,
|
||||
f: f,
|
||||
wait: make(chan error, 1),
|
||||
}
|
||||
w.todo <- task
|
||||
return <-task.wait
|
||||
}
|
||||
|
||||
// run processes the tasks for a given transaction writer. Only one
|
||||
// of these goroutines will run at a time. A transaction will be
|
||||
// opened using the database object from the task and then this will
|
||||
// be passed as a parameter to the task function.
|
||||
func (w *TransactionWriter) run() {
|
||||
if !w.running.CAS(false, true) {
|
||||
return
|
||||
}
|
||||
defer w.running.Store(false)
|
||||
for task := range w.todo {
|
||||
if task.txn != nil {
|
||||
task.wait <- task.f(task.txn)
|
||||
} else if task.db != nil {
|
||||
task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
|
||||
return task.f(txn)
|
||||
})
|
||||
} else {
|
||||
panic("expected database or transaction but got neither")
|
||||
}
|
||||
close(task.wait)
|
||||
}
|
||||
type TransactionWriter interface {
|
||||
Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error
|
||||
}
|
||||
|
22
internal/sqlutil/writer_dummy.go
Normal file
22
internal/sqlutil/writer_dummy.go
Normal file
@ -0,0 +1,22 @@
|
||||
package sqlutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type DummyTransactionWriter struct {
|
||||
}
|
||||
|
||||
func NewDummyTransactionWriter() TransactionWriter {
|
||||
return &DummyTransactionWriter{}
|
||||
}
|
||||
|
||||
func (w *DummyTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
|
||||
if txn == nil {
|
||||
return WithTransaction(db, func(txn *sql.Tx) error {
|
||||
return f(txn)
|
||||
})
|
||||
} else {
|
||||
return f(txn)
|
||||
}
|
||||
}
|
75
internal/sqlutil/writer_exclusive.go
Normal file
75
internal/sqlutil/writer_exclusive.go
Normal file
@ -0,0 +1,75 @@
|
||||
package sqlutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// ExclusiveTransactionWriter allows queuing database writes so that you don't
|
||||
// contend on database locks in, e.g. SQLite. Only one task will run
|
||||
// at a time on a given ExclusiveTransactionWriter.
|
||||
type ExclusiveTransactionWriter struct {
|
||||
running atomic.Bool
|
||||
todo chan transactionWriterTask
|
||||
}
|
||||
|
||||
func NewTransactionWriter() TransactionWriter {
|
||||
return &ExclusiveTransactionWriter{
|
||||
todo: make(chan transactionWriterTask),
|
||||
}
|
||||
}
|
||||
|
||||
// transactionWriterTask represents a specific task.
|
||||
type transactionWriterTask struct {
|
||||
db *sql.DB
|
||||
txn *sql.Tx
|
||||
f func(txn *sql.Tx) error
|
||||
wait chan error
|
||||
}
|
||||
|
||||
// Do queues a task to be run by a TransactionWriter. The function
|
||||
// provided will be ran within a transaction as supplied by the
|
||||
// txn parameter if one is supplied, and if not, will take out a
|
||||
// new transaction from the database supplied in the database
|
||||
// parameter. Either way, this will block until the task is done.
|
||||
func (w *ExclusiveTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
|
||||
if w.todo == nil {
|
||||
return errors.New("not initialised")
|
||||
}
|
||||
if !w.running.Load() {
|
||||
go w.run()
|
||||
}
|
||||
task := transactionWriterTask{
|
||||
db: db,
|
||||
txn: txn,
|
||||
f: f,
|
||||
wait: make(chan error, 1),
|
||||
}
|
||||
w.todo <- task
|
||||
return <-task.wait
|
||||
}
|
||||
|
||||
// run processes the tasks for a given transaction writer. Only one
|
||||
// of these goroutines will run at a time. A transaction will be
|
||||
// opened using the database object from the task and then this will
|
||||
// be passed as a parameter to the task function.
|
||||
func (w *ExclusiveTransactionWriter) run() {
|
||||
if !w.running.CAS(false, true) {
|
||||
return
|
||||
}
|
||||
defer w.running.Store(false)
|
||||
for task := range w.todo {
|
||||
if task.txn != nil {
|
||||
task.wait <- task.f(task.txn)
|
||||
} else if task.db != nil {
|
||||
task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
|
||||
return task.f(txn)
|
||||
})
|
||||
} else {
|
||||
panic("expected database or transaction but got neither")
|
||||
}
|
||||
close(task.wait)
|
||||
}
|
||||
}
|
@ -63,7 +63,7 @@ const deleteAllDeviceKeysSQL = "" +
|
||||
|
||||
type deviceKeysStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
upsertDeviceKeysStmt *sql.Stmt
|
||||
selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
|
@ -52,7 +52,7 @@ const selectKeyChangesSQL = "" +
|
||||
|
||||
type keyChangesStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
upsertKeyChangeStmt *sql.Stmt
|
||||
selectKeyChangesStmt *sql.Stmt
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ const selectKeyByAlgorithmSQL = "" +
|
||||
|
||||
type oneTimeKeysStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
upsertKeysStmt *sql.Stmt
|
||||
selectKeysStmt *sql.Stmt
|
||||
selectKeysCountStmt *sql.Stmt
|
||||
|
@ -62,7 +62,7 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user
|
||||
|
||||
type mediaStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
insertMediaStmt *sql.Stmt
|
||||
selectMediaStmt *sql.Stmt
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
|
||||
) (err error) {
|
||||
updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID)
|
||||
if err != nil {
|
||||
return
|
||||
return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err)
|
||||
}
|
||||
succeeded := false
|
||||
defer func() {
|
||||
@ -79,7 +79,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
|
||||
}
|
||||
|
||||
if err = u.doUpdateLatestEvents(); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("u.doUpdateLatestEvents: %w", err)
|
||||
}
|
||||
|
||||
succeeded = true
|
||||
@ -137,7 +137,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||
// don't need to do anything, as we've handled it already.
|
||||
hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("u.updater.HasEventBeenSent: %w", err)
|
||||
} else if hasBeenSent {
|
||||
return nil
|
||||
}
|
||||
@ -145,7 +145,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||
// Update the roomserver_previous_events table with references. This
|
||||
// is effectively tracking the structure of the DAG.
|
||||
if err = u.updater.StorePreviousEvents(u.stateAtEvent.EventNID, prevEvents); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("u.updater.StorePreviousEvents: %w", err)
|
||||
}
|
||||
|
||||
// Get the event reference for our new event. This will be used when
|
||||
@ -156,7 +156,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||
// in the room. If it is then it isn't a latest event.
|
||||
alreadyReferenced, err := u.updater.IsReferenced(eventReference)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("u.updater.IsReferenced: %w", err)
|
||||
}
|
||||
|
||||
// Work out what the latest events are.
|
||||
@ -173,19 +173,19 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||
// Now that we know what the latest events are, it's time to get the
|
||||
// latest state.
|
||||
if err = u.latestState(); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("u.latestState: %w", err)
|
||||
}
|
||||
|
||||
// If we need to generate any output events then here's where we do it.
|
||||
// TODO: Move this!
|
||||
updates, err := u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("u.api.updateMemberships: %w", err)
|
||||
}
|
||||
|
||||
update, err := u.makeOutputNewRoomEvent()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err)
|
||||
}
|
||||
updates = append(updates, *update)
|
||||
|
||||
@ -198,14 +198,18 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||
// the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
|
||||
// necessary bookkeeping we'll keep the event sending synchronous for now.
|
||||
if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("u.api.WriteOutputEvents: %w", err)
|
||||
}
|
||||
|
||||
if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("u.updater.SetLatestEvents: %w", err)
|
||||
}
|
||||
|
||||
return u.updater.MarkEventAsSent(u.stateAtEvent.EventNID)
|
||||
if err = u.updater.MarkEventAsSent(u.stateAtEvent.EventNID); err != nil {
|
||||
return fmt.Errorf("u.updater.MarkEventAsSent: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *latestEventsUpdater) latestState() error {
|
||||
@ -225,7 +229,7 @@ func (u *latestEventsUpdater) latestState() error {
|
||||
u.ctx, u.roomNID, latestStateAtEvents,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err)
|
||||
}
|
||||
|
||||
// If we are overwriting the state then we should make sure that we
|
||||
@ -244,7 +248,7 @@ func (u *latestEventsUpdater) latestState() error {
|
||||
u.ctx, u.oldStateNID, u.newStateNID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("roomState.DifferenceBetweenStateSnapshots: %w", err)
|
||||
}
|
||||
|
||||
// Also work out the state before the event removes and the event
|
||||
@ -252,7 +256,11 @@ func (u *latestEventsUpdater) latestState() error {
|
||||
u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots(
|
||||
u.ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return fmt.Errorf("roomState.DifferenceBetweeenStateSnapshots: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func calculateLatest(
|
||||
|
@ -558,7 +558,11 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents(
|
||||
// 2) There weren't any prev_events for this event so the state is
|
||||
// empty.
|
||||
metrics.algorithm = "empty_state"
|
||||
return metrics.stop(v.db.AddState(ctx, roomNID, nil, nil))
|
||||
stateNID, err := v.db.AddState(ctx, roomNID, nil, nil)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("v.db.AddState: %w", err)
|
||||
}
|
||||
return metrics.stop(stateNID, err)
|
||||
}
|
||||
|
||||
if len(prevStates) == 1 {
|
||||
@ -578,22 +582,30 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents(
|
||||
)
|
||||
if err != nil {
|
||||
metrics.algorithm = "_load_state_blocks"
|
||||
return metrics.stop(0, err)
|
||||
return metrics.stop(0, fmt.Errorf("v.db.StateBlockNIDs: %w", err))
|
||||
}
|
||||
stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs
|
||||
if len(stateBlockNIDs) < maxStateBlockNIDs {
|
||||
// 4) The number of state data blocks is small enough that we can just
|
||||
// add the state event as a block of size one to the end of the blocks.
|
||||
metrics.algorithm = "single_delta"
|
||||
return metrics.stop(v.db.AddState(
|
||||
stateNID, err := v.db.AddState(
|
||||
ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
|
||||
))
|
||||
)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("v.db.AddState: %w", err)
|
||||
}
|
||||
return metrics.stop(stateNID, err)
|
||||
}
|
||||
// If there are too many deltas then we need to calculate the full state
|
||||
// So fall through to calculateAndStoreStateAfterManyEvents
|
||||
}
|
||||
|
||||
return v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics)
|
||||
stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err)
|
||||
}
|
||||
return stateNID, nil
|
||||
}
|
||||
|
||||
// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state.
|
||||
|
@ -98,6 +98,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||
}
|
||||
d.Database = shared.Database{
|
||||
DB: db,
|
||||
Writer: sqlutil.NewDummyTransactionWriter(),
|
||||
EventTypesTable: eventTypes,
|
||||
EventStateKeysTable: eventStateKeys,
|
||||
EventJSONTable: eventJSON,
|
||||
|
@ -3,6 +3,7 @@ package shared
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
@ -65,12 +66,14 @@ func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
|
||||
|
||||
// StorePreviousEvents implements types.RoomRecentEventsUpdater
|
||||
func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
||||
for _, ref := range previousEventReferences {
|
||||
if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
||||
return err
|
||||
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||
for _, ref := range previousEventReferences {
|
||||
if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
||||
return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// IsReferenced implements types.RoomRecentEventsUpdater
|
||||
@ -82,7 +85,7 @@ func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.Even
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err)
|
||||
}
|
||||
|
||||
// SetLatestEvents implements types.RoomRecentEventsUpdater
|
||||
@ -94,7 +97,12 @@ func (u *LatestEventsUpdater) SetLatestEvents(
|
||||
for i := range latest {
|
||||
eventNIDs[i] = latest[i].EventNID
|
||||
}
|
||||
return u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
|
||||
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||
if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
|
||||
return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// HasEventBeenSent implements types.RoomRecentEventsUpdater
|
||||
@ -104,7 +112,9 @@ func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, e
|
||||
|
||||
// MarkEventAsSent implements types.RoomRecentEventsUpdater
|
||||
func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
||||
return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, u.txn, eventNID)
|
||||
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||
return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID)
|
||||
})
|
||||
}
|
||||
|
||||
func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
|
||||
|
@ -3,6 +3,7 @@ package shared
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
@ -41,9 +42,14 @@ func (d *Database) membershipUpdaterTxn(
|
||||
targetUserNID types.EventStateKeyNID,
|
||||
targetLocal bool,
|
||||
) (*MembershipUpdater, error) {
|
||||
|
||||
if err := d.MembershipTable.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
|
||||
return nil, err
|
||||
err := d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
|
||||
if err := d.MembershipTable.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
|
||||
return fmt.Errorf("d.MembershipTable.InsertMembership: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("u.d.Writer.Do: %w", err)
|
||||
}
|
||||
|
||||
membership, err := d.MembershipTable.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
|
||||
@ -75,19 +81,19 @@ func (u *MembershipUpdater) IsLeave() bool {
|
||||
func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) {
|
||||
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
|
||||
}
|
||||
inserted, err := u.d.InvitesTable.InsertInviteEvent(
|
||||
u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err)
|
||||
}
|
||||
if u.membership != tables.MembershipStateInvite {
|
||||
if err = u.d.MembershipTable.UpdateMembership(
|
||||
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0,
|
||||
); err != nil {
|
||||
return false, err
|
||||
return false, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
|
||||
}
|
||||
}
|
||||
return inserted, nil
|
||||
@ -99,7 +105,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
|
||||
|
||||
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
|
||||
}
|
||||
|
||||
// If this is a join event update, there is no invite to update
|
||||
@ -108,14 +114,14 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
|
||||
u.ctx, u.txn, u.roomNID, u.targetUserNID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Look up the NID of the new join event
|
||||
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("u.d.EventNIDs: %w", err)
|
||||
}
|
||||
|
||||
if u.membership != tables.MembershipStateJoin || isUpdate {
|
||||
@ -123,7 +129,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
|
||||
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
|
||||
tables.MembershipStateJoin, nIDs[eventID],
|
||||
); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -134,19 +140,19 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
|
||||
func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
|
||||
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
|
||||
}
|
||||
inviteEventIDs, err := u.d.InvitesTable.UpdateInviteRetired(
|
||||
u.ctx, u.txn, u.roomNID, u.targetUserNID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err)
|
||||
}
|
||||
|
||||
// Look up the NID of the new leave event
|
||||
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("u.d.EventNIDs: %w", err)
|
||||
}
|
||||
|
||||
if u.membership != tables.MembershipStateLeaveOrBan {
|
||||
@ -154,7 +160,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
|
||||
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
|
||||
tables.MembershipStateLeaveOrBan, nIDs[eventID],
|
||||
); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
|
||||
}
|
||||
}
|
||||
return inviteEventIDs, nil
|
||||
|
@ -27,6 +27,7 @@ const redactionsArePermanent = false
|
||||
|
||||
type Database struct {
|
||||
DB *sql.DB
|
||||
Writer sqlutil.TransactionWriter
|
||||
EventsTable tables.Events
|
||||
EventJSONTable tables.EventJSON
|
||||
EventTypesTable tables.EventTypes
|
||||
@ -83,20 +84,23 @@ func (d *Database) AddState(
|
||||
stateBlockNIDs []types.StateBlockNID,
|
||||
state []types.StateEntry,
|
||||
) (stateNID types.StateSnapshotNID, err error) {
|
||||
err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if len(state) > 0 {
|
||||
var stateBlockNID types.StateBlockNID
|
||||
stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("d.StateBlockTable.BulkInsertStateData: %w", err)
|
||||
}
|
||||
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
|
||||
}
|
||||
stateNID, err = d.StateSnapshotTable.InsertState(ctx, txn, roomNID, stateBlockNIDs)
|
||||
return err
|
||||
if err != nil {
|
||||
return fmt.Errorf("d.StateSnapshotTable.InsertState: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("d.Writer.Do: %w", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -110,7 +114,9 @@ func (d *Database) EventNIDs(
|
||||
func (d *Database) SetState(
|
||||
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
||||
) error {
|
||||
return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID)
|
||||
return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error {
|
||||
return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) StateAtEventIDs(
|
||||
@ -221,7 +227,9 @@ func (d *Database) GetRoomVersionForRoomNID(
|
||||
}
|
||||
|
||||
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
|
||||
return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID)
|
||||
return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error {
|
||||
return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
|
||||
@ -239,15 +247,21 @@ func (d *Database) GetCreatorIDForAlias(
|
||||
}
|
||||
|
||||
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
|
||||
return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias)
|
||||
return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error {
|
||||
return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) GetMembership(
|
||||
ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
|
||||
) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
|
||||
requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID)
|
||||
var requestSenderUserNID types.EventStateKeyNID
|
||||
err = d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error {
|
||||
requestSenderUserNID, err = d.assignStateKeyNID(ctx, nil, requestSenderUserID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
return 0, false, fmt.Errorf("d.assignStateKeyNID: %w", err)
|
||||
}
|
||||
|
||||
senderMembershipEventNID, senderMembership, err :=
|
||||
@ -350,6 +364,7 @@ func (d *Database) GetLatestEventsForUpdate(
|
||||
return NewLatestEventsUpdater(ctx, d, txn, roomNID)
|
||||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
func (d *Database) StoreEvent(
|
||||
ctx context.Context, event gomatrixserverlib.Event,
|
||||
txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
|
||||
@ -365,10 +380,10 @@ func (d *Database) StoreEvent(
|
||||
err error
|
||||
)
|
||||
|
||||
err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if txnAndSessionID != nil {
|
||||
if err = d.TransactionsTable.InsertTransaction(
|
||||
ctx, txn, txnAndSessionID.TransactionID,
|
||||
ctx, nil, txnAndSessionID.TransactionID,
|
||||
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
|
||||
); err != nil {
|
||||
return fmt.Errorf("d.TransactionsTable.InsertTransaction: %w", err)
|
||||
@ -433,7 +448,7 @@ func (d *Database) StoreEvent(
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, types.StateAtEvent{}, nil, "", err
|
||||
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err)
|
||||
}
|
||||
|
||||
return roomNID, types.StateAtEvent{
|
||||
@ -449,7 +464,9 @@ func (d *Database) StoreEvent(
|
||||
}
|
||||
|
||||
func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) error {
|
||||
return d.PublishedTable.UpsertRoomPublished(ctx, roomID, publish)
|
||||
return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error {
|
||||
return d.PublishedTable.UpsertRoomPublished(ctx, roomID, publish)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) {
|
||||
|
@ -49,15 +49,13 @@ const bulkSelectEventJSONSQL = `
|
||||
|
||||
type eventJSONStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertEventJSONStmt *sql.Stmt
|
||||
bulkSelectEventJSONStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteEventJSONTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventJSON, error) {
|
||||
func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
|
||||
s := &eventJSONStatements{
|
||||
db: db,
|
||||
writer: writer,
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(eventJSONSchema)
|
||||
if err != nil {
|
||||
@ -72,10 +70,8 @@ func NewSqliteEventJSONTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tab
|
||||
func (s *eventJSONStatements) InsertEventJSON(
|
||||
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte,
|
||||
) error {
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
|
||||
return err
|
||||
})
|
||||
_, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *eventJSONStatements) BulkSelectEventJSON(
|
||||
|
@ -64,17 +64,15 @@ const bulkSelectEventStateKeyNIDSQL = `
|
||||
|
||||
type eventStateKeyStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertEventStateKeyNIDStmt *sql.Stmt
|
||||
selectEventStateKeyNIDStmt *sql.Stmt
|
||||
bulkSelectEventStateKeyNIDStmt *sql.Stmt
|
||||
bulkSelectEventStateKeyStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteEventStateKeysTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventStateKeys, error) {
|
||||
func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) {
|
||||
s := &eventStateKeyStatements{
|
||||
db: db,
|
||||
writer: writer,
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(eventStateKeysSchema)
|
||||
if err != nil {
|
||||
@ -91,19 +89,15 @@ func NewSqliteEventStateKeysTable(db *sql.DB, writer *sqlutil.TransactionWriter)
|
||||
func (s *eventStateKeyStatements) InsertEventStateKeyNID(
|
||||
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
||||
) (types.EventStateKeyNID, error) {
|
||||
var eventStateKeyNID int64
|
||||
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt)
|
||||
res, err := insertStmt.ExecContext(ctx, eventStateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
eventStateKeyNID, err = res.LastInsertId()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt)
|
||||
res, err := insertStmt.ExecContext(ctx, eventStateKey)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
eventStateKeyNID, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return types.EventStateKeyNID(eventStateKeyNID), err
|
||||
}
|
||||
|
||||
|
@ -18,6 +18,7 @@ package sqlite3
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
@ -78,17 +79,15 @@ const bulkSelectEventTypeNIDSQL = `
|
||||
|
||||
type eventTypeStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertEventTypeNIDStmt *sql.Stmt
|
||||
insertEventTypeNIDResultStmt *sql.Stmt
|
||||
selectEventTypeNIDStmt *sql.Stmt
|
||||
bulkSelectEventTypeNIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteEventTypesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventTypes, error) {
|
||||
func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
|
||||
s := &eventTypeStatements{
|
||||
db: db,
|
||||
writer: writer,
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(eventTypesSchema)
|
||||
if err != nil {
|
||||
@ -104,18 +103,18 @@ func NewSqliteEventTypesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (ta
|
||||
}
|
||||
|
||||
func (s *eventTypeStatements) InsertEventTypeNID(
|
||||
ctx context.Context, tx *sql.Tx, eventType string,
|
||||
ctx context.Context, txn *sql.Tx, eventType string,
|
||||
) (types.EventTypeNID, error) {
|
||||
var eventTypeNID int64
|
||||
err := s.writer.Do(s.db, tx, func(tx *sql.Tx) error {
|
||||
insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt)
|
||||
resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt)
|
||||
_, err := insertStmt.ExecContext(ctx, eventType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID)
|
||||
})
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDStmt)
|
||||
resultStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDResultStmt)
|
||||
_, err := insertStmt.ExecContext(ctx, eventType)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insertStmt.ExecContext: %w", err)
|
||||
}
|
||||
if err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID); err != nil {
|
||||
return 0, fmt.Errorf("resultStmt.QueryRowContext.Scan: %w", err)
|
||||
}
|
||||
return types.EventTypeNID(eventTypeNID), err
|
||||
}
|
||||
|
||||
|
@ -99,7 +99,6 @@ const selectRoomNIDForEventNIDSQL = "" +
|
||||
|
||||
type eventStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertEventStmt *sql.Stmt
|
||||
selectEventStmt *sql.Stmt
|
||||
bulkSelectStateEventByIDStmt *sql.Stmt
|
||||
@ -115,10 +114,9 @@ type eventStatements struct {
|
||||
selectRoomNIDForEventNIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteEventsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Events, error) {
|
||||
func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
|
||||
s := &eventStatements{
|
||||
db: db,
|
||||
writer: writer,
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(eventsSchema)
|
||||
if err != nil {
|
||||
@ -155,22 +153,19 @@ func (s *eventStatements) InsertEvent(
|
||||
) (types.EventNID, types.StateSnapshotNID, error) {
|
||||
// attempt to insert: the last_row_id is the event NID
|
||||
var eventNID int64
|
||||
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
|
||||
result, err := insertStmt.ExecContext(
|
||||
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
|
||||
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
modified, err := result.RowsAffected()
|
||||
if modified == 0 && err == nil {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
eventNID, err = result.LastInsertId()
|
||||
return err
|
||||
})
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
|
||||
result, err := insertStmt.ExecContext(
|
||||
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
|
||||
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
modified, err := result.RowsAffected()
|
||||
if modified == 0 && err == nil {
|
||||
return 0, 0, sql.ErrNoRows
|
||||
}
|
||||
eventNID, err = result.LastInsertId()
|
||||
return types.EventNID(eventNID), 0, err
|
||||
}
|
||||
|
||||
@ -286,11 +281,8 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
|
||||
func (s *eventStatements) UpdateEventState(
|
||||
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
||||
) error {
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateEventStateStmt)
|
||||
_, err := stmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
|
||||
return err
|
||||
})
|
||||
_, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *eventStatements) SelectEventSentToOutput(
|
||||
@ -302,11 +294,9 @@ func (s *eventStatements) SelectEventSentToOutput(
|
||||
}
|
||||
|
||||
func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error {
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt)
|
||||
_, err := updateStmt.ExecContext(ctx, int64(eventNID))
|
||||
return err
|
||||
})
|
||||
updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt)
|
||||
_, err := updateStmt.ExecContext(ctx, int64(eventNID))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *eventStatements) SelectEventID(
|
||||
@ -334,7 +324,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference(
|
||||
|
||||
rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("sqlutil.TxStmt.QueryContext: %w", err)
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed")
|
||||
results := make([]types.StateAtEventAndReference, len(eventNIDs))
|
||||
@ -481,7 +471,7 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
|
||||
}
|
||||
err = sqlutil.TxStmt(txn, sqlPrep).QueryRowContext(ctx, iEventIDs...).Scan(&result)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
@ -64,17 +64,15 @@ SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_ni
|
||||
|
||||
type inviteStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertInviteEventStmt *sql.Stmt
|
||||
selectInviteActiveForUserInRoomStmt *sql.Stmt
|
||||
updateInviteRetiredStmt *sql.Stmt
|
||||
selectInvitesAboutToRetireStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteInvitesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Invites, error) {
|
||||
func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) {
|
||||
s := &inviteStatements{
|
||||
db: db,
|
||||
writer: writer,
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(inviteSchema)
|
||||
if err != nil {
|
||||
@ -96,20 +94,17 @@ func (s *inviteStatements) InsertInviteEvent(
|
||||
inviteEventJSON []byte,
|
||||
) (bool, error) {
|
||||
var count int64
|
||||
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
|
||||
result, err := stmt.ExecContext(
|
||||
ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
count, err = result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
|
||||
result, err := stmt.ExecContext(
|
||||
ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
count, err = result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count != 0, err
|
||||
}
|
||||
|
||||
@ -117,26 +112,23 @@ func (s *inviteStatements) UpdateInviteRetired(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (eventIDs []string, err error) {
|
||||
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
// gather all the event IDs we will retire
|
||||
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
|
||||
if err != nil {
|
||||
return err
|
||||
// gather all the event IDs we will retire
|
||||
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var inviteEventID string
|
||||
if err = rows.Scan(&inviteEventID); err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var inviteEventID string
|
||||
if err = rows.Scan(&inviteEventID); err != nil {
|
||||
return err
|
||||
}
|
||||
eventIDs = append(eventIDs, inviteEventID)
|
||||
}
|
||||
// now retire the invites
|
||||
stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
|
||||
_, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
|
||||
return err
|
||||
})
|
||||
eventIDs = append(eventIDs, inviteEventID)
|
||||
}
|
||||
// now retire the invites
|
||||
stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
|
||||
_, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -77,7 +77,6 @@ const updateMembershipSQL = "" +
|
||||
|
||||
type membershipStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertMembershipStmt *sql.Stmt
|
||||
selectMembershipForUpdateStmt *sql.Stmt
|
||||
selectMembershipFromRoomAndTargetStmt *sql.Stmt
|
||||
@ -88,10 +87,9 @@ type membershipStatements struct {
|
||||
updateMembershipStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteMembershipTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Membership, error) {
|
||||
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||
s := &membershipStatements{
|
||||
db: db,
|
||||
writer: writer,
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(membershipSchema)
|
||||
if err != nil {
|
||||
@ -115,11 +113,9 @@ func (s *membershipStatements) InsertMembership(
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
localTarget bool,
|
||||
) error {
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
|
||||
return err
|
||||
})
|
||||
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipForUpdate(
|
||||
@ -201,11 +197,9 @@ func (s *membershipStatements) UpdateMembership(
|
||||
senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
|
||||
eventNID types.EventNID,
|
||||
) error {
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
@ -54,15 +54,13 @@ const selectPreviousEventExistsSQL = `
|
||||
|
||||
type previousEventStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertPreviousEventStmt *sql.Stmt
|
||||
selectPreviousEventExistsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqlitePrevEventsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.PreviousEvents, error) {
|
||||
func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
|
||||
s := &previousEventStatements{
|
||||
db: db,
|
||||
writer: writer,
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(previousEventSchema)
|
||||
if err != nil {
|
||||
@ -82,13 +80,11 @@ func (s *previousEventStatements) InsertPreviousEvent(
|
||||
previousEventReferenceSHA256 []byte,
|
||||
eventNID types.EventNID,
|
||||
) error {
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
|
||||
)
|
||||
return err
|
||||
})
|
||||
stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if the event reference exists
|
||||
|
@ -19,7 +19,6 @@ import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
)
|
||||
@ -45,16 +44,14 @@ const selectPublishedSQL = "" +
|
||||
|
||||
type publishedStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
upsertPublishedStmt *sql.Stmt
|
||||
selectAllPublishedStmt *sql.Stmt
|
||||
selectPublishedStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqlitePublishedTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Published, error) {
|
||||
func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
|
||||
s := &publishedStatements{
|
||||
db: db,
|
||||
writer: writer,
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(publishedSchema)
|
||||
if err != nil {
|
||||
@ -69,12 +66,9 @@ func NewSqlitePublishedTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tab
|
||||
|
||||
func (s *publishedStatements) UpsertRoomPublished(
|
||||
ctx context.Context, roomID string, published bool,
|
||||
) (err error) {
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomID, published)
|
||||
return err
|
||||
})
|
||||
) error {
|
||||
_, err := s.upsertPublishedStmt.ExecContext(ctx, roomID, published)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *publishedStatements) SelectPublishedFromRoomID(
|
||||
|
@ -53,17 +53,15 @@ const markRedactionValidatedSQL = "" +
|
||||
|
||||
type redactionStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertRedactionStmt *sql.Stmt
|
||||
selectRedactionInfoByRedactionEventIDStmt *sql.Stmt
|
||||
selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt
|
||||
markRedactionValidatedStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteRedactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Redactions, error) {
|
||||
func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
|
||||
s := &redactionStatements{
|
||||
db: db,
|
||||
writer: writer,
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(redactionsSchema)
|
||||
if err != nil {
|
||||
@ -81,11 +79,9 @@ func NewSqliteRedactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (ta
|
||||
func (s *redactionStatements) InsertRedaction(
|
||||
ctx context.Context, txn *sql.Tx, info tables.RedactionInfo,
|
||||
) error {
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt)
|
||||
_, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated)
|
||||
return err
|
||||
})
|
||||
stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt)
|
||||
_, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
|
||||
@ -121,9 +117,7 @@ func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted(
|
||||
func (s *redactionStatements) MarkRedactionValidated(
|
||||
ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool,
|
||||
) error {
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt)
|
||||
_, err := stmt.ExecContext(ctx, redactionEventID, validated)
|
||||
return err
|
||||
})
|
||||
stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt)
|
||||
_, err := stmt.ExecContext(ctx, redactionEventID, validated)
|
||||
return err
|
||||
}
|
||||
|
@ -20,7 +20,6 @@ import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
)
|
||||
@ -57,7 +56,6 @@ const deleteRoomAliasSQL = `
|
||||
|
||||
type roomAliasesStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertRoomAliasStmt *sql.Stmt
|
||||
selectRoomIDFromAliasStmt *sql.Stmt
|
||||
selectAliasesFromRoomIDStmt *sql.Stmt
|
||||
@ -65,10 +63,9 @@ type roomAliasesStatements struct {
|
||||
deleteRoomAliasStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteRoomAliasesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.RoomAliases, error) {
|
||||
func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
|
||||
s := &roomAliasesStatements{
|
||||
db: db,
|
||||
writer: writer,
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(roomAliasesSchema)
|
||||
if err != nil {
|
||||
@ -85,12 +82,9 @@ func NewSqliteRoomAliasesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (t
|
||||
|
||||
func (s *roomAliasesStatements) InsertRoomAlias(
|
||||
ctx context.Context, alias string, roomID string, creatorUserID string,
|
||||
) (err error) {
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt)
|
||||
_, err := stmt.ExecContext(ctx, alias, roomID, creatorUserID)
|
||||
return err
|
||||
})
|
||||
) error {
|
||||
_, err := s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
||||
@ -138,10 +132,7 @@ func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
|
||||
|
||||
func (s *roomAliasesStatements) DeleteRoomAlias(
|
||||
ctx context.Context, alias string,
|
||||
) (err error) {
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt)
|
||||
_, err := stmt.ExecContext(ctx, alias)
|
||||
return err
|
||||
})
|
||||
) error {
|
||||
_, err := s.deleteRoomAliasStmt.ExecContext(ctx, alias)
|
||||
return err
|
||||
}
|
||||
|
@ -66,7 +66,6 @@ const selectRoomVersionForRoomNIDSQL = "" +
|
||||
|
||||
type roomStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertRoomNIDStmt *sql.Stmt
|
||||
selectRoomNIDStmt *sql.Stmt
|
||||
selectLatestEventNIDsStmt *sql.Stmt
|
||||
@ -76,10 +75,9 @@ type roomStatements struct {
|
||||
selectRoomVersionForRoomNIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteRoomsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Rooms, error) {
|
||||
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||
s := &roomStatements{
|
||||
db: db,
|
||||
writer: writer,
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(roomsSchema)
|
||||
if err != nil {
|
||||
@ -100,20 +98,14 @@ func (s *roomStatements) InsertRoomNID(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomID string, roomVersion gomatrixserverlib.RoomVersion,
|
||||
) (roomNID types.RoomNID, err error) {
|
||||
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
|
||||
_, err = insertStmt.ExecContext(ctx, roomID, roomVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insertStmt.ExecContext: %w", err)
|
||||
}
|
||||
roomNID, err = s.SelectRoomNID(ctx, txn, roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.SelectRoomNID: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
|
||||
_, err = insertStmt.ExecContext(ctx, roomID, roomVersion)
|
||||
if err != nil {
|
||||
return types.RoomNID(0), err
|
||||
return 0, fmt.Errorf("insertStmt.ExecContext: %w", err)
|
||||
}
|
||||
roomNID, err = s.SelectRoomNID(ctx, txn, roomID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("s.SelectRoomNID: %w", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -170,17 +162,15 @@ func (s *roomStatements) UpdateLatestEventNIDs(
|
||||
lastEventSentNID types.EventNID,
|
||||
stateSnapshotNID types.StateSnapshotNID,
|
||||
) error {
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
eventNIDsAsArray(eventNIDs),
|
||||
int64(lastEventSentNID),
|
||||
int64(stateSnapshotNID),
|
||||
roomNID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
eventNIDsAsArray(eventNIDs),
|
||||
int64(lastEventSentNID),
|
||||
int64(stateSnapshotNID),
|
||||
roomNID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomVersionForRoomID(
|
||||
|
@ -74,17 +74,15 @@ const bulkSelectFilteredStateBlockEntriesSQL = "" +
|
||||
|
||||
type stateBlockStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertStateDataStmt *sql.Stmt
|
||||
selectNextStateBlockNIDStmt *sql.Stmt
|
||||
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
||||
bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteStateBlockTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.StateBlock, error) {
|
||||
func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
||||
s := &stateBlockStatements{
|
||||
db: db,
|
||||
writer: writer,
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(stateDataSchema)
|
||||
if err != nil {
|
||||
@ -107,25 +105,22 @@ func (s *stateBlockStatements) BulkInsertStateData(
|
||||
return 0, nil
|
||||
}
|
||||
var stateBlockNID types.StateBlockNID
|
||||
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
|
||||
err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for _, entry := range entries {
|
||||
_, err = txn.Stmt(s.insertStateDataStmt).ExecContext(
|
||||
ctx,
|
||||
int64(stateBlockNID),
|
||||
int64(entry.EventTypeNID),
|
||||
int64(entry.EventStateKeyNID),
|
||||
int64(entry.EventNID),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
for _, entry := range entries {
|
||||
_, err := txn.Stmt(s.insertStateDataStmt).ExecContext(
|
||||
ctx,
|
||||
int64(stateBlockNID),
|
||||
int64(entry.EventTypeNID),
|
||||
int64(entry.EventStateKeyNID),
|
||||
int64(entry.EventNID),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
return stateBlockNID, err
|
||||
}
|
||||
|
||||
|
@ -50,15 +50,13 @@ const bulkSelectStateBlockNIDsSQL = "" +
|
||||
|
||||
type stateSnapshotStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertStateStmt *sql.Stmt
|
||||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteStateSnapshotTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.StateSnapshot, error) {
|
||||
func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||
s := &stateSnapshotStatements{
|
||||
db: db,
|
||||
writer: writer,
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(stateSnapshotSchema)
|
||||
if err != nil {
|
||||
@ -78,19 +76,16 @@ func (s *stateSnapshotStatements) InsertState(
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
insertStmt := txn.Stmt(s.insertStateStmt)
|
||||
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lastRowID, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stateNID = types.StateSnapshotNID(lastRowID)
|
||||
return nil
|
||||
})
|
||||
insertStmt := txn.Stmt(s.insertStateStmt)
|
||||
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
lastRowID, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
stateNID = types.StateSnapshotNID(lastRowID)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -41,6 +41,7 @@ type Database struct {
|
||||
invites tables.Invites
|
||||
membership tables.Membership
|
||||
db *sql.DB
|
||||
writer sqlutil.TransactionWriter
|
||||
}
|
||||
|
||||
// Open a sqlite database.
|
||||
@ -51,7 +52,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writer := sqlutil.NewTransactionWriter()
|
||||
d.writer = sqlutil.NewTransactionWriter()
|
||||
//d.db.Exec("PRAGMA journal_mode=WAL;")
|
||||
//d.db.Exec("PRAGMA read_uncommitted = true;")
|
||||
|
||||
@ -61,64 +62,65 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||
// which it will never obtain.
|
||||
d.db.SetMaxOpenConns(20)
|
||||
|
||||
d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db, writer)
|
||||
d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.eventTypes, err = NewSqliteEventTypesTable(d.db, writer)
|
||||
d.eventTypes, err = NewSqliteEventTypesTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.eventJSON, err = NewSqliteEventJSONTable(d.db, writer)
|
||||
d.eventJSON, err = NewSqliteEventJSONTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.events, err = NewSqliteEventsTable(d.db, writer)
|
||||
d.events, err = NewSqliteEventsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.rooms, err = NewSqliteRoomsTable(d.db, writer)
|
||||
d.rooms, err = NewSqliteRoomsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.transactions, err = NewSqliteTransactionsTable(d.db, writer)
|
||||
d.transactions, err = NewSqliteTransactionsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stateBlock, err := NewSqliteStateBlockTable(d.db, writer)
|
||||
stateBlock, err := NewSqliteStateBlockTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stateSnapshot, err := NewSqliteStateSnapshotTable(d.db, writer)
|
||||
stateSnapshot, err := NewSqliteStateSnapshotTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.prevEvents, err = NewSqlitePrevEventsTable(d.db, writer)
|
||||
d.prevEvents, err = NewSqlitePrevEventsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomAliases, err := NewSqliteRoomAliasesTable(d.db, writer)
|
||||
roomAliases, err := NewSqliteRoomAliasesTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.invites, err = NewSqliteInvitesTable(d.db, writer)
|
||||
d.invites, err = NewSqliteInvitesTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.membership, err = NewSqliteMembershipTable(d.db, writer)
|
||||
d.membership, err = NewSqliteMembershipTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
published, err := NewSqlitePublishedTable(d.db, writer)
|
||||
published, err := NewSqlitePublishedTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
redactions, err := NewSqliteRedactionsTable(d.db, writer)
|
||||
redactions, err := NewSqliteRedactionsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.Database = shared.Database{
|
||||
DB: d.db,
|
||||
Writer: sqlutil.NewTransactionWriter(),
|
||||
EventsTable: d.events,
|
||||
EventTypesTable: d.eventTypes,
|
||||
EventStateKeysTable: d.eventStateKeys,
|
||||
|
@ -45,15 +45,13 @@ const selectTransactionEventIDSQL = `
|
||||
|
||||
type transactionStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertTransactionStmt *sql.Stmt
|
||||
selectTransactionEventIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteTransactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Transactions, error) {
|
||||
func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) {
|
||||
s := &transactionStatements{
|
||||
db: db,
|
||||
writer: writer,
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(transactionsSchema)
|
||||
if err != nil {
|
||||
@ -72,14 +70,12 @@ func (s *transactionStatements) InsertTransaction(
|
||||
sessionID int64,
|
||||
userID string,
|
||||
eventID string,
|
||||
) (err error) {
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, transactionID, sessionID, userID, eventID,
|
||||
)
|
||||
return err
|
||||
})
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, transactionID, sessionID, userID, eventID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *transactionStatements) SelectTransactionEventID(
|
||||
|
@ -63,7 +63,7 @@ const upsertServerKeysSQL = "" +
|
||||
|
||||
type serverKeyStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
bulkSelectServerKeysStmt *sql.Stmt
|
||||
upsertServerKeysStmt *sql.Stmt
|
||||
}
|
||||
|
@ -45,7 +45,7 @@ type Database struct {
|
||||
BackwardExtremities tables.BackwardsExtremities
|
||||
SendToDevice tables.SendToDevice
|
||||
Filter tables.Filter
|
||||
SendToDeviceWriter *sqlutil.TransactionWriter
|
||||
SendToDeviceWriter sqlutil.TransactionWriter
|
||||
EDUCache *cache.EDUCache
|
||||
}
|
||||
|
||||
|
@ -51,7 +51,7 @@ const selectMaxAccountDataIDSQL = "" +
|
||||
|
||||
type accountDataStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
streamIDStatements *streamIDStatements
|
||||
insertAccountDataStmt *sql.Stmt
|
||||
selectMaxAccountDataIDStmt *sql.Stmt
|
||||
|
@ -49,7 +49,7 @@ const deleteBackwardExtremitySQL = "" +
|
||||
|
||||
type backwardExtremitiesStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
insertBackwardExtremityStmt *sql.Stmt
|
||||
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||
deleteBackwardExtremityStmt *sql.Stmt
|
||||
|
@ -85,7 +85,7 @@ const selectEventsWithEventIDsSQL = "" +
|
||||
|
||||
type currentRoomStateStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
streamIDStatements *streamIDStatements
|
||||
upsertRoomStateStmt *sql.Stmt
|
||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||
|
@ -52,7 +52,7 @@ const insertFilterSQL = "" +
|
||||
|
||||
type filterStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
selectFilterStmt *sql.Stmt
|
||||
selectFilterIDByContentStmt *sql.Stmt
|
||||
insertFilterStmt *sql.Stmt
|
||||
|
@ -59,7 +59,7 @@ const selectMaxInviteIDSQL = "" +
|
||||
|
||||
type inviteEventsStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
streamIDStatements *streamIDStatements
|
||||
insertInviteEventStmt *sql.Stmt
|
||||
selectInviteEventsInRangeStmt *sql.Stmt
|
||||
|
@ -105,7 +105,7 @@ const selectStateInRangeSQL = "" +
|
||||
|
||||
type outputRoomEventsStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
streamIDStatements *streamIDStatements
|
||||
insertEventStmt *sql.Stmt
|
||||
selectEventsStmt *sql.Stmt
|
||||
|
@ -67,7 +67,7 @@ const selectMaxPositionInTopologySQL = "" +
|
||||
|
||||
type outputRoomEventsTopologyStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
insertEventInTopologyStmt *sql.Stmt
|
||||
selectEventIDsInRangeASCStmt *sql.Stmt
|
||||
selectEventIDsInRangeDESCStmt *sql.Stmt
|
||||
|
@ -73,7 +73,7 @@ const deleteSendToDeviceMessagesSQL = `
|
||||
|
||||
type sendToDeviceStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
insertSendToDeviceMessageStmt *sql.Stmt
|
||||
selectSendToDeviceMessagesStmt *sql.Stmt
|
||||
countSendToDeviceMessagesStmt *sql.Stmt
|
||||
|
@ -28,7 +28,7 @@ const selectStreamIDStmt = "" +
|
||||
|
||||
type streamIDStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
increaseStreamIDStmt *sql.Stmt
|
||||
selectStreamIDStmt *sql.Stmt
|
||||
}
|
||||
|
@ -51,7 +51,7 @@ const selectAccountDataByTypeSQL = "" +
|
||||
|
||||
type accountDataStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
insertAccountDataStmt *sql.Stmt
|
||||
selectAccountDataStmt *sql.Stmt
|
||||
selectAccountDataByTypeStmt *sql.Stmt
|
||||
|
@ -59,7 +59,7 @@ const selectNewNumericLocalpartSQL = "" +
|
||||
|
||||
type accountsStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
insertAccountStmt *sql.Stmt
|
||||
selectAccountByLocalpartStmt *sql.Stmt
|
||||
selectPasswordHashStmt *sql.Stmt
|
||||
|
@ -53,7 +53,7 @@ const selectProfilesBySearchSQL = "" +
|
||||
|
||||
type profilesStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
insertProfileStmt *sql.Stmt
|
||||
selectProfileByLocalpartStmt *sql.Stmt
|
||||
setAvatarURLStmt *sql.Stmt
|
||||
|
@ -54,7 +54,7 @@ const deleteThreePIDSQL = "" +
|
||||
|
||||
type threepidStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
selectLocalpartForThreePIDStmt *sql.Stmt
|
||||
selectThreePIDsForLocalpartStmt *sql.Stmt
|
||||
insertThreePIDStmt *sql.Stmt
|
||||
|
@ -78,7 +78,7 @@ const selectDevicesByIDSQL = "" +
|
||||
|
||||
type devicesStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
writer sqlutil.TransactionWriter
|
||||
insertDeviceStmt *sql.Stmt
|
||||
selectDevicesCountStmt *sql.Stmt
|
||||
selectDeviceByTokenStmt *sql.Stmt
|
||||
|
Loading…
Reference in New Issue
Block a user