Use TransactionWriter in roomserver SQLite (#1208)

This commit is contained in:
Neil Alexander 2020-07-21 10:48:49 +01:00 committed by GitHub
parent 489f34fed7
commit d76eb1b994
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 259 additions and 148 deletions

View File

@ -49,13 +49,16 @@ const bulkSelectEventJSONSQL = `
type eventJSONStatements struct { type eventJSONStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
insertEventJSONStmt *sql.Stmt insertEventJSONStmt *sql.Stmt
bulkSelectEventJSONStmt *sql.Stmt bulkSelectEventJSONStmt *sql.Stmt
} }
func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
s := &eventJSONStatements{} s := &eventJSONStatements{
s.db = db db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(eventJSONSchema) _, err := db.Exec(eventJSONSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -69,8 +72,10 @@ func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
func (s *eventJSONStatements) InsertEventJSON( func (s *eventJSONStatements) InsertEventJSON(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
return err return err
})
} }
func (s *eventJSONStatements) BulkSelectEventJSON( func (s *eventJSONStatements) BulkSelectEventJSON(

View File

@ -64,6 +64,7 @@ const bulkSelectEventStateKeyNIDSQL = `
type eventStateKeyStatements struct { type eventStateKeyStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
insertEventStateKeyNIDStmt *sql.Stmt insertEventStateKeyNIDStmt *sql.Stmt
selectEventStateKeyNIDStmt *sql.Stmt selectEventStateKeyNIDStmt *sql.Stmt
bulkSelectEventStateKeyNIDStmt *sql.Stmt bulkSelectEventStateKeyNIDStmt *sql.Stmt
@ -71,8 +72,10 @@ type eventStateKeyStatements struct {
} }
func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) {
s := &eventStateKeyStatements{} s := &eventStateKeyStatements{
s.db = db db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(eventStateKeysSchema) _, err := db.Exec(eventStateKeysSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -89,12 +92,18 @@ func (s *eventStateKeyStatements) InsertEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string, ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) { ) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64 var eventStateKeyNID int64
var err error err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
var res sql.Result
insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt) insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt)
if res, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil { res, err := insertStmt.ExecContext(ctx, eventStateKey)
eventStateKeyNID, err = res.LastInsertId() if err != nil {
return err
} }
eventStateKeyNID, err = res.LastInsertId()
if err != nil {
return err
}
return nil
})
return types.EventStateKeyNID(eventStateKeyNID), err return types.EventStateKeyNID(eventStateKeyNID), err
} }

View File

