Always defer *sql.Rows.Close and consult with Err (#844)

* Always defer *sql.Rows.Close and consult with Err

database/sql.Rows.Next() makes sure to call Close only after exhausting
result rows which would NOT happen when returning early from a bad Scan.
Close being idempotent makes it a great candidate to get always deferred
regardless of what happens later on the result set.

This change also makes sure call Err() after exhausting Next() and
propagate non-nil results from it as the documentation advises.

Closes #764

Signed-off-by: Kiril Vladimiroff <kiril@vladimiroff.org>

* Override named result parameters in last returns

Signed-off-by: Kiril Vladimiroff <kiril@vladimiroff.org>

* Do the same over new changes that got merged

Signed-off-by: Kiril Vladimiroff <kiril@vladimiroff.org>

Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
This commit is contained in:
Kiril Vladimiroff 2020-02-11 16:12:21 +02:00 committed by GitHub
parent d45f869cdd
commit d5dbe546e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 81 additions and 49 deletions

View File

@ -90,6 +90,7 @@ func (s *accountDataStatements) selectAccountData(
if err != nil { if err != nil {
return return
} }
defer rows.Close() // nolint: errcheck
global = []gomatrixserverlib.ClientEvent{} global = []gomatrixserverlib.ClientEvent{}
rooms = make(map[string][]gomatrixserverlib.ClientEvent) rooms = make(map[string][]gomatrixserverlib.ClientEvent)
@ -114,8 +115,7 @@ func (s *accountDataStatements) selectAccountData(
global = append(global, ac) global = append(global, ac)
} }
} }
return global, rooms, rows.Err()
return
} }
func (s *accountDataStatements) selectAccountDataByType( func (s *accountDataStatements) selectAccountDataByType(

View File

@ -122,11 +122,10 @@ func (s *membershipStatements) selectMembershipsByLocalpart(
for rows.Next() { for rows.Next() {
var m authtypes.Membership var m authtypes.Membership
m.Localpart = localpart m.Localpart = localpart
if err := rows.Scan(&m.RoomID, &m.EventID); err != nil { if err = rows.Scan(&m.RoomID, &m.EventID); err != nil {
return nil, err return
} }
memberships = append(memberships, m) memberships = append(memberships, m)
} }
return memberships, rows.Err()
return
} }

View File

@ -97,6 +97,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
if err != nil { if err != nil {
return return
} }
defer rows.Close() // nolint: errcheck
threepids = []authtypes.ThreePID{} threepids = []authtypes.ThreePID{}
for rows.Next() { for rows.Next() {
@ -110,8 +111,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
Medium: medium, Medium: medium,
}) })
} }
return threepids, rows.Err()
return
} }
func (s *threepidStatements) insertThreePID( func (s *threepidStatements) insertThreePID(

View File

@ -226,6 +226,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
if err != nil { if err != nil {
return devices, err return devices, err
} }
defer rows.Close() // nolint: errcheck
for rows.Next() { for rows.Next() {
var dev authtypes.Device var dev authtypes.Device
@ -237,5 +238,5 @@ func (s *devicesStatements) selectDevicesByLocalpart(
devices = append(devices, dev) devices = append(devices, dev)
} }
return devices, nil return devices, rows.Err()
} }

View File

@ -117,7 +117,7 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
} }
} }
return results, nil return results, rows.Err()
} }
func (s *serverKeyStatements) upsertServerKeys( func (s *serverKeyStatements) upsertServerKeys(

View File

@ -99,7 +99,7 @@ func (s *PartitionOffsetStatements) selectPartitionOffsets(
} }
results = append(results, offset) results = append(results, offset)
} }
return results, nil return results, rows.Err()
} }
// UpsertPartitionOffset updates or inserts the partition offset for the given topic. // UpsertPartitionOffset updates or inserts the partition offset for the given topic.

View File

@ -132,5 +132,5 @@ func joinedHostsFromStmt(
}) })
} }
return result, nil return result, rows.Err()
} }

View File

@ -144,6 +144,7 @@ func (s *thumbnailStatements) selectThumbnails(
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() // nolint: errcheck
var thumbnails []*types.ThumbnailMetadata var thumbnails []*types.ThumbnailMetadata
for rows.Next() { for rows.Next() {
@ -167,5 +168,5 @@ func (s *thumbnailStatements) selectThumbnails(
thumbnails = append(thumbnails, &thumbnailMetadata) thumbnails = append(thumbnails, &thumbnailMetadata)
} }
return thumbnails, err return thumbnails, rows.Err()
} }

