Fix memory leaks with SQLite prepared statements (#2253)

This commit is contained in:
Neil Alexander 2022-03-04 15:05:42 +00:00 committed by GitHub
parent 5e694cd362
commit 22a034dcba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 34 additions and 20 deletions

View File

@ -154,6 +154,7 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKey(
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer selectPrep.Close()
stmt := sqlutil.TxStmt(txn, selectPrep) stmt := sqlutil.TxStmt(txn, selectPrep)
rows, err := stmt.QueryContext(ctx, iEventStateKeyNIDs...) rows, err := stmt.QueryContext(ctx, iEventStateKeyNIDs...)
if err != nil { if err != nil {

View File

@ -140,6 +140,7 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID(
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer selectPrep.Close()
stmt := sqlutil.TxStmt(txn, selectPrep) stmt := sqlutil.TxStmt(txn, selectPrep)
/////////////// ///////////////

View File

@ -198,11 +198,12 @@ func (s *eventStatements) BulkSelectStateEventByID(
iEventIDs[k] = v iEventIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
selectStmt, err := s.db.Prepare(selectOrig) selectPrep, err := s.db.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt) defer selectPrep.Close() // nolint:errcheck
selectStmt := sqlutil.TxStmt(txn, selectPrep)
/////////////// ///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...) rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
@ -266,11 +267,12 @@ func (s *eventStatements) BulkSelectStateEventByNID(
} }
} }
selectOrig += " ORDER BY event_type_nid, event_state_key_nid ASC" selectOrig += " ORDER BY event_type_nid, event_state_key_nid ASC"
selectStmt, err := s.db.Prepare(selectOrig) selectPrep, err := s.db.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, fmt.Errorf("s.db.Prepare: %w", err) return nil, fmt.Errorf("s.db.Prepare: %w", err)
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt) defer selectPrep.Close() // nolint:errcheck
selectStmt := sqlutil.TxStmt(txn, selectPrep)
rows, err := selectStmt.QueryContext(ctx, params...) rows, err := selectStmt.QueryContext(ctx, params...)
if err != nil { if err != nil {
return nil, fmt.Errorf("selectStmt.QueryContext: %w", err) return nil, fmt.Errorf("selectStmt.QueryContext: %w", err)
@ -307,11 +309,12 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
iEventIDs[k] = v iEventIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
selectStmt, err := s.db.Prepare(selectOrig) selectPrep, err := s.db.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt) defer selectPrep.Close() // nolint:errcheck
selectStmt := sqlutil.TxStmt(txn, selectPrep)
/////////////// ///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...) rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
if err != nil { if err != nil {
@ -390,10 +393,11 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference(
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectPrep = sqlutil.TxStmt(txn, selectPrep) defer selectPrep.Close() // nolint:errcheck
selectStmt := sqlutil.TxStmt(txn, selectPrep)
////////////// //////////////
rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) rows, err := sqlutil.TxStmt(txn, selectStmt).QueryContext(ctx, iEventNIDs...)
if err != nil { if err != nil {
return nil, fmt.Errorf("sqlutil.TxStmt.QueryContext: %w", err) return nil, fmt.Errorf("sqlutil.TxStmt.QueryContext: %w", err)
} }
@ -441,6 +445,7 @@ func (s *eventStatements) BulkSelectEventReference(
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer selectPrep.Close() // nolint:errcheck
/////////////// ///////////////
selectStmt := sqlutil.TxStmt(txn, selectPrep) selectStmt := sqlutil.TxStmt(txn, selectPrep)
@ -471,11 +476,12 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
iEventNIDs[k] = v iEventNIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
selectStmt, err := s.db.Prepare(selectOrig) selectPrep, err := s.db.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt) defer selectPrep.Close() // nolint:errcheck
selectStmt := sqlutil.TxStmt(txn, selectPrep)
/////////////// ///////////////
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
@ -526,11 +532,12 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e
} else { } else {
selectOrig = strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) selectOrig = strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
} }
selectStmt, err := s.db.Prepare(selectOrig) selectPrep, err := s.db.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt) defer selectPrep.Close() // nolint:errcheck
selectStmt := sqlutil.TxStmt(txn, selectPrep)
/////////////// ///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...) rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
if err != nil { if err != nil {
@ -560,6 +567,7 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
if err != nil { if err != nil {
return 0, err return 0, err
} }
defer sqlPrep.Close()
err = sqlutil.TxStmt(txn, sqlPrep).QueryRowContext(ctx, iEventIDs...).Scan(&result) err = sqlutil.TxStmt(txn, sqlPrep).QueryRowContext(ctx, iEventIDs...).Scan(&result)
if err != nil { if err != nil {
return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err) return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err)
@ -575,12 +583,13 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs(
if err != nil { if err != nil {
return nil, err return nil, err
} }
sqlPrep = sqlutil.TxStmt(txn, sqlPrep) defer sqlPrep.Close()
sqlStmt := sqlutil.TxStmt(txn, sqlPrep)
iEventNIDs := make([]interface{}, len(eventNIDs)) iEventNIDs := make([]interface{}, len(eventNIDs))
for i, v := range eventNIDs { for i, v := range eventNIDs {
iEventNIDs[i] = v iEventNIDs[i] = v
} }
rows, err := sqlPrep.QueryContext(ctx, iEventNIDs...) rows, err := sqlStmt.QueryContext(ctx, iEventNIDs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -233,12 +233,13 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
if err != nil { if err != nil {
return nil, err return nil, err
} }
sqlPrep = sqlutil.TxStmt(txn, sqlPrep) defer sqlPrep.Close() // nolint:errcheck
sqlStmt := sqlutil.TxStmt(txn, sqlPrep)
iRoomNIDs := make([]interface{}, len(roomNIDs)) iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs { for i, v := range roomNIDs {
iRoomNIDs[i] = v iRoomNIDs[i] = v
} }
rows, err := sqlPrep.QueryContext(ctx, iRoomNIDs...) rows, err := sqlStmt.QueryContext(ctx, iRoomNIDs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -108,11 +108,12 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
intfs[i] = int64(stateBlockNIDs[i]) intfs[i] = int64(stateBlockNIDs[i])
} }
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(intfs)), 1) selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(intfs)), 1)
selectStmt, err := s.db.Prepare(selectOrig) selectPrep, err := s.db.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt) defer selectPrep.Close() // nolint:errcheck
selectStmt := sqlutil.TxStmt(txn, selectPrep)
rows, err := selectStmt.QueryContext(ctx, intfs...) rows, err := selectStmt.QueryContext(ctx, intfs...)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -113,11 +113,12 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
nids[k] = v nids[k] = v
} }
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
selectStmt, err := s.db.Prepare(selectOrig) selectPrep, err := s.db.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
selectStmt = sqlutil.TxStmt(txn, selectStmt) defer selectPrep.Close() // nolint:errcheck
selectStmt := sqlutil.TxStmt(txn, selectPrep)
rows, err := selectStmt.QueryContext(ctx, nids...) rows, err := selectStmt.QueryContext(ctx, nids...)
if err != nil { if err != nil {