diff --git a/federationsender/storage/sqlite3/storage.go b/federationsender/storage/sqlite3/storage.go index 7ba51fb5..7fe6b65b 100644 --- a/federationsender/storage/sqlite3/storage.go +++ b/federationsender/storage/sqlite3/storage.go @@ -35,7 +35,9 @@ type Database struct { queuePDUsStatements queueJSONStatements sqlutil.PartitionOffsetStatements - db *sql.DB + db *sql.DB + queuePDUsWriter *sqlutil.TransactionWriter + queueJSONWriter *sqlutil.TransactionWriter } // NewDatabase opens a new database @@ -74,6 +76,9 @@ func (d *Database) prepare() error { return err } + d.queuePDUsWriter = sqlutil.NewTransactionWriter() + d.queueJSONWriter = sqlutil.NewTransactionWriter() + return d.PartitionOffsetStatements.Prepare(d.db, "federationsender") } @@ -145,12 +150,16 @@ func (d *Database) GetJoinedHosts( // metadata entries. func (d *Database) StoreJSON( ctx context.Context, js string, -) (int64, error) { - nid, err := d.insertQueueJSON(ctx, nil, js) - if err != nil { - return 0, fmt.Errorf("d.insertQueueJSON: %w", err) - } - return nid, nil +) (nid int64, err error) { + err = d.queueJSONWriter.Do(d.db, func(txn *sql.Tx) error { + n, e := d.insertQueueJSON(ctx, nil, js) + if e != nil { + return fmt.Errorf("d.insertQueueJSON: %w", e) + } + nid = n + return nil + }) + return } // AssociatePDUWithDestination creates an association that the @@ -162,7 +171,7 @@ func (d *Database) AssociatePDUWithDestination( serverName gomatrixserverlib.ServerName, nids []int64, ) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.queuePDUsWriter.Do(d.db, func(txn *sql.Tx) error { for _, nid := range nids { if err := d.insertQueuePDU( ctx, // context @@ -230,18 +239,18 @@ func (d *Database) CleanTransactionPDUs( serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID, ) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - nids, err := d.selectQueuePDUs(ctx, txn, serverName, transactionID, 50) + var err error + var nids []int64 + var deleteNIDs []int64 + if err = d.queuePDUsWriter.Do(d.db, func(txn *sql.Tx) error { + nids, err = d.selectQueuePDUs(ctx, txn, serverName, transactionID, 50) if err != nil { return fmt.Errorf("d.selectQueuePDUs: %w", err) } - if err = d.deleteQueueTransaction(ctx, txn, serverName, transactionID); err != nil { return fmt.Errorf("d.deleteQueueTransaction: %w", err) } - var count int64 - var deleteNIDs []int64 for _, nid := range nids { count, err = d.selectQueueReferenceJSONCount(ctx, txn, nid) if err != nil { @@ -251,15 +260,19 @@ func (d *Database) CleanTransactionPDUs( deleteNIDs = append(deleteNIDs, nid) } } - + return nil + }); err != nil { + return err + } + err = d.queueJSONWriter.Do(d.db, func(txn *sql.Tx) error { if len(deleteNIDs) > 0 { if err = d.deleteQueueJSON(ctx, txn, deleteNIDs); err != nil { return fmt.Errorf("d.deleteQueueJSON: %w", err) } } - return nil }) + return err } // GetPendingPDUCount returns the number of PDUs waiting to be