View File

@ -203,6 +203,7 @@ func (s *publicRoomsStatements) selectPublicRooms(
if err != nil { if err != nil {
return []types.PublicRoom{}, nil return []types.PublicRoom{}, nil
} }
defer rows.Close() // nolint: errcheck
rooms := []types.PublicRoom{} rooms := []types.PublicRoom{}
for rows.Next() { for rows.Next() {
@ -222,7 +223,7 @@ func (s *publicRoomsStatements) selectPublicRooms(
rooms = append(rooms, r) rooms = append(rooms, r)
} }
return rooms, nil return rooms, rows.Err()
} }
func (s *publicRoomsStatements) selectRoomVisibility( func (s *publicRoomsStatements) selectRoomVisibility(

View File

@ -102,5 +102,5 @@ func (s *eventJSONStatements) bulkSelectEventJSON(
} }
result.EventNID = types.EventNID(eventNID) result.EventNID = types.EventNID(eventNID)
} }
return results[:i], nil return results[:i], rows.Err()
} }

View File

@ -125,7 +125,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
} }
result[stateKey] = types.EventStateKeyNID(stateKeyNID) result[stateKey] = types.EventStateKeyNID(stateKeyNID)
} }
return result, nil return result, rows.Err()
} }
func (s *eventStateKeyStatements) bulkSelectEventStateKey( func (s *eventStateKeyStatements) bulkSelectEventStateKey(
@ -150,5 +150,5 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKey(
} }
result[types.EventStateKeyNID(stateKeyNID)] = stateKey result[types.EventStateKeyNID(stateKeyNID)] = stateKey
} }
return result, nil return result, rows.Err()
} }

View File

@ -143,5 +143,5 @@ func (s *eventTypeStatements) bulkSelectEventTypeNID(
} }
result[eventType] = types.EventTypeNID(eventTypeNID) result[eventType] = types.EventTypeNID(eventTypeNID)
} }
return result, nil return result, rows.Err()
} }

View File

@ -209,6 +209,9 @@ func (s *eventStatements) bulkSelectStateEventByID(
return nil, err return nil, err
} }
} }
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventIDs) { if i != len(eventIDs) {
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
// We don't know which ones were missing because we don't return the string IDs in the query. // We don't know which ones were missing because we don't return the string IDs in the query.
@ -219,7 +222,7 @@ func (s *eventStatements) bulkSelectStateEventByID(
fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)), fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)),
) )
} }
return results, err return results, nil
} }
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
@ -251,12 +254,15 @@ func (s *eventStatements) bulkSelectStateAtEventByID(
) )
} }
} }
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventIDs) { if i != len(eventIDs) {
return nil, types.MissingEventError( return nil, types.MissingEventError(
fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)), fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)),
) )
} }
return results, err return results, nil
} }
func (s *eventStatements) updateEventState( func (s *eventStatements) updateEventState(
@ -321,6 +327,9 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(
result.EventID = eventID result.EventID = eventID
result.EventSHA256 = eventSHA256 result.EventSHA256 = eventSHA256
} }
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventNIDs) { if i != len(eventNIDs) {
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
} }
@ -343,6 +352,9 @@ func (s *eventStatements) bulkSelectEventReference(
return nil, err return nil, err
} }
} }
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventNIDs) { if i != len(eventNIDs) {
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
} }
@ -366,6 +378,9 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []typ
} }
results[types.EventNID(eventNID)] = eventID results[types.EventNID(eventNID)] = eventID
} }
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventNIDs) { if i != len(eventNIDs) {
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
} }
@ -389,7 +404,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []str
} }
results[eventID] = types.EventNID(eventNID) results[eventID] = types.EventNID(eventNID)
} }
return results, nil return results, rows.Err()
} }
func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) { func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) {

View File

@ -114,21 +114,23 @@ func (s *inviteStatements) insertInviteEvent(
func (s *inviteStatements) updateInviteRetired( func (s *inviteStatements) updateInviteRetired(
ctx context.Context, ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) { ) ([]string, error) {
stmt := common.TxStmt(txn, s.updateInviteRetiredStmt) stmt := common.TxStmt(txn, s.updateInviteRetiredStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer (func() { err = rows.Close() })() defer rows.Close() // nolint: errcheck
var eventIDs []string
for rows.Next() { for rows.Next() {
var inviteEventID string var inviteEventID string
if err := rows.Scan(&inviteEventID); err != nil { if err = rows.Scan(&inviteEventID); err != nil {
return nil, err return nil, err
} }
eventIDs = append(eventIDs, inviteEventID) eventIDs = append(eventIDs, inviteEventID)
} }
return return eventIDs, rows.Err()
} }
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs // selectInviteActiveForUserInRoom returns a list of sender state key NIDs
@ -151,5 +153,5 @@ func (s *inviteStatements) selectInviteActiveForUserInRoom(
} }
result = append(result, types.EventStateKeyNID(senderUserNID)) result = append(result, types.EventStateKeyNID(senderUserNID))
} }
return result, nil return result, rows.Err()
} }