@ -78,6 +78,7 @@ const bulkSelectEventTypeNIDSQL = `
type eventTypeStatements struct { type eventTypeStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
insertEventTypeNIDStmt *sql.Stmt insertEventTypeNIDStmt *sql.Stmt
insertEventTypeNIDResultStmt *sql.Stmt insertEventTypeNIDResultStmt *sql.Stmt
selectEventTypeNIDStmt *sql.Stmt selectEventTypeNIDStmt *sql.Stmt
@ -85,8 +86,10 @@ type eventTypeStatements struct {
} }
func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) { func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
s := &eventTypeStatements{} s := &eventTypeStatements{
s.db = db db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(eventTypesSchema) _, err := db.Exec(eventTypesSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -104,12 +107,15 @@ func (s *eventTypeStatements) InsertEventTypeNID(
ctx context.Context, tx *sql.Tx, eventType string, ctx context.Context, tx *sql.Tx, eventType string,
) (types.EventTypeNID, error) { ) (types.EventTypeNID, error) {
var eventTypeNID int64 var eventTypeNID int64
var err error err := s.writer.Do(s.db, tx, func(tx *sql.Tx) error {
insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt) insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt)
resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt) resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt)
if _, err = insertStmt.ExecContext(ctx, eventType); err == nil { _, err := insertStmt.ExecContext(ctx, eventType)
err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID) if err != nil {
return err
} }
return resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID)
})
return types.EventTypeNID(eventTypeNID), err return types.EventTypeNID(eventTypeNID), err
} }

View File

@ -99,6 +99,7 @@ const selectRoomNIDForEventNIDSQL = "" +
type eventStatements struct { type eventStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventStmt *sql.Stmt selectEventStmt *sql.Stmt
bulkSelectStateEventByIDStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt
@ -115,8 +116,10 @@ type eventStatements struct {
} }
func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
s := &eventStatements{} s := &eventStatements{
s.db = db db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(eventsSchema) _, err := db.Exec(eventsSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -151,19 +154,23 @@ func (s *eventStatements) InsertEvent(
depth int64, depth int64,
) (types.EventNID, types.StateSnapshotNID, error) { ) (types.EventNID, types.StateSnapshotNID, error) {
// attempt to insert: the last_row_id is the event NID // attempt to insert: the last_row_id is the event NID
var eventNID int64
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
result, err := insertStmt.ExecContext( result, err := insertStmt.ExecContext(
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
) )
if err != nil { if err != nil {
return 0, 0, err return err
} }
modified, err := result.RowsAffected() modified, err := result.RowsAffected()
if modified == 0 && err == nil { if modified == 0 && err == nil {
return 0, 0, sql.ErrNoRows return sql.ErrNoRows
} }
eventNID, err := result.LastInsertId() eventNID, err = result.LastInsertId()
return err
})
return types.EventNID(eventNID), 0, err return types.EventNID(eventNID), 0, err
} }
@ -279,8 +286,10 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
func (s *eventStatements) UpdateEventState( func (s *eventStatements) UpdateEventState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error { ) error {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) _, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
return err return err
})
} }
func (s *eventStatements) SelectEventSentToOutput( func (s *eventStatements) SelectEventSentToOutput(
@ -288,17 +297,15 @@ func (s *eventStatements) SelectEventSentToOutput(
) (sentToOutput bool, err error) { ) (sentToOutput bool, err error) {
selectStmt := sqlutil.TxStmt(txn, s.selectEventSentToOutputStmt) selectStmt := sqlutil.TxStmt(txn, s.selectEventSentToOutputStmt)
err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
//err = s.selectEventSentToOutputStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
if err != nil {
}
return return
} }
func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt) updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt)
_, err := updateStmt.ExecContext(ctx, int64(eventNID)) _, err := updateStmt.ExecContext(ctx, int64(eventNID))
//_, err := s.updateEventSentToOutputStmt.ExecContext(ctx, int64(eventNID))
return err return err
})
} }
func (s *eventStatements) SelectEventID( func (s *eventStatements) SelectEventID(

View File

@ -63,6 +63,8 @@ SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_ni
` `
type inviteStatements struct { type inviteStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertInviteEventStmt *sql.Stmt insertInviteEventStmt *sql.Stmt
selectInviteActiveForUserInRoomStmt *sql.Stmt selectInviteActiveForUserInRoomStmt *sql.Stmt
updateInviteRetiredStmt *sql.Stmt updateInviteRetiredStmt *sql.Stmt
@ -70,7 +72,10 @@ type inviteStatements struct {
} }
func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) { func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) {
s := &inviteStatements{} s := &inviteStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(inviteSchema) _, err := db.Exec(inviteSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -90,42 +95,48 @@ func (s *inviteStatements) InsertInviteEvent(
targetUserNID, senderUserNID types.EventStateKeyNID, targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte, inviteEventJSON []byte,
) (bool, error) { ) (bool, error) {
var count int64
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
result, err := stmt.ExecContext( result, err := stmt.ExecContext(
ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
) )
if err != nil { if err != nil {
return false, err return err
} }
count, err := result.RowsAffected() count, err = result.RowsAffected()
if err != nil { if err != nil {
return false, err return err
} }
return count != 0, nil return nil
})
return count != 0, err
} }
func (s *inviteStatements) UpdateInviteRetired( func (s *inviteStatements) UpdateInviteRetired(
ctx context.Context, ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) { ) (eventIDs []string, err error) {
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
// gather all the event IDs we will retire // gather all the event IDs we will retire
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
if err != nil { if err != nil {
return nil, err return err
} }
defer (func() { err = rows.Close() })() defer (func() { err = rows.Close() })()
for rows.Next() { for rows.Next() {
var inviteEventID string var inviteEventID string
if err = rows.Scan(&inviteEventID); err != nil { if err = rows.Scan(&inviteEventID); err != nil {
return nil, err return err
} }
eventIDs = append(eventIDs, inviteEventID) eventIDs = append(eventIDs, inviteEventID)
} }
// now retire the invites // now retire the invites
stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
_, err = stmt.ExecContext(ctx, roomNID, targetUserNID) _, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
return err
})
return return
} }

View File

@ -76,6 +76,8 @@ const updateMembershipSQL = "" +
" WHERE room_nid = $4 AND target_nid = $5" " WHERE room_nid = $4 AND target_nid = $5"
type membershipStatements struct { type membershipStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertMembershipStmt *sql.Stmt insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt
selectMembershipFromRoomAndTargetStmt *sql.Stmt selectMembershipFromRoomAndTargetStmt *sql.Stmt
@ -87,7 +89,10 @@ type membershipStatements struct {
} }
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
s := &membershipStatements{} s := &membershipStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(membershipSchema) _, err := db.Exec(membershipSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -110,9 +115,11 @@ func (s *membershipStatements) InsertMembership(
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
localTarget bool, localTarget bool,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
return err return err
})
} }
func (s *membershipStatements) SelectMembershipForUpdate( func (s *membershipStatements) SelectMembershipForUpdate(
@ -194,9 +201,11 @@ func (s *membershipStatements) UpdateMembership(
senderUserNID types.EventStateKeyNID, membership tables.MembershipState, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
eventNID types.EventNID, eventNID types.EventNID,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
_, err := stmt.ExecContext( _, err := stmt.ExecContext(
ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID, ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID,
) )
return err return err
})
} }

View File

@ -53,12 +53,17 @@ const selectPreviousEventExistsSQL = `
` `
type previousEventStatements struct { type previousEventStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertPreviousEventStmt *sql.Stmt insertPreviousEventStmt *sql.Stmt
selectPreviousEventExistsStmt *sql.Stmt selectPreviousEventExistsStmt *sql.Stmt
} }
func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
s := &previousEventStatements{} s := &previousEventStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(previousEventSchema) _, err := db.Exec(previousEventSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -77,11 +82,13 @@ func (s *previousEventStatements) InsertPreviousEvent(
previousEventReferenceSHA256 []byte, previousEventReferenceSHA256 []byte,
eventNID types.EventNID, eventNID types.EventNID,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
_, err := stmt.ExecContext( _, err := stmt.ExecContext(
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
) )
return err return err
})
} }
// Check if the event reference exists // Check if the event reference exists

