diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/account_data_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/account_data_table.go index 27ca5b61..0d73cb31 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/account_data_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/account_data_table.go @@ -15,6 +15,7 @@ package accounts import ( + "context" "database/sql" "github.com/matrix-org/gomatrixserverlib" @@ -70,17 +71,22 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { return } -func (s *accountDataStatements) insertAccountData(localpart string, roomID string, dataType string, content string) (err error) { - _, err = s.insertAccountDataStmt.Exec(localpart, roomID, dataType, content) +func (s *accountDataStatements) insertAccountData( + ctx context.Context, localpart, roomID, dataType, content string, +) (err error) { + stmt := s.insertAccountDataStmt + _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) return } -func (s *accountDataStatements) selectAccountData(localpart string) ( +func (s *accountDataStatements) selectAccountData( + ctx context.Context, localpart string, +) ( global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error, ) { - rows, err := s.selectAccountDataStmt.Query(localpart) + rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) if err != nil { return } @@ -93,7 +99,7 @@ func (s *accountDataStatements) selectAccountData(localpart string) ( var dataType string var content []byte - if err = rows.Scan(&roomID, &dataType, &content); err != nil && err != sql.ErrNoRows { + if err = rows.Scan(&roomID, &dataType, &content); err != nil { return } @@ -113,11 +119,12 @@ func (s *accountDataStatements) selectAccountData(localpart string) ( } func (s *accountDataStatements) selectAccountDataByType( - localpart string, roomID string, dataType string, + ctx context.Context, localpart, roomID, dataType string, ) (data []gomatrixserverlib.ClientEvent, err error) { data = []gomatrixserverlib.ClientEvent{} - rows, err := s.selectAccountDataByTypeStmt.Query(localpart, roomID, dataType) + stmt := s.selectAccountDataByTypeStmt + rows, err := stmt.QueryContext(ctx, localpart, roomID, dataType) if err != nil { return } diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go index d1177f3e..68a80917 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go @@ -15,6 +15,7 @@ package accounts import ( + "context" "database/sql" "fmt" "time" @@ -76,26 +77,34 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. -func (s *accountsStatements) insertAccount(localpart, hash string) (acc *authtypes.Account, err error) { +func (s *accountsStatements) insertAccount( + ctx context.Context, localpart, hash string, +) (*authtypes.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 - if _, err = s.insertAccountStmt.Exec(localpart, createdTimeMS, hash); err == nil { - acc = &authtypes.Account{ - Localpart: localpart, - UserID: makeUserID(localpart, s.serverName), - ServerName: s.serverName, - } + stmt := s.insertAccountStmt + if _, err := stmt.ExecContext(ctx, localpart, createdTimeMS, hash); err != nil { + return nil, err } + return &authtypes.Account{ + Localpart: localpart, + UserID: makeUserID(localpart, s.serverName), + ServerName: s.serverName, + }, nil +} + +func (s *accountsStatements) selectPasswordHash( + ctx context.Context, localpart string, +) (hash string, err error) { + err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) return } -func (s *accountsStatements) selectPasswordHash(localpart string) (hash string, err error) { - err = s.selectPasswordHashStmt.QueryRow(localpart).Scan(&hash) - return -} - -func (s *accountsStatements) selectAccountByLocalpart(localpart string) (*authtypes.Account, error) { +func (s *accountsStatements) selectAccountByLocalpart( + ctx context.Context, localpart string, +) (*authtypes.Account, error) { var acc authtypes.Account - err := s.selectAccountByLocalpartStmt.QueryRow(localpart).Scan(&acc.Localpart) + stmt := s.selectAccountByLocalpartStmt + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart) if err != nil { acc.UserID = makeUserID(localpart, s.serverName) acc.ServerName = s.serverName diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go index 590c5ed8..75b65b53 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go @@ -15,6 +15,7 @@ package accounts import ( + "context" "database/sql" "github.com/lib/pq" @@ -80,18 +81,27 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { return } -func (s *membershipStatements) insertMembership(localpart string, roomID string, eventID string, txn *sql.Tx) (err error) { - _, err = txn.Stmt(s.insertMembershipStmt).Exec(localpart, roomID, eventID) +func (s *membershipStatements) insertMembership( + ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, +) (err error) { + stmt := txn.Stmt(s.insertMembershipStmt) + _, err = stmt.ExecContext(ctx, localpart, roomID, eventID) return } -func (s *membershipStatements) deleteMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) (err error) { - _, err = txn.Stmt(s.deleteMembershipsByEventIDsStmt).Exec(pq.StringArray(eventIDs)) +func (s *membershipStatements) deleteMembershipsByEventIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) (err error) { + stmt := txn.Stmt(s.deleteMembershipsByEventIDsStmt) + _, err = stmt.ExecContext(ctx, pq.StringArray(eventIDs)) return } -func (s *membershipStatements) selectMembershipsByLocalpart(localpart string) (memberships []authtypes.Membership, err error) { - rows, err := s.selectMembershipsByLocalpartStmt.Query(localpart) +func (s *membershipStatements) selectMembershipsByLocalpart( + ctx context.Context, localpart string, +) (memberships []authtypes.Membership, err error) { + stmt := s.selectMembershipsByLocalpartStmt + rows, err := stmt.QueryContext(ctx, localpart) if err != nil { return } @@ -111,7 +121,11 @@ func (s *membershipStatements) selectMembershipsByLocalpart(localpart string) (m return } -func (s *membershipStatements) updateMembershipByEventID(oldEventID string, newEventID string) (err error) { - _, err = s.updateMembershipByEventIDStmt.Exec(oldEventID, newEventID) +func (s *membershipStatements) updateMembershipByEventID( + ctx context.Context, oldEventID string, newEventID string, +) (err error) { + _, err = s.updateMembershipByEventIDStmt.ExecContext( + ctx, oldEventID, newEventID, + ) return } diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/profile_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/profile_table.go index 2976c276..157bb99b 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/profile_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/profile_table.go @@ -15,6 +15,7 @@ package accounts import ( + "context" "database/sql" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -71,23 +72,36 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { return } -func (s *profilesStatements) insertProfile(localpart string) (err error) { - _, err = s.insertProfileStmt.Exec(localpart, "", "") +func (s *profilesStatements) insertProfile( + ctx context.Context, localpart string, +) (err error) { + _, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "") return } -func (s *profilesStatements) selectProfileByLocalpart(localpart string) (*authtypes.Profile, error) { +func (s *profilesStatements) selectProfileByLocalpart( + ctx context.Context, localpart string, +) (*authtypes.Profile, error) { var profile authtypes.Profile - err := s.selectProfileByLocalpartStmt.QueryRow(localpart).Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL) - return &profile, err + err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( + &profile.Localpart, &profile.DisplayName, &profile.AvatarURL, + ) + if err != nil { + return nil, err + } + return &profile, nil } -func (s *profilesStatements) setAvatarURL(localpart string, avatarURL string) (err error) { - _, err = s.setAvatarURLStmt.Exec(avatarURL, localpart) +func (s *profilesStatements) setAvatarURL( + ctx context.Context, localpart string, avatarURL string, +) (err error) { + _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) return } -func (s *profilesStatements) setDisplayName(localpart string, displayName string) (err error) { - _, err = s.setDisplayNameStmt.Exec(displayName, localpart) +func (s *profilesStatements) setDisplayName( + ctx context.Context, localpart string, displayName string, +) (err error) { + _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) return } diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go index c025853e..26607cdf 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go @@ -15,6 +15,7 @@ package accounts import ( + "context" "database/sql" "errors" @@ -74,46 +75,56 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByPassword(localpart, plaintextPassword string) (*authtypes.Account, error) { - hash, err := d.accounts.selectPasswordHash(localpart) +func (d *Database) GetAccountByPassword( + ctx context.Context, localpart, plaintextPassword string, +) (*authtypes.Account, error) { + hash, err := d.accounts.selectPasswordHash(ctx, localpart) if err != nil { return nil, err } if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { return nil, err } - return d.accounts.selectAccountByLocalpart(localpart) + return d.accounts.selectAccountByLocalpart(ctx, localpart) } // GetProfileByLocalpart returns the profile associated with the given localpart. // Returns sql.ErrNoRows if no profile exists which matches the given localpart. -func (d *Database) GetProfileByLocalpart(localpart string) (*authtypes.Profile, error) { - return d.profiles.selectProfileByLocalpart(localpart) +func (d *Database) GetProfileByLocalpart( + ctx context.Context, localpart string, +) (*authtypes.Profile, error) { + return d.profiles.selectProfileByLocalpart(ctx, localpart) } // SetAvatarURL updates the avatar URL of the profile associated with the given // localpart. Returns an error if something went wrong with the SQL query -func (d *Database) SetAvatarURL(localpart string, avatarURL string) error { - return d.profiles.setAvatarURL(localpart, avatarURL) +func (d *Database) SetAvatarURL( + ctx context.Context, localpart string, avatarURL string, +) error { + return d.profiles.setAvatarURL(ctx, localpart, avatarURL) } // SetDisplayName updates the display name of the profile associated with the given // localpart. Returns an error if something went wrong with the SQL query -func (d *Database) SetDisplayName(localpart string, displayName string) error { - return d.profiles.setDisplayName(localpart, displayName) +func (d *Database) SetDisplayName( + ctx context.Context, localpart string, displayName string, +) error { + return d.profiles.setDisplayName(ctx, localpart, displayName) } // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. -func (d *Database) CreateAccount(localpart, plaintextPassword string) (*authtypes.Account, error) { +func (d *Database) CreateAccount( + ctx context.Context, localpart, plaintextPassword string, +) (*authtypes.Account, error) { hash, err := hashPassword(plaintextPassword) if err != nil { return nil, err } - if err := d.profiles.insertProfile(localpart); err != nil { + if err := d.profiles.insertProfile(ctx, localpart); err != nil { return nil, err } - return d.accounts.insertAccount(localpart, hash) + return d.accounts.insertAccount(ctx, localpart, hash) } // PartitionOffsets implements common.PartitionStorer @@ -131,15 +142,19 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6 // is still in the room. // If a membership already exists between the user and the room, or of the // insert fails, returns the SQL error -func (d *Database) SaveMembership(localpart string, roomID string, eventID string, txn *sql.Tx) error { - return d.memberships.insertMembership(localpart, roomID, eventID, txn) +func (d *Database) saveMembership( + ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, +) error { + return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID) } // removeMembershipsByEventIDs removes the memberships of which the `join` membership // event ID is included in a given array of events IDs // If the removal fails, or if there is no membership to remove, returns an error -func (d *Database) removeMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) error { - return d.memberships.deleteMembershipsByEventIDs(eventIDs, txn) +func (d *Database) removeMembershipsByEventIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) error { + return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs) } // UpdateMemberships adds the "join" membership events included in a given state @@ -147,14 +162,16 @@ func (d *Database) removeMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) e // IDs. All of the process is run in a transaction, which commits only once/if every // insertion and deletion has been successfully processed. // Returns a SQL error if there was an issue with any part of the process -func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error { +func (d *Database) UpdateMemberships( + ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string, +) error { return common.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.removeMembershipsByEventIDs(idsToRemove, txn); err != nil { + if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil { return err } for _, event := range eventsToAdd { - if err := d.newMembership(event, txn); err != nil { + if err := d.newMembership(ctx, txn, event); err != nil { return err } } @@ -167,8 +184,10 @@ func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsT // the rooms a user matching a given localpart is a member of // If no membership match the given localpart, returns an empty array // If there was an issue during the retrieval, returns the SQL error -func (d *Database) GetMembershipsByLocalpart(localpart string) (memberships []authtypes.Membership, err error) { - return d.memberships.selectMembershipsByLocalpart(localpart) +func (d *Database) GetMembershipsByLocalpart( + ctx context.Context, localpart string, +) (memberships []authtypes.Membership, err error) { + return d.memberships.selectMembershipsByLocalpart(ctx, localpart) } // newMembership will save a new membership in the database, with a flag on whether @@ -178,7 +197,9 @@ func (d *Database) GetMembershipsByLocalpart(localpart string) (memberships []au // values, does nothing. // If the event isn't a "join" membership event, does nothing // If an error occurred, returns it -func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error { +func (d *Database) newMembership( + ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event, +) error { if ev.Type() == "m.room.member" && ev.StateKey() != nil { localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) if err != nil { @@ -199,7 +220,7 @@ func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error // Only "join" membership events can be considered as new memberships if membership == "join" { - if err := d.SaveMembership(localpart, roomID, eventID, txn); err != nil { + if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil { return err } } @@ -212,27 +233,33 @@ func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error // If an account data already exists for a given set (user, room, data type), it will // update the corresponding row with the new content // Returns a SQL error if there was an issue with the insertion/update -func (d *Database) SaveAccountData(localpart string, roomID string, dataType string, content string) error { - return d.accountDatas.insertAccountData(localpart, roomID, dataType, content) +func (d *Database) SaveAccountData( + ctx context.Context, localpart, roomID, dataType, content string, +) error { + return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content) } // GetAccountData returns account data related to a given localpart // If no account data could be found, returns an empty arrays // Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountData(localpart string) ( +func (d *Database) GetAccountData(ctx context.Context, localpart string) ( global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error, ) { - return d.accountDatas.selectAccountData(localpart) + return d.accountDatas.selectAccountData(ctx, localpart) } // GetAccountDataByType returns account data matching a given // localpart, room ID and type. // If no account data could be found, returns an empty array // Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountDataByType(localpart string, roomID string, dataType string) (data []gomatrixserverlib.ClientEvent, err error) { - return d.accountDatas.selectAccountDataByType(localpart, roomID, dataType) +func (d *Database) GetAccountDataByType( + ctx context.Context, localpart, roomID, dataType string, +) (data []gomatrixserverlib.ClientEvent, err error) { + return d.accountDatas.selectAccountDataByType( + ctx, localpart, roomID, dataType, + ) } func hashPassword(plaintext string) (hash string, err error) { @@ -248,9 +275,13 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use") // and a local Matrix user (identified by the user's ID's local part). // If the third-party identifier is already part of an association, returns Err3PIDInUse. // Returns an error if there was a problem talking to the database. -func (d *Database) SaveThreePIDAssociation(threepid string, localpart string, medium string) (err error) { +func (d *Database) SaveThreePIDAssociation( + ctx context.Context, threepid, localpart, medium string, +) (err error) { return common.WithTransaction(d.db, func(txn *sql.Tx) error { - user, err := d.threepids.selectLocalpartForThreePID(txn, threepid, medium) + user, err := d.threepids.selectLocalpartForThreePID( + ctx, txn, threepid, medium, + ) if err != nil { return err } @@ -259,7 +290,7 @@ func (d *Database) SaveThreePIDAssociation(threepid string, localpart string, me return Err3PIDInUse } - return d.threepids.insertThreePID(txn, threepid, medium, localpart) + return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) }) } @@ -267,8 +298,10 @@ func (d *Database) SaveThreePIDAssociation(threepid string, localpart string, me // identifier. // If no association exists involving this third-party identifier, returns nothing. // If there was a problem talking to the database, returns an error. -func (d *Database) RemoveThreePIDAssociation(threepid string, medium string) (err error) { - return d.threepids.deleteThreePID(threepid, medium) +func (d *Database) RemoveThreePIDAssociation( + ctx context.Context, threepid string, medium string, +) (err error) { + return d.threepids.deleteThreePID(ctx, threepid, medium) } // GetLocalpartForThreePID looks up the localpart associated with a given third-party @@ -276,14 +309,18 @@ func (d *Database) RemoveThreePIDAssociation(threepid string, medium string) (er // If no association involves the given third-party idenfitier, returns an empty // string. // Returns an error if there was a problem talking to the database. -func (d *Database) GetLocalpartForThreePID(threepid string, medium string) (localpart string, err error) { - return d.threepids.selectLocalpartForThreePID(nil, threepid, medium) +func (d *Database) GetLocalpartForThreePID( + ctx context.Context, threepid string, medium string, +) (localpart string, err error) { + return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) } // GetThreePIDsForLocalpart looks up the third-party identifiers associated with // a given local user. // If no association is known for this user, returns an empty slice. // Returns an error if there was an issue talking to the database. -func (d *Database) GetThreePIDsForLocalpart(localpart string) (threepids []authtypes.ThreePID, err error) { - return d.threepids.selectThreePIDsForLocalpart(localpart) +func (d *Database) GetThreePIDsForLocalpart( + ctx context.Context, localpart string, +) (threepids []authtypes.ThreePID, err error) { + return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) } diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/threepid_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/threepid_table.go index 55bcee6b..5900260a 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/threepid_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/threepid_table.go @@ -15,8 +15,11 @@ package accounts import ( + "context" "database/sql" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -76,22 +79,21 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { return } -func (s *threepidStatements) selectLocalpartForThreePID(txn *sql.Tx, threepid string, medium string) (localpart string, err error) { - var stmt *sql.Stmt - if txn != nil { - stmt = txn.Stmt(s.selectLocalpartForThreePIDStmt) - } else { - stmt = s.selectLocalpartForThreePIDStmt - } - err = stmt.QueryRow(threepid, medium).Scan(&localpart) +func (s *threepidStatements) selectLocalpartForThreePID( + ctx context.Context, txn *sql.Tx, threepid string, medium string, +) (localpart string, err error) { + stmt := common.TxStmt(txn, s.selectLocalpartForThreePIDStmt) + err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) if err == sql.ErrNoRows { return "", nil } return } -func (s *threepidStatements) selectThreePIDsForLocalpart(localpart string) (threepids []authtypes.ThreePID, err error) { - rows, err := s.selectThreePIDsForLocalpartStmt.Query(localpart) +func (s *threepidStatements) selectThreePIDsForLocalpart( + ctx context.Context, localpart string, +) (threepids []authtypes.ThreePID, err error) { + rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) if err != nil { return } @@ -103,18 +105,25 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(localpart string) (thre if err = rows.Scan(&threepid, &medium); err != nil { return } - threepids = append(threepids, authtypes.ThreePID{threepid, medium}) + threepids = append(threepids, authtypes.ThreePID{ + Address: threepid, + Medium: medium, + }) } return } -func (s *threepidStatements) insertThreePID(txn *sql.Tx, threepid string, medium string, localpart string) (err error) { - _, err = txn.Stmt(s.insertThreePIDStmt).Exec(threepid, medium, localpart) +func (s *threepidStatements) insertThreePID( + ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, +) (err error) { + stmt := common.TxStmt(txn, s.insertThreePIDStmt) + _, err = stmt.ExecContext(ctx, threepid, medium, localpart) return } -func (s *threepidStatements) deleteThreePID(threepid string, medium string) (err error) { - _, err = s.deleteThreePIDStmt.Exec(threepid, medium) +func (s *threepidStatements) deleteThreePID( + ctx context.Context, threepid string, medium string) (err error) { + _, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium) return } diff --git a/src/github.com/matrix-org/dendrite/clientapi/consumers/roomserver.go b/src/github.com/matrix-org/dendrite/clientapi/consumers/roomserver.go index 591354b6..d20cf6d2 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/consumers/roomserver.go +++ b/src/github.com/matrix-org/dendrite/clientapi/consumers/roomserver.go @@ -96,7 +96,7 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error { return err } - if err := s.db.UpdateMemberships(events, output.NewRoomEvent.RemovesStateEventIDs); err != nil { + if err := s.db.UpdateMemberships(context.TODO(), events, output.NewRoomEvent.RemovesStateEventIDs); err != nil { return err } diff --git a/src/github.com/matrix-org/dendrite/clientapi/readers/account_data.go b/src/github.com/matrix-org/dendrite/clientapi/readers/account_data.go index d3bb932d..bd64e2b0 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/readers/account_data.go +++ b/src/github.com/matrix-org/dendrite/clientapi/readers/account_data.go @@ -59,7 +59,9 @@ func SaveAccountData( return httputil.LogThenError(req, err) } - if err := accountDB.SaveAccountData(localpart, roomID, dataType, string(body)); err != nil { + if err := accountDB.SaveAccountData( + req.Context(), localpart, roomID, dataType, string(body), + ); err != nil { return httputil.LogThenError(req, err) } diff --git a/src/github.com/matrix-org/dendrite/clientapi/readers/login.go b/src/github.com/matrix-org/dendrite/clientapi/readers/login.go index 1f6b5071..027560cc 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/readers/login.go +++ b/src/github.com/matrix-org/dendrite/clientapi/readers/login.go @@ -79,7 +79,7 @@ func Login( util.GetLogger(req.Context()).WithField("user", r.User).Info("Processing login request") - acc, err := accountDB.GetAccountByPassword(r.User, r.Password) + acc, err := accountDB.GetAccountByPassword(req.Context(), r.User, r.Password) if err != nil { // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows // but that would leak the existence of the user. diff --git a/src/github.com/matrix-org/dendrite/clientapi/readers/profile.go b/src/github.com/matrix-org/dendrite/clientapi/readers/profile.go index 2047dc71..532f01e4 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/readers/profile.go +++ b/src/github.com/matrix-org/dendrite/clientapi/readers/profile.go @@ -60,7 +60,7 @@ func GetProfile( return httputil.LogThenError(req, err) } - profile, err := accountDB.GetProfileByLocalpart(localpart) + profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart) if err != nil { return httputil.LogThenError(req, err) } @@ -83,7 +83,7 @@ func GetAvatarURL( return httputil.LogThenError(req, err) } - profile, err := accountDB.GetProfileByLocalpart(localpart) + profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart) if err != nil { return httputil.LogThenError(req, err) } @@ -127,16 +127,16 @@ func SetAvatarURL( return httputil.LogThenError(req, err) } - oldProfile, err := accountDB.GetProfileByLocalpart(localpart) + oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart) if err != nil { return httputil.LogThenError(req, err) } - if err = accountDB.SetAvatarURL(localpart, r.AvatarURL); err != nil { + if err = accountDB.SetAvatarURL(req.Context(), localpart, r.AvatarURL); err != nil { return httputil.LogThenError(req, err) } - memberships, err := accountDB.GetMembershipsByLocalpart(localpart) + memberships, err := accountDB.GetMembershipsByLocalpart(req.Context(), localpart) if err != nil { return httputil.LogThenError(req, err) } @@ -175,7 +175,7 @@ func GetDisplayName( return httputil.LogThenError(req, err) } - profile, err := accountDB.GetProfileByLocalpart(localpart) + profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart) if err != nil { return httputil.LogThenError(req, err) } @@ -219,16 +219,16 @@ func SetDisplayName( return httputil.LogThenError(req, err) } - oldProfile, err := accountDB.GetProfileByLocalpart(localpart) + oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart) if err != nil { return httputil.LogThenError(req, err) } - if err = accountDB.SetDisplayName(localpart, r.DisplayName); err != nil { + if err = accountDB.SetDisplayName(req.Context(), localpart, r.DisplayName); err != nil { return httputil.LogThenError(req, err) } - memberships, err := accountDB.GetMembershipsByLocalpart(localpart) + memberships, err := accountDB.GetMembershipsByLocalpart(req.Context(), localpart) if err != nil { return httputil.LogThenError(req, err) } diff --git a/src/github.com/matrix-org/dendrite/clientapi/readers/threepid.go b/src/github.com/matrix-org/dendrite/clientapi/readers/threepid.go index b0d79287..b038e352 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/readers/threepid.go +++ b/src/github.com/matrix-org/dendrite/clientapi/readers/threepid.go @@ -49,7 +49,7 @@ func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg conf var err error // Check if the 3PID is already in use locally - localpart, err := accountDB.GetLocalpartForThreePID(body.Email, "email") + localpart, err := accountDB.GetLocalpartForThreePID(req.Context(), body.Email, "email") if err != nil { return httputil.LogThenError(req, err) } @@ -64,7 +64,7 @@ func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg conf } } - resp.SID, err = threepid.CreateSession(body, cfg) + resp.SID, err = threepid.CreateSession(req.Context(), body, cfg) if err == threepid.ErrNotTrusted { return util.JSONResponse{ Code: 400, @@ -91,7 +91,7 @@ func CheckAndSave3PIDAssociation( } // Check if the association has been validated - verified, address, medium, err := threepid.CheckAssociation(body.Creds, cfg) + verified, address, medium, err := threepid.CheckAssociation(req.Context(), body.Creds, cfg) if err == threepid.ErrNotTrusted { return util.JSONResponse{ Code: 400, @@ -130,7 +130,7 @@ func CheckAndSave3PIDAssociation( return httputil.LogThenError(req, err) } - if err = accountDB.SaveThreePIDAssociation(address, localpart, medium); err != nil { + if err = accountDB.SaveThreePIDAssociation(req.Context(), address, localpart, medium); err != nil { return httputil.LogThenError(req, err) } @@ -149,7 +149,7 @@ func GetAssociated3PIDs( return httputil.LogThenError(req, err) } - threepids, err := accountDB.GetThreePIDsForLocalpart(localpart) + threepids, err := accountDB.GetThreePIDsForLocalpart(req.Context(), localpart) if err != nil { return httputil.LogThenError(req, err) } @@ -167,7 +167,7 @@ func Forget3PID(req *http.Request, accountDB *accounts.Database) util.JSONRespon return *reqErr } - if err := accountDB.RemoveThreePIDAssociation(body.Address, body.Medium); err != nil { + if err := accountDB.RemoveThreePIDAssociation(req.Context(), body.Address, body.Medium); err != nil { return httputil.LogThenError(req, err) } diff --git a/src/github.com/matrix-org/dendrite/clientapi/threepid/invites.go b/src/github.com/matrix-org/dendrite/clientapi/threepid/invites.go index d02c598a..51c0dd9c 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/threepid/invites.go +++ b/src/github.com/matrix-org/dendrite/clientapi/threepid/invites.go @@ -102,7 +102,7 @@ func CheckAndProcessInvite( return } - lookupRes, storeInviteRes, err := queryIDServer(db, cfg, device, body, roomID) + lookupRes, storeInviteRes, err := queryIDServer(ctx, db, cfg, device, body, roomID) if err != nil { return } @@ -134,6 +134,7 @@ func CheckAndProcessInvite( // Returns a representation of the response for both cases. // Returns an error if a check or a request failed. func queryIDServer( + ctx context.Context, db *accounts.Database, cfg config.Dendrite, device *authtypes.Device, body *MembershipRequest, roomID string, ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { @@ -142,7 +143,7 @@ func queryIDServer( } // Lookup the 3PID - lookupRes, err = queryIDServerLookup(body) + lookupRes, err = queryIDServerLookup(ctx, body) if err != nil { return } @@ -150,7 +151,7 @@ func queryIDServer( if lookupRes.MXID == "" { // No Matrix ID matches with the given 3PID, ask the server to store the // invite and return a token - storeInviteRes, err = queryIDServerStoreInvite(db, cfg, device, body, roomID) + storeInviteRes, err = queryIDServerStoreInvite(ctx, db, cfg, device, body, roomID) return } @@ -161,11 +162,11 @@ func queryIDServer( if lookupRes.NotBefore > now || now > lookupRes.NotAfter { // If the current timestamp isn't in the time frame in which the association // is known to be valid, re-run the query - return queryIDServer(db, cfg, device, body, roomID) + return queryIDServer(ctx, db, cfg, device, body, roomID) } // Check the request signatures and send an error if one isn't valid - if err = checkIDServerSignatures(body, lookupRes); err != nil { + if err = checkIDServerSignatures(ctx, body, lookupRes); err != nil { return } @@ -175,10 +176,14 @@ func queryIDServer( // queryIDServerLookup sends a response to the identity server on /_matrix/identity/api/v1/lookup // and returns the response as a structure. // Returns an error if the request failed to send or if the response couldn't be parsed. -func queryIDServerLookup(body *MembershipRequest) (*idServerLookupResponse, error) { +func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServerLookupResponse, error) { address := url.QueryEscape(body.Address) url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/lookup?medium=%s&address=%s", body.IDServer, body.Medium, address) - resp, err := http.Get(url) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + resp, err := http.DefaultClient.Do(req.WithContext(ctx)) if err != nil { return nil, err } @@ -198,6 +203,7 @@ func queryIDServerLookup(body *MembershipRequest) (*idServerLookupResponse, erro // and returns the response as a structure. // Returns an error if the request failed to send or if the response couldn't be parsed. func queryIDServerStoreInvite( + ctx context.Context, db *accounts.Database, cfg config.Dendrite, device *authtypes.Device, body *MembershipRequest, roomID string, ) (*idServerStoreInviteResponse, error) { @@ -209,7 +215,7 @@ func queryIDServerStoreInvite( var profile *authtypes.Profile if serverName == cfg.Matrix.ServerName { - profile, err = db.GetProfileByLocalpart(localpart) + profile, err = db.GetProfileByLocalpart(ctx, localpart) if err != nil { return nil, err } @@ -239,7 +245,7 @@ func queryIDServerStoreInvite( } req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - resp, err := client.Do(req) + resp, err := client.Do(req.WithContext(ctx)) if err != nil { return nil, err } @@ -259,9 +265,13 @@ func queryIDServerStoreInvite( // We assume that the ID server is trusted at this point. // Returns an error if the request couldn't be sent, if its body couldn't be parsed // or if the key couldn't be decoded from base64. -func queryIDServerPubKey(idServerName string, keyID string) ([]byte, error) { +func queryIDServerPubKey(ctx context.Context, idServerName string, keyID string) ([]byte, error) { url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/pubkey/%s", idServerName, keyID) - resp, err := http.Get(url) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + resp, err := http.DefaultClient.Do(req.WithContext(ctx)) if err != nil { return nil, err } @@ -286,7 +296,9 @@ func queryIDServerPubKey(idServerName string, keyID string) ([]byte, error) { // We assume that the ID server is trusted at this point. // Returns nil if all the verifications succeeded. // Returns an error if something failed in the process. -func checkIDServerSignatures(body *MembershipRequest, res *idServerLookupResponse) error { +func checkIDServerSignatures( + ctx context.Context, body *MembershipRequest, res *idServerLookupResponse, +) error { // Mashall the body so we can give it to VerifyJSON marshalledBody, err := json.Marshal(*res) if err != nil { @@ -299,7 +311,7 @@ func checkIDServerSignatures(body *MembershipRequest, res *idServerLookupRespons } for keyID := range signatures { - pubKey, err := queryIDServerPubKey(body.IDServer, keyID) + pubKey, err := queryIDServerPubKey(ctx, body.IDServer, keyID) if err != nil { return err } diff --git a/src/github.com/matrix-org/dendrite/clientapi/threepid/threepid.go b/src/github.com/matrix-org/dendrite/clientapi/threepid/threepid.go index f07a3a14..1ca3a22a 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/threepid/threepid.go +++ b/src/github.com/matrix-org/dendrite/clientapi/threepid/threepid.go @@ -15,6 +15,7 @@ package threepid import ( + "context" "encoding/json" "errors" "fmt" @@ -51,7 +52,9 @@ type Credentials struct { // Returns the session's ID. // Returns an error if there was a problem sending the request or decoding the // response, or if the identity server responded with a non-OK status. -func CreateSession(req EmailAssociationRequest, cfg config.Dendrite) (string, error) { +func CreateSession( + ctx context.Context, req EmailAssociationRequest, cfg config.Dendrite, +) (string, error) { if err := isTrusted(req.IDServer, cfg); err != nil { return "", err } @@ -71,7 +74,7 @@ func CreateSession(req EmailAssociationRequest, cfg config.Dendrite) (string, er request.Header.Add("Content-Type", "application/x-www-form-urlencoded") client := http.Client{} - resp, err := client.Do(request) + resp, err := client.Do(request.WithContext(ctx)) if err != nil { return "", err } @@ -97,13 +100,19 @@ func CreateSession(req EmailAssociationRequest, cfg config.Dendrite) (string, er // identifier and its medium. // Returns an error if there was a problem sending the request or decoding the // response, or if the identity server responded with a non-OK status. -func CheckAssociation(creds Credentials, cfg config.Dendrite) (bool, string, string, error) { +func CheckAssociation( + ctx context.Context, creds Credentials, cfg config.Dendrite, +) (bool, string, string, error) { if err := isTrusted(creds.IDServer, cfg); err != nil { return false, "", "", err } url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", creds.IDServer, creds.SID, creds.Secret) - resp, err := http.Get(url) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return false, "", "", err + } + resp, err := http.DefaultClient.Do(req.WithContext(ctx)) if err != nil { return false, "", "", err } diff --git a/src/github.com/matrix-org/dendrite/clientapi/writers/createroom.go b/src/github.com/matrix-org/dendrite/clientapi/writers/createroom.go index 3ee044b9..d948c919 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/writers/createroom.go +++ b/src/github.com/matrix-org/dendrite/clientapi/writers/createroom.go @@ -127,7 +127,7 @@ func createRoom(req *http.Request, device *authtypes.Device, return httputil.LogThenError(req, err) } - profile, err := accountDB.GetProfileByLocalpart(localpart) + profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart) if err != nil { return httputil.LogThenError(req, err) } diff --git a/src/github.com/matrix-org/dendrite/clientapi/writers/joinroom.go b/src/github.com/matrix-org/dendrite/clientapi/writers/joinroom.go index 7eaed537..b6520d83 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/writers/joinroom.go +++ b/src/github.com/matrix-org/dendrite/clientapi/writers/joinroom.go @@ -57,7 +57,7 @@ func JoinRoomByIDOrAlias( return httputil.LogThenError(req, err) } - profile, err := accountDB.GetProfileByLocalpart(localpart) + profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart) if err != nil { return httputil.LogThenError(req, err) } diff --git a/src/github.com/matrix-org/dendrite/clientapi/writers/membership.go b/src/github.com/matrix-org/dendrite/clientapi/writers/membership.go index 45fea87f..41c42240 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/writers/membership.go +++ b/src/github.com/matrix-org/dendrite/clientapi/writers/membership.go @@ -121,7 +121,7 @@ func buildMembershipEvent( return nil, err } - profile, err := loadProfile(stateKey, cfg, accountDB) + profile, err := loadProfile(ctx, stateKey, cfg, accountDB) if err != nil { return nil, err } @@ -156,7 +156,9 @@ func buildMembershipEvent( // it if the user is local to this server, or returns an empty profile if not. // Returns an error if the retrieval failed or if the first parameter isn't a // valid Matrix ID. -func loadProfile(userID string, cfg config.Dendrite, accountDB *accounts.Database) (*authtypes.Profile, error) { +func loadProfile( + ctx context.Context, userID string, cfg config.Dendrite, accountDB *accounts.Database, +) (*authtypes.Profile, error) { localpart, serverName, err := gomatrixserverlib.SplitID('@', userID) if err != nil { return nil, err @@ -164,7 +166,7 @@ func loadProfile(userID string, cfg config.Dendrite, accountDB *accounts.Databas var profile *authtypes.Profile if serverName == cfg.Matrix.ServerName { - profile, err = accountDB.GetProfileByLocalpart(localpart) + profile, err = accountDB.GetProfileByLocalpart(ctx, localpart) } else { profile = &authtypes.Profile{} } diff --git a/src/github.com/matrix-org/dendrite/clientapi/writers/register.go b/src/github.com/matrix-org/dendrite/clientapi/writers/register.go index 34c9bd65..a6f2c387 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/writers/register.go +++ b/src/github.com/matrix-org/dendrite/clientapi/writers/register.go @@ -1,6 +1,7 @@ package writers import ( + "context" "fmt" "net/http" "time" @@ -134,7 +135,9 @@ func Register(req *http.Request, accountDB *accounts.Database, deviceDB *devices switch r.Auth.Type { case authtypes.LoginTypeDummy: // there is nothing to do - return completeRegistration(accountDB, deviceDB, r.Username, r.Password) + return completeRegistration( + req.Context(), accountDB, deviceDB, r.Username, r.Password, + ) default: return util.JSONResponse{ Code: 501, @@ -143,7 +146,12 @@ func Register(req *http.Request, accountDB *accounts.Database, deviceDB *devices } } -func completeRegistration(accountDB *accounts.Database, deviceDB *devices.Database, username, password string) util.JSONResponse { +func completeRegistration( + ctx context.Context, + accountDB *accounts.Database, + deviceDB *devices.Database, + username, password string, +) util.JSONResponse { if username == "" { return util.JSONResponse{ Code: 400, @@ -157,7 +165,7 @@ func completeRegistration(accountDB *accounts.Database, deviceDB *devices.Databa } } - acc, err := accountDB.CreateAccount(username, password) + acc, err := accountDB.CreateAccount(ctx, username, password) if err != nil { return util.JSONResponse{ Code: 500, diff --git a/src/github.com/matrix-org/dendrite/cmd/create-account/main.go b/src/github.com/matrix-org/dendrite/cmd/create-account/main.go index bd5d2f78..82f1fec3 100644 --- a/src/github.com/matrix-org/dendrite/cmd/create-account/main.go +++ b/src/github.com/matrix-org/dendrite/cmd/create-account/main.go @@ -15,6 +15,7 @@ package main import ( + "context" "flag" "fmt" "os" @@ -68,7 +69,7 @@ func main() { os.Exit(1) } - _, err = accountDB.CreateAccount(*username, *password) + _, err = accountDB.CreateAccount(context.Background(), *username, *password) if err != nil { fmt.Println(err.Error()) os.Exit(1) diff --git a/src/github.com/matrix-org/dendrite/federationapi/writers/threepid.go b/src/github.com/matrix-org/dendrite/federationapi/writers/threepid.go index 0374dc39..2d7bac3d 100644 --- a/src/github.com/matrix-org/dendrite/federationapi/writers/threepid.go +++ b/src/github.com/matrix-org/dendrite/federationapi/writers/threepid.go @@ -191,7 +191,7 @@ func createInviteFrom3PIDInvite( StateKey: &inv.MXID, } - profile, err := accountDB.GetProfileByLocalpart(localpart) + profile, err := accountDB.GetProfileByLocalpart(ctx, localpart) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/syncapi/sync/request.go b/src/github.com/matrix-org/dendrite/syncapi/sync/request.go index 5260a363..e8579271 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/sync/request.go +++ b/src/github.com/matrix-org/dendrite/syncapi/sync/request.go @@ -15,6 +15,7 @@ package sync import ( + "context" "net/http" "strconv" "time" @@ -29,6 +30,7 @@ const defaultTimelineLimit = 20 // syncRequest represents a /sync request, with sensible defaults/sanity checks applied. type syncRequest struct { + ctx context.Context userID string limit int timeout time.Duration @@ -47,6 +49,7 @@ func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) { } // TODO: Additional query params: set_presence, filter return &syncRequest{ + ctx: req.Context(), userID: userID, timeout: timeout, since: since, diff --git a/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go b/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go index a207b815..922ee5d8 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go +++ b/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go @@ -128,7 +128,7 @@ func (rp *RequestPool) appendAccountData( // already been sent. Instead, we send the whole batch. var global []gomatrixserverlib.ClientEvent var rooms map[string][]gomatrixserverlib.ClientEvent - global, rooms, err = rp.accountDB.GetAccountData(localpart) + global, rooms, err = rp.accountDB.GetAccountData(req.ctx, localpart) if err != nil { return nil, err } @@ -159,7 +159,9 @@ func (rp *RequestPool) appendAccountData( events := []gomatrixserverlib.ClientEvent{} // Request the missing data from the database for _, dataType := range dataTypes { - evs, err := rp.accountDB.GetAccountDataByType(localpart, roomID, dataType) + evs, err := rp.accountDB.GetAccountDataByType( + req.ctx, localpart, roomID, dataType, + ) if err != nil { return nil, err }