From a6700331cee70a3ca04c834efdc68cc2ef63947d Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 24 Sep 2020 12:10:14 +0200 Subject: [PATCH] Update all usages of tx.Stmt to sqlutil.TxStmt (#1423) * Replace all usages of txn.Stmt with sqlutil.TxStmt Signed-off-by: Sam Day * Fix sign off link in PR template. Signed-off-by: Sam Day Co-authored-by: Neil Alexander --- .github/PULL_REQUEST_TEMPLATE.md | 2 +- internal/sqlutil/sql.go | 8 ++++++++ keyserver/storage/postgres/device_keys_table.go | 7 ++++--- keyserver/storage/postgres/one_time_keys_table.go | 9 +++++---- keyserver/storage/sqlite3/device_keys_table.go | 6 +++--- keyserver/storage/sqlite3/one_time_keys_table.go | 9 +++++---- roomserver/storage/postgres/state_snapshot_table.go | 3 ++- roomserver/storage/sqlite3/state_block_table.go | 4 ++-- roomserver/storage/sqlite3/state_snapshot_table.go | 2 +- syncapi/storage/postgres/backwards_extremities_table.go | 4 ++-- syncapi/storage/postgres/send_to_device_table.go | 4 ++-- syncapi/storage/sqlite3/account_data_table.go | 5 +++-- syncapi/storage/sqlite3/backwards_extremities_table.go | 4 ++-- userapi/storage/accounts/postgres/account_data_table.go | 3 ++- userapi/storage/accounts/postgres/accounts_table.go | 5 +++-- userapi/storage/accounts/postgres/profile_table.go | 3 ++- userapi/storage/accounts/sqlite3/account_data_table.go | 4 +++- userapi/storage/accounts/sqlite3/accounts_table.go | 7 ++++--- userapi/storage/accounts/sqlite3/profile_table.go | 2 +- 19 files changed, 55 insertions(+), 36 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index f4724c31..8d22f7f7 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -3,4 +3,4 @@ * [ ] I have added any new tests that need to pass to `testfile` as specified in [docs/sytest.md](https://github.com/matrix-org/dendrite/blob/master/docs/sytest.md) -* [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/master/CONTRIBUTING.md#sign-off) +* [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/master/docs/CONTRIBUTING.md#sign-off) diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 90562ded..a3885d66 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -88,6 +88,14 @@ func TxStmt(transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt { return statement } +// TxStmtContext behaves similarly to TxStmt, with support for also passing context. +func TxStmtContext(context context.Context, transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt { + if transaction != nil { + statement = transaction.StmtContext(context, statement) + } + return statement +} + // Hack of the century func QueryVariadic(count int) string { return QueryVariadicOffset(count, 0) diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index e5bec8f6..95064fc8 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -21,6 +21,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -125,7 +126,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { // nullable if there are no results var nullStream sql.NullInt32 - err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) if err == sql.ErrNoRows { err = nil } @@ -151,7 +152,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { for _, key := range keys { now := time.Now().Unix() - _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( + _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, ) if err != nil { @@ -162,7 +163,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx } func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { - _, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) + _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) return err } diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go index a299861d..6e32838b 100644 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ b/keyserver/storage/postgres/one_time_keys_table.go @@ -21,6 +21,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -151,14 +152,14 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql. } for keyIDWithAlgo, keyJSON := range keys.KeyJSON { algo, keyID := keys.Split(keyIDWithAlgo) - _, err := txn.Stmt(s.upsertKeysStmt).ExecContext( + _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), ) if err != nil { return nil, err } } - rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) if err != nil { return nil, err } @@ -180,14 +181,14 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( ) (map[string]json.RawMessage, error) { var keyID string var keyJSON string - err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) if err != nil { if err == sql.ErrNoRows { return nil, nil } return nil, err } - _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) return map[string]json.RawMessage{ algorithm + ":" + keyID: json.RawMessage(keyJSON), }, err diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index e7ff9976..9112fc6e 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -97,7 +97,7 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { } func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { - _, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) + _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) return err } @@ -156,7 +156,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { // nullable if there are no results var nullStream sql.NullInt32 - err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) if err == sql.ErrNoRows { err = nil } @@ -188,7 +188,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { for _, key := range keys { now := time.Now().Unix() - _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( + _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, ) if err != nil { diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index 1b6a74d6..185b8861 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -21,6 +21,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -153,14 +154,14 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys( } for keyIDWithAlgo, keyJSON := range keys.KeyJSON { algo, keyID := keys.Split(keyIDWithAlgo) - _, err := txn.Stmt(s.upsertKeysStmt).ExecContext( + _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), ) if err != nil { return nil, err } } - rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) if err != nil { return nil, err } @@ -182,14 +183,14 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( ) (map[string]json.RawMessage, error) { var keyID string var keyJSON string - err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) if err != nil { if err == sql.ErrNoRows { return nil, nil } return nil, err } - _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index 0f8f1c51..63175955 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -21,6 +21,7 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -86,7 +87,7 @@ func (s *stateSnapshotStatements) InsertState( for i := range stateBlockNIDs { nids[i] = int64(stateBlockNIDs[i]) } - err = txn.Stmt(s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) + err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) return } diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 8033903f..2c544f2b 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -105,12 +105,12 @@ func (s *stateBlockStatements) BulkInsertStateData( return 0, nil } var stateBlockNID types.StateBlockNID - err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) + err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) if err != nil { return 0, err } for _, entry := range entries { - _, err = txn.Stmt(s.insertStateDataStmt).ExecContext( + _, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext( ctx, int64(stateBlockNID), int64(entry.EventTypeNID), diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 392c2a67..bf49f62c 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -76,7 +76,7 @@ func (s *stateSnapshotStatements) InsertState( if err != nil { return } - insertStmt := txn.Stmt(s.insertStateStmt) + insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt) res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) if err != nil { return 0, err diff --git a/syncapi/storage/postgres/backwards_extremities_table.go b/syncapi/storage/postgres/backwards_extremities_table.go index 13056588..d5cf563a 100644 --- a/syncapi/storage/postgres/backwards_extremities_table.go +++ b/syncapi/storage/postgres/backwards_extremities_table.go @@ -81,7 +81,7 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ) (err error) { - _, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) + _, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) return } @@ -110,7 +110,7 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ) (err error) { - _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) + _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return } diff --git a/syncapi/storage/postgres/send_to_device_table.go b/syncapi/storage/postgres/send_to_device_table.go index 07af9ad6..be9c347b 100644 --- a/syncapi/storage/postgres/send_to_device_table.go +++ b/syncapi/storage/postgres/send_to_device_table.go @@ -160,13 +160,13 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID, ) (err error) { - _, err = txn.Stmt(s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids)) + _, err = sqlutil.TxStmt(txn, s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids)) return } func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, ) (err error) { - _, err = txn.Stmt(s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids)) + _, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids)) return } diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 72c46e48..4bcc06ed 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -85,7 +86,7 @@ func (s *accountDataStatements) InsertAccountData( if err != nil { return } - _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) + _, err = sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) return } @@ -147,7 +148,7 @@ func (s *accountDataStatements) SelectMaxAccountDataID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 - err = txn.Stmt(s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID) + err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 } diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index 9a81e8e7..662cb025 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -84,7 +84,7 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ) (err error) { - _, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) + _, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) return err } @@ -113,7 +113,7 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ) (err error) { - _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) + _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return err } diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/accounts/postgres/account_data_table.go index 90c79e87..09eb2611 100644 --- a/userapi/storage/accounts/postgres/account_data_table.go +++ b/userapi/storage/accounts/postgres/account_data_table.go @@ -20,6 +20,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" ) const accountDataSchema = ` @@ -75,7 +76,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { func (s *accountDataStatements) insertAccountData( ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) (err error) { - stmt := txn.Stmt(s.insertAccountDataStmt) + stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) return } diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go index 8c8d32cf..7500e1e8 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -20,6 +20,7 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -99,7 +100,7 @@ func (s *accountsStatements) insertAccount( ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 - stmt := txn.Stmt(s.insertAccountStmt) + stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) var err error if appserviceID == "" { @@ -162,7 +163,7 @@ func (s *accountsStatements) selectNewNumericLocalpart( ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt if txn != nil { - stmt = txn.Stmt(stmt) + stmt = sqlutil.TxStmt(txn, stmt) } err = stmt.QueryRowContext(ctx).Scan(&id) return diff --git a/userapi/storage/accounts/postgres/profile_table.go b/userapi/storage/accounts/postgres/profile_table.go index 14b12c35..45d802f1 100644 --- a/userapi/storage/accounts/postgres/profile_table.go +++ b/userapi/storage/accounts/postgres/profile_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" ) const profilesSchema = ` @@ -84,7 +85,7 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { func (s *profilesStatements) insertProfile( ctx context.Context, txn *sql.Tx, localpart string, ) (err error) { - _, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") return } diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go index f9430c24..870a3706 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -18,6 +18,8 @@ import ( "context" "database/sql" "encoding/json" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) const accountDataSchema = ` @@ -75,7 +77,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { func (s *accountDataStatements) insertAccountData( ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) error { - _, err := txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) + _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) return err } diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index fbbdc337..2d935fb6 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -20,6 +20,7 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -104,9 +105,9 @@ func (s *accountsStatements) insertAccount( var err error if appserviceID == "" { - _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) } else { - _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) } if err != nil { return nil, err @@ -163,7 +164,7 @@ func (s *accountsStatements) selectNewNumericLocalpart( ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt if txn != nil { - stmt = txn.Stmt(stmt) + stmt = sqlutil.TxStmt(txn, stmt) } err = stmt.QueryRowContext(ctx).Scan(&id) return diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go index 4eeaf037..a67e892f 100644 --- a/userapi/storage/accounts/sqlite3/profile_table.go +++ b/userapi/storage/accounts/sqlite3/profile_table.go @@ -87,7 +87,7 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { func (s *profilesStatements) insertProfile( ctx context.Context, txn *sql.Tx, localpart string, ) error { - _, err := txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") return err }