View File

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
) )
@ -43,13 +44,18 @@ const selectPublishedSQL = "" +
"SELECT published FROM roomserver_published WHERE room_id = $1" "SELECT published FROM roomserver_published WHERE room_id = $1"
type publishedStatements struct { type publishedStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
upsertPublishedStmt *sql.Stmt upsertPublishedStmt *sql.Stmt
selectAllPublishedStmt *sql.Stmt selectAllPublishedStmt *sql.Stmt
selectPublishedStmt *sql.Stmt selectPublishedStmt *sql.Stmt
} }
func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) { func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
s := &publishedStatements{} s := &publishedStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(publishedSchema) _, err := db.Exec(publishedSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -64,8 +70,10 @@ func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
func (s *publishedStatements) UpsertRoomPublished( func (s *publishedStatements) UpsertRoomPublished(
ctx context.Context, roomID string, published bool, ctx context.Context, roomID string, published bool,
) (err error) { ) (err error) {
_, err = s.upsertPublishedStmt.ExecContext(ctx, roomID, published) return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
return _, err := s.upsertPublishedStmt.ExecContext(ctx, roomID, published)
return err
})
} }
func (s *publishedStatements) SelectPublishedFromRoomID( func (s *publishedStatements) SelectPublishedFromRoomID(

View File

@ -52,6 +52,8 @@ const markRedactionValidatedSQL = "" +
" UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1" " UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1"
type redactionStatements struct { type redactionStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertRedactionStmt *sql.Stmt insertRedactionStmt *sql.Stmt
selectRedactionInfoByRedactionEventIDStmt *sql.Stmt selectRedactionInfoByRedactionEventIDStmt *sql.Stmt
selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt
@ -59,7 +61,10 @@ type redactionStatements struct {
} }
func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) { func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
s := &redactionStatements{} s := &redactionStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(redactionsSchema) _, err := db.Exec(redactionsSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -76,9 +81,11 @@ func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
func (s *redactionStatements) InsertRedaction( func (s *redactionStatements) InsertRedaction(
ctx context.Context, txn *sql.Tx, info tables.RedactionInfo, ctx context.Context, txn *sql.Tx, info tables.RedactionInfo,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt) stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt)
_, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated) _, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated)
return err return err
})
} }
func (s *redactionStatements) SelectRedactionInfoByRedactionEventID( func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
@ -114,7 +121,9 @@ func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted(
func (s *redactionStatements) MarkRedactionValidated( func (s *redactionStatements) MarkRedactionValidated(
ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool, ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt) stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt)
_, err := stmt.ExecContext(ctx, redactionEventID, validated) _, err := stmt.ExecContext(ctx, redactionEventID, validated)
return err return err
})
} }