View File

@ -151,6 +151,7 @@ func (s *membershipStatements) selectMembershipsFromRoom(
if err != nil { if err != nil {
return return
} }
defer rows.Close() // nolint: errcheck
for rows.Next() { for rows.Next() {
var eNID types.EventNID var eNID types.EventNID
@ -159,8 +160,9 @@ func (s *membershipStatements) selectMembershipsFromRoom(
} }
eventNIDs = append(eventNIDs, eNID) eventNIDs = append(eventNIDs, eNID)
} }
return return eventNIDs, rows.Err()
} }
func (s *membershipStatements) selectMembershipsFromRoomAndMembership( func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID, membership membershipState, roomNID types.RoomNID, membership membershipState,
@ -170,6 +172,7 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
if err != nil { if err != nil {
return return
} }
defer rows.Close() // nolint: errcheck
for rows.Next() { for rows.Next() {
var eNID types.EventNID var eNID types.EventNID
@ -178,7 +181,7 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
} }
eventNIDs = append(eventNIDs, eNID) eventNIDs = append(eventNIDs, eNID)
} }
return return eventNIDs, rows.Err()
} }
func (s *membershipStatements) updateMembership( func (s *membershipStatements) updateMembership(

View File

@ -90,23 +90,23 @@ func (s *roomAliasesStatements) selectRoomIDFromAlias(
func (s *roomAliasesStatements) selectAliasesFromRoomID( func (s *roomAliasesStatements) selectAliasesFromRoomID(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) (aliases []string, err error) { ) ([]string, error) {
aliases = []string{}
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
if err != nil { if err != nil {
return return nil, err
} }
defer rows.Close() // nolint: errcheck
var aliases []string
for rows.Next() { for rows.Next() {
var alias string var alias string
if err = rows.Scan(&alias); err != nil { if err = rows.Scan(&alias); err != nil {
return return nil, err
} }
aliases = append(aliases, alias) aliases = append(aliases, alias)
} }
return aliases, rows.Err()
return
} }
func (s *roomAliasesStatements) selectCreatorIDFromAlias( func (s *roomAliasesStatements) selectCreatorIDFromAlias(

View File

@ -152,7 +152,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(
eventNID int64 eventNID int64
entry types.StateEntry entry types.StateEntry
) )
if err := rows.Scan( if err = rows.Scan(
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
); err != nil { ); err != nil {
return nil, err return nil, err
@ -169,10 +169,13 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(
} }
current.StateEntries = append(current.StateEntries, entry) current.StateEntries = append(current.StateEntries, entry)
} }
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(stateBlockNIDs) { if i != len(stateBlockNIDs) {
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs)) return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs))
} }
return results, nil return results, err
} }
func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
@ -237,7 +240,7 @@ func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
if current.StateEntries != nil { if current.StateEntries != nil {
results = append(results, current) results = append(results, current)
} }
return results, nil return results, rows.Err()
} }
func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array { func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {

View File

@ -104,7 +104,7 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
for ; rows.Next(); i++ { for ; rows.Next(); i++ {
result := &results[i] result := &results[i]
var stateBlockNIDs pq.Int64Array var stateBlockNIDs pq.Int64Array
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil { if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
return nil, err return nil, err
} }
result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs)) result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs))
@ -112,6 +112,9 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k]) result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k])
} }
} }
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(stateNIDs) { if i != len(stateNIDs) {
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs)) return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
} }

View File

