diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index 6368675b..64795d02 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -49,13 +49,16 @@ const bulkSelectEventJSONSQL = ` type eventJSONStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter insertEventJSONStmt *sql.Stmt bulkSelectEventJSONStmt *sql.Stmt } func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { - s := &eventJSONStatements{} - s.db = db + s := &eventJSONStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(eventJSONSchema) if err != nil { return nil, err @@ -69,8 +72,10 @@ func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { func (s *eventJSONStatements) InsertEventJSON( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { - _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) + return err + }) } func (s *eventJSONStatements) BulkSelectEventJSON( diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index cbea8428..3e9f2e61 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -64,6 +64,7 @@ const bulkSelectEventStateKeyNIDSQL = ` type eventStateKeyStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter insertEventStateKeyNIDStmt *sql.Stmt selectEventStateKeyNIDStmt *sql.Stmt bulkSelectEventStateKeyNIDStmt *sql.Stmt @@ -71,8 +72,10 @@ type eventStateKeyStatements struct { } func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { - s := &eventStateKeyStatements{} - s.db = db + s := &eventStateKeyStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(eventStateKeysSchema) if err != nil { return nil, err @@ -89,12 +92,18 @@ func (s *eventStateKeyStatements) InsertEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 - var err error - var res sql.Result - insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt) - if res, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil { + err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt) + res, err := insertStmt.ExecContext(ctx, eventStateKey) + if err != nil { + return err + } eventStateKeyNID, err = res.LastInsertId() - } + if err != nil { + return err + } + return nil + }) return types.EventStateKeyNID(eventStateKeyNID), err } diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index c9a461f9..fd4a2e42 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -78,6 +78,7 @@ const bulkSelectEventTypeNIDSQL = ` type eventTypeStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter insertEventTypeNIDStmt *sql.Stmt insertEventTypeNIDResultStmt *sql.Stmt selectEventTypeNIDStmt *sql.Stmt @@ -85,8 +86,10 @@ type eventTypeStatements struct { } func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) { - s := &eventTypeStatements{} - s.db = db + s := &eventTypeStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(eventTypesSchema) if err != nil { return nil, err @@ -104,12 +107,15 @@ func (s *eventTypeStatements) InsertEventTypeNID( ctx context.Context, tx *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 - var err error - insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt) - resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt) - if _, err = insertStmt.ExecContext(ctx, eventType); err == nil { - err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID) - } + err := s.writer.Do(s.db, tx, func(tx *sql.Tx) error { + insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt) + resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt) + _, err := insertStmt.ExecContext(ctx, eventType) + if err != nil { + return err + } + return resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID) + }) return types.EventTypeNID(eventTypeNID), err } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index d66db469..378441c3 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -99,6 +99,7 @@ const selectRoomNIDForEventNIDSQL = "" + type eventStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt @@ -115,8 +116,10 @@ type eventStatements struct { } func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { - s := &eventStatements{} - s.db = db + s := &eventStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(eventsSchema) if err != nil { return nil, err @@ -151,19 +154,23 @@ func (s *eventStatements) InsertEvent( depth int64, ) (types.EventNID, types.StateSnapshotNID, error) { // attempt to insert: the last_row_id is the event NID - insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) - result, err := insertStmt.ExecContext( - ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), - eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, - ) - if err != nil { - return 0, 0, err - } - modified, err := result.RowsAffected() - if modified == 0 && err == nil { - return 0, 0, sql.ErrNoRows - } - eventNID, err := result.LastInsertId() + var eventNID int64 + err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) + result, err := insertStmt.ExecContext( + ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), + eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, + ) + if err != nil { + return err + } + modified, err := result.RowsAffected() + if modified == 0 && err == nil { + return sql.ErrNoRows + } + eventNID, err = result.LastInsertId() + return err + }) return types.EventNID(eventNID), 0, err } @@ -279,8 +286,10 @@ func (s *eventStatements) BulkSelectStateAtEventByID( func (s *eventStatements) UpdateEventState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { - _, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) - return err + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + _, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) + return err + }) } func (s *eventStatements) SelectEventSentToOutput( @@ -288,17 +297,15 @@ func (s *eventStatements) SelectEventSentToOutput( ) (sentToOutput bool, err error) { selectStmt := sqlutil.TxStmt(txn, s.selectEventSentToOutputStmt) err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) - //err = s.selectEventSentToOutputStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) - if err != nil { - } return } func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { - updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt) - _, err := updateStmt.ExecContext(ctx, int64(eventNID)) - //_, err := s.updateEventSentToOutputStmt.ExecContext(ctx, int64(eventNID)) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt) + _, err := updateStmt.ExecContext(ctx, int64(eventNID)) + return err + }) } func (s *eventStatements) SelectEventID( diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index 8b6cbe3f..e806eab6 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -63,6 +63,8 @@ SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_ni ` type inviteStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertInviteEventStmt *sql.Stmt selectInviteActiveForUserInRoomStmt *sql.Stmt updateInviteRetiredStmt *sql.Stmt @@ -70,7 +72,10 @@ type inviteStatements struct { } func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) { - s := &inviteStatements{} + s := &inviteStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(inviteSchema) if err != nil { return nil, err @@ -90,42 +95,48 @@ func (s *inviteStatements) InsertInviteEvent( targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { - stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) - result, err := stmt.ExecContext( - ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, - ) - if err != nil { - return false, err - } - count, err := result.RowsAffected() - if err != nil { - return false, err - } - return count != 0, nil + var count int64 + err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) + result, err := stmt.ExecContext( + ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, + ) + if err != nil { + return err + } + count, err = result.RowsAffected() + if err != nil { + return err + } + return nil + }) + return count != 0, err } func (s *inviteStatements) UpdateInviteRetired( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventIDs []string, err error) { - // gather all the event IDs we will retire - stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) - rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) - if err != nil { - return nil, err - } - defer (func() { err = rows.Close() })() - for rows.Next() { - var inviteEventID string - if err = rows.Scan(&inviteEventID); err != nil { - return nil, err + err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + // gather all the event IDs we will retire + stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) + rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) + if err != nil { + return err } - eventIDs = append(eventIDs, inviteEventID) - } - - // now retire the invites - stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) - _, err = stmt.ExecContext(ctx, roomNID, targetUserNID) + defer (func() { err = rows.Close() })() + for rows.Next() { + var inviteEventID string + if err = rows.Scan(&inviteEventID); err != nil { + return err + } + eventIDs = append(eventIDs, inviteEventID) + } + // now retire the invites + stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) + _, err = stmt.ExecContext(ctx, roomNID, targetUserNID) + return err + }) return } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 6f0d763e..6dd8bd83 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -76,6 +76,8 @@ const updateMembershipSQL = "" + " WHERE room_nid = $4 AND target_nid = $5" type membershipStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt selectMembershipFromRoomAndTargetStmt *sql.Stmt @@ -87,7 +89,10 @@ type membershipStatements struct { } func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { - s := &membershipStatements{} + s := &membershipStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(membershipSchema) if err != nil { return nil, err @@ -110,9 +115,11 @@ func (s *membershipStatements) InsertMembership( roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool, ) error { - stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) - _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) + _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) + return err + }) } func (s *membershipStatements) SelectMembershipForUpdate( @@ -194,9 +201,11 @@ func (s *membershipStatements) UpdateMembership( senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, ) error { - stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) - _, err := stmt.ExecContext( - ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID, - ) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) + _, err := stmt.ExecContext( + ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID, + ) + return err + }) } diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index 549aecfb..28b5d18f 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -53,12 +53,17 @@ const selectPreviousEventExistsSQL = ` ` type previousEventStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertPreviousEventStmt *sql.Stmt selectPreviousEventExistsStmt *sql.Stmt } func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { - s := &previousEventStatements{} + s := &previousEventStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(previousEventSchema) if err != nil { return nil, err @@ -77,11 +82,13 @@ func (s *previousEventStatements) InsertPreviousEvent( previousEventReferenceSHA256 []byte, eventNID types.EventNID, ) error { - stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) - _, err := stmt.ExecContext( - ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), - ) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) + _, err := stmt.ExecContext( + ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), + ) + return err + }) } // Check if the event reference exists diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index 9995fff6..96575241 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -43,13 +44,18 @@ const selectPublishedSQL = "" + "SELECT published FROM roomserver_published WHERE room_id = $1" type publishedStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter upsertPublishedStmt *sql.Stmt selectAllPublishedStmt *sql.Stmt selectPublishedStmt *sql.Stmt } func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) { - s := &publishedStatements{} + s := &publishedStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(publishedSchema) if err != nil { return nil, err @@ -64,8 +70,10 @@ func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) { func (s *publishedStatements) UpsertRoomPublished( ctx context.Context, roomID string, published bool, ) (err error) { - _, err = s.upsertPublishedStmt.ExecContext(ctx, roomID, published) - return + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + _, err := s.upsertPublishedStmt.ExecContext(ctx, roomID, published) + return err + }) } func (s *publishedStatements) SelectPublishedFromRoomID( diff --git a/roomserver/storage/sqlite3/redactions_table.go b/roomserver/storage/sqlite3/redactions_table.go index 1cddb9b4..d2bd2a20 100644 --- a/roomserver/storage/sqlite3/redactions_table.go +++ b/roomserver/storage/sqlite3/redactions_table.go @@ -52,6 +52,8 @@ const markRedactionValidatedSQL = "" + " UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1" type redactionStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertRedactionStmt *sql.Stmt selectRedactionInfoByRedactionEventIDStmt *sql.Stmt selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt @@ -59,7 +61,10 @@ type redactionStatements struct { } func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) { - s := &redactionStatements{} + s := &redactionStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(redactionsSchema) if err != nil { return nil, err @@ -76,9 +81,11 @@ func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) { func (s *redactionStatements) InsertRedaction( ctx context.Context, txn *sql.Tx, info tables.RedactionInfo, ) error { - stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt) - _, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt) + _, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated) + return err + }) } func (s *redactionStatements) SelectRedactionInfoByRedactionEventID( @@ -114,7 +121,9 @@ func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted( func (s *redactionStatements) MarkRedactionValidated( ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool, ) error { - stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt) - _, err := stmt.ExecContext(ctx, redactionEventID, validated) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt) + _, err := stmt.ExecContext(ctx, redactionEventID, validated) + return err + }) } diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index da5f9161..096b73f9 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -55,6 +56,8 @@ const deleteRoomAliasSQL = ` ` type roomAliasesStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertRoomAliasStmt *sql.Stmt selectRoomIDFromAliasStmt *sql.Stmt selectAliasesFromRoomIDStmt *sql.Stmt @@ -63,7 +66,10 @@ type roomAliasesStatements struct { } func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { - s := &roomAliasesStatements{} + s := &roomAliasesStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(roomAliasesSchema) if err != nil { return nil, err @@ -80,8 +86,10 @@ func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { func (s *roomAliasesStatements) InsertRoomAlias( ctx context.Context, alias string, roomID string, creatorUserID string, ) (err error) { - _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) - return + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + _, err := s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) + return err + }) } func (s *roomAliasesStatements) SelectRoomIDFromAlias( @@ -130,6 +138,8 @@ func (s *roomAliasesStatements) SelectCreatorIDFromAlias( func (s *roomAliasesStatements) DeleteRoomAlias( ctx context.Context, alias string, ) (err error) { - _, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias) - return + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + _, err := s.deleteRoomAliasStmt.ExecContext(ctx, alias) + return err + }) } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index ab695c5d..9eeadea9 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -64,6 +64,8 @@ const selectRoomVersionForRoomNIDSQL = "" + "SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" type roomStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt @@ -74,7 +76,10 @@ type roomStatements struct { } func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { - s := &roomStatements{} + s := &roomStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(roomsSchema) if err != nil { return nil, err @@ -94,9 +99,12 @@ func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (types.RoomNID, error) { - var err error - insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt) - if _, err = insertStmt.ExecContext(ctx, roomID, roomVersion); err == nil { + err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt) + _, err := insertStmt.ExecContext(ctx, roomID, roomVersion) + return err + }) + if err == nil { return s.SelectRoomNID(ctx, txn, roomID) } else { return types.RoomNID(0), err @@ -155,15 +163,17 @@ func (s *roomStatements) UpdateLatestEventNIDs( lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID, ) error { - stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt) - _, err := stmt.ExecContext( - ctx, - eventNIDsAsArray(eventNIDs), - int64(lastEventSentNID), - int64(stateSnapshotNID), - roomNID, - ) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt) + _, err := stmt.ExecContext( + ctx, + eventNIDsAsArray(eventNIDs), + int64(lastEventSentNID), + int64(stateSnapshotNID), + roomNID, + ) + return err + }) } func (s *roomStatements) SelectRoomVersionForRoomID( diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index c058c783..3d716b64 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -74,6 +74,7 @@ const bulkSelectFilteredStateBlockEntriesSQL = "" + type stateBlockStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter insertStateDataStmt *sql.Stmt selectNextStateBlockNIDStmt *sql.Stmt bulkSelectStateBlockEntriesStmt *sql.Stmt @@ -81,8 +82,10 @@ type stateBlockStatements struct { } func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { - s := &stateBlockStatements{} - s.db = db + s := &stateBlockStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(stateDataSchema) if err != nil { return nil, err @@ -104,24 +107,26 @@ func (s *stateBlockStatements) BulkInsertStateData( return 0, nil } var stateBlockNID types.StateBlockNID - err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) - if err != nil { - return 0, err - } - - for _, entry := range entries { - _, err := txn.Stmt(s.insertStateDataStmt).ExecContext( - ctx, - int64(stateBlockNID), - int64(entry.EventTypeNID), - int64(entry.EventStateKeyNID), - int64(entry.EventNID), - ) + err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) if err != nil { - return 0, err + return err } - } - return stateBlockNID, nil + for _, entry := range entries { + _, err := txn.Stmt(s.insertStateDataStmt).ExecContext( + ctx, + int64(stateBlockNID), + int64(entry.EventTypeNID), + int64(entry.EventStateKeyNID), + int64(entry.EventNID), + ) + if err != nil { + return err + } + } + return nil + }) + return stateBlockNID, err } func (s *stateBlockStatements) BulkSelectStateBlockEntries( diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index d077b617..48f1210b 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -50,13 +50,16 @@ const bulkSelectStateBlockNIDsSQL = "" + type stateSnapshotStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter insertStateStmt *sql.Stmt bulkSelectStateBlockNIDsStmt *sql.Stmt } func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { - s := &stateSnapshotStatements{} - s.db = db + s := &stateSnapshotStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(stateSnapshotSchema) if err != nil { return nil, err @@ -75,14 +78,19 @@ func (s *stateSnapshotStatements) InsertState( if err != nil { return } - insertStmt := txn.Stmt(s.insertStateStmt) - if res, err2 := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)); err2 == nil { - lastRowID, err3 := res.LastInsertId() - if err3 != nil { - err = err3 + err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + insertStmt := txn.Stmt(s.insertStateStmt) + res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) + if err != nil { + return err + } + lastRowID, err := res.LastInsertId() + if err != nil { + return err } stateNID = types.StateSnapshotNID(lastRowID) - } + return nil + }) return } diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go index 1e8de1ca..2f6cff95 100644 --- a/roomserver/storage/sqlite3/transactions_table.go +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -44,12 +44,17 @@ const selectTransactionEventIDSQL = ` ` type transactionStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertTransactionStmt *sql.Stmt selectTransactionEventIDStmt *sql.Stmt } func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) { - s := &transactionStatements{} + s := &transactionStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(transactionsSchema) if err != nil { return nil, err @@ -68,11 +73,13 @@ func (s *transactionStatements) InsertTransaction( userID string, eventID string, ) (err error) { - stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt) - _, err = stmt.ExecContext( - ctx, transactionID, sessionID, userID, eventID, - ) - return + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt) + _, err := stmt.ExecContext( + ctx, transactionID, sessionID, userID, eventID, + ) + return err + }) } func (s *transactionStatements) SelectTransactionEventID(