View File

@ -20,6 +20,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
) )
@ -55,6 +56,8 @@ const deleteRoomAliasSQL = `
` `
type roomAliasesStatements struct { type roomAliasesStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertRoomAliasStmt *sql.Stmt insertRoomAliasStmt *sql.Stmt
selectRoomIDFromAliasStmt *sql.Stmt selectRoomIDFromAliasStmt *sql.Stmt
selectAliasesFromRoomIDStmt *sql.Stmt selectAliasesFromRoomIDStmt *sql.Stmt
@ -63,7 +66,10 @@ type roomAliasesStatements struct {
} }
func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
s := &roomAliasesStatements{} s := &roomAliasesStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(roomAliasesSchema) _, err := db.Exec(roomAliasesSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -80,8 +86,10 @@ func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
func (s *roomAliasesStatements) InsertRoomAlias( func (s *roomAliasesStatements) InsertRoomAlias(
ctx context.Context, alias string, roomID string, creatorUserID string, ctx context.Context, alias string, roomID string, creatorUserID string,
) (err error) { ) (err error) {
_, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
return _, err := s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID)
return err
})
} }
func (s *roomAliasesStatements) SelectRoomIDFromAlias( func (s *roomAliasesStatements) SelectRoomIDFromAlias(
@ -130,6 +138,8 @@ func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
func (s *roomAliasesStatements) DeleteRoomAlias( func (s *roomAliasesStatements) DeleteRoomAlias(
ctx context.Context, alias string, ctx context.Context, alias string,
) (err error) { ) (err error) {
_, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias) return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
return _, err := s.deleteRoomAliasStmt.ExecContext(ctx, alias)
return err
})
} }

View File

@ -64,6 +64,8 @@ const selectRoomVersionForRoomNIDSQL = "" +
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" "SELECT room_version FROM roomserver_rooms WHERE room_nid = $1"
type roomStatements struct { type roomStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertRoomNIDStmt *sql.Stmt insertRoomNIDStmt *sql.Stmt
selectRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt
selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt
@ -74,7 +76,10 @@ type roomStatements struct {
} }
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
s := &roomStatements{} s := &roomStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(roomsSchema) _, err := db.Exec(roomsSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -94,9 +99,12 @@ func (s *roomStatements) InsertRoomNID(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion, roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (types.RoomNID, error) { ) (types.RoomNID, error) {
var err error err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt) insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
if _, err = insertStmt.ExecContext(ctx, roomID, roomVersion); err == nil { _, err := insertStmt.ExecContext(ctx, roomID, roomVersion)
return err
})
if err == nil {
return s.SelectRoomNID(ctx, txn, roomID) return s.SelectRoomNID(ctx, txn, roomID)
} else { } else {
return types.RoomNID(0), err return types.RoomNID(0), err
@ -155,6 +163,7 @@ func (s *roomStatements) UpdateLatestEventNIDs(
lastEventSentNID types.EventNID, lastEventSentNID types.EventNID,
stateSnapshotNID types.StateSnapshotNID, stateSnapshotNID types.StateSnapshotNID,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt) stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt)
_, err := stmt.ExecContext( _, err := stmt.ExecContext(
ctx, ctx,
@ -164,6 +173,7 @@ func (s *roomStatements) UpdateLatestEventNIDs(
roomNID, roomNID,
) )
return err return err
})
} }
func (s *roomStatements) SelectRoomVersionForRoomID( func (s *roomStatements) SelectRoomVersionForRoomID(

View File

@ -74,6 +74,7 @@ const bulkSelectFilteredStateBlockEntriesSQL = "" +
type stateBlockStatements struct { type stateBlockStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
insertStateDataStmt *sql.Stmt insertStateDataStmt *sql.Stmt
selectNextStateBlockNIDStmt *sql.Stmt selectNextStateBlockNIDStmt *sql.Stmt
bulkSelectStateBlockEntriesStmt *sql.Stmt bulkSelectStateBlockEntriesStmt *sql.Stmt
@ -81,8 +82,10 @@ type stateBlockStatements struct {
} }
func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
s := &stateBlockStatements{} s := &stateBlockStatements{
s.db = db db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(stateDataSchema) _, err := db.Exec(stateDataSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -104,11 +107,11 @@ func (s *stateBlockStatements) BulkInsertStateData(
return 0, nil return 0, nil
} }
var stateBlockNID types.StateBlockNID var stateBlockNID types.StateBlockNID
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
if err != nil { if err != nil {
return 0, err return err
} }
for _, entry := range entries { for _, entry := range entries {
_, err := txn.Stmt(s.insertStateDataStmt).ExecContext( _, err := txn.Stmt(s.insertStateDataStmt).ExecContext(
ctx, ctx,
@ -118,10 +121,12 @@ func (s *stateBlockStatements) BulkInsertStateData(
int64(entry.EventNID), int64(entry.EventNID),
) )
if err != nil { if err != nil {
return 0, err return err
} }
} }
return stateBlockNID, nil return nil
})
return stateBlockNID, err
} }
func (s *stateBlockStatements) BulkSelectStateBlockEntries( func (s *stateBlockStatements) BulkSelectStateBlockEntries(

View File

@ -50,13 +50,16 @@ const bulkSelectStateBlockNIDsSQL = "" +
type stateSnapshotStatements struct { type stateSnapshotStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
insertStateStmt *sql.Stmt insertStateStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt *sql.Stmt bulkSelectStateBlockNIDsStmt *sql.Stmt
} }
func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{} s := &stateSnapshotStatements{
s.db = db db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(stateSnapshotSchema) _, err := db.Exec(stateSnapshotSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -75,14 +78,19 @@ func (s *stateSnapshotStatements) InsertState(
if err != nil { if err != nil {
return return
} }
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
insertStmt := txn.Stmt(s.insertStateStmt) insertStmt := txn.Stmt(s.insertStateStmt)
if res, err2 := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)); err2 == nil { res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
lastRowID, err3 := res.LastInsertId() if err != nil {
if err3 != nil { return err
err = err3 }
lastRowID, err := res.LastInsertId()
if err != nil {
return err
} }
stateNID = types.StateSnapshotNID(lastRowID) stateNID = types.StateSnapshotNID(lastRowID)
} return nil
})
return return
} }

View File

@ -44,12 +44,17 @@ const selectTransactionEventIDSQL = `
` `
type transactionStatements struct { type transactionStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertTransactionStmt *sql.Stmt insertTransactionStmt *sql.Stmt
selectTransactionEventIDStmt *sql.Stmt selectTransactionEventIDStmt *sql.Stmt
} }
func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) { func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) {
s := &transactionStatements{} s := &transactionStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(transactionsSchema) _, err := db.Exec(transactionsSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -68,11 +73,13 @@ func (s *transactionStatements) InsertTransaction(
userID string, userID string,
eventID string, eventID string,
) (err error) { ) (err error) {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt) stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt)
_, err = stmt.ExecContext( _, err := stmt.ExecContext(
ctx, transactionID, sessionID, userID, eventID, ctx, transactionID, sessionID, userID, eventID,
) )
return return err
})
} }
func (s *transactionStatements) SelectTransactionEventID( func (s *transactionStatements) SelectTransactionEventID(