@ -118,6 +118,7 @@ func (s *accountDataStatements) selectAccountDataInRange(
if err != nil { if err != nil {
return return
} }
defer rows.Close() // nolint: errcheck
for rows.Next() { for rows.Next() {
var dataType string var dataType string
@ -133,8 +134,7 @@ func (s *accountDataStatements) selectAccountDataInRange(
data[roomID] = []string{dataType} data[roomID] = []string{dataType}
} }
} }
return data, rows.Err()
return
} }
func (s *accountDataStatements) selectMaxAccountDataID( func (s *accountDataStatements) selectMaxAccountDataID(

View File

@ -91,6 +91,7 @@ func (s *backwardExtremitiesStatements) selectBackwardExtremitiesForRoom(
if err != nil { if err != nil {
return return
} }
defer rows.Close() // nolint: errcheck
for rows.Next() { for rows.Next() {
var eID string var eID string
@ -101,7 +102,7 @@ func (s *backwardExtremitiesStatements) selectBackwardExtremitiesForRoom(
eventIDs = append(eventIDs, eID) eventIDs = append(eventIDs, eID)
} }
return return eventIDs, rows.Err()
} }
func (s *backwardExtremitiesStatements) isBackwardExtremity( func (s *backwardExtremitiesStatements) isBackwardExtremity(

View File

@ -154,7 +154,7 @@ func (s *currentRoomStateStatements) selectJoinedUsers(
users = append(users, userID) users = append(users, userID)
result[roomID] = users result[roomID] = users
} }
return result, nil return result, rows.Err()
} }
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
@ -179,7 +179,7 @@ func (s *currentRoomStateStatements) selectRoomIDsWithMembership(
} }
result = append(result, roomID) result = append(result, roomID)
} }
return result, nil return result, rows.Err()
} }
// CurrentState returns all the current state events for the given room. // CurrentState returns all the current state events for the given room.
@ -267,7 +267,7 @@ func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.Event, error) {
} }
result = append(result, ev) result = append(result, ev)
} }
return result, nil return result, rows.Err()
} }
func (s *currentRoomStateStatements) selectStateEvent( func (s *currentRoomStateStatements) selectStateEvent(

View File

@ -133,7 +133,7 @@ func (s *inviteEventsStatements) selectInviteEventsInRange(
result[roomID] = event result[roomID] = event
} }
return result, nil return result, rows.Err()
} }
func (s *inviteEventsStatements) selectMaxInviteID( func (s *inviteEventsStatements) selectMaxInviteID(

View File

@ -170,6 +170,7 @@ func (s *outputRoomEventsStatements) selectStateInRange(
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer rows.Close() // nolint: errcheck
// Fetch all the state change events for all rooms between the two positions then loop each event and: // Fetch all the state change events for all rooms between the two positions then loop each event and:
// - Keep a cache of the event by ID (99% of state change events are for the event itself) // - Keep a cache of the event by ID (99% of state change events are for the event itself)
// - For each room ID, build up an array of event IDs which represents cumulative adds/removes // - For each room ID, build up an array of event IDs which represents cumulative adds/removes
@ -226,7 +227,7 @@ func (s *outputRoomEventsStatements) selectStateInRange(
} }
} }
return stateNeeded, eventIDToEvent, nil return stateNeeded, eventIDToEvent, rows.Err()
} }
// MaxID returns the ID of the last inserted event in this table. 'txn' is optional. If it is not supplied, // MaxID returns the ID of the last inserted event in this table. 'txn' is optional. If it is not supplied,
@ -392,5 +393,5 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
ExcludeFromSync: excludeFromSync, ExcludeFromSync: excludeFromSync,
}) })
} }
return result, nil return result, rows.Err()
} }

View File

@ -134,6 +134,7 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange(
} else if err != nil { } else if err != nil {
return return
} }
defer rows.Close() // nolint: errcheck
// Return the IDs. // Return the IDs.
var eventID string var eventID string
@ -144,7 +145,7 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange(
eventIDs = append(eventIDs, eventID) eventIDs = append(eventIDs, eventID)
} }
return return eventIDs, rows.Err()
} }
// selectPositionInTopology returns the position of a given event in the // selectPositionInTopology returns the position of a given event in the
@ -176,6 +177,7 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsFromPosition(
} else if err != nil { } else if err != nil {
return return
} }
defer rows.Close() // nolint: errcheck
// Return the IDs. // Return the IDs.
var eventID string var eventID string
for rows.Next() { for rows.Next() {
@ -184,5 +186,5 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsFromPosition(
} }
eventIDs = append(eventIDs, eventID) eventIDs = append(eventIDs, eventID)
} }
return return eventIDs, rows.Err()
} }