Prevent sql scanning into nil value in accounts_table (#479)

* Prevent sql scanning into nil value in accounts_table

Signed-off-by: Andrew Morgan <andrewm@matrix.org>

* Remove uneccessary logging, null checking

* Don't forget to set the localpart

* Simplify error checking
This commit is contained in:
Andrew Morgan 2018-06-29 03:55:29 -07:00 committed by GitHub
parent af08eea46d
commit 141fd91537
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -22,6 +22,8 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
) )
const accountsSchema = ` const accountsSchema = `
@ -121,14 +123,26 @@ func (s *accountsStatements) selectPasswordHash(
func (s *accountsStatements) selectAccountByLocalpart( func (s *accountsStatements) selectAccountByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (acc *authtypes.Account, err error) { ) (*authtypes.Account, error) {
var appserviceIDPtr sql.NullString
var acc authtypes.Account
stmt := s.selectAccountByLocalpartStmt stmt := s.selectAccountByLocalpartStmt
err = stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &acc.AppServiceID) err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
if err == nil { if err != nil {
acc.UserID = userutil.MakeUserID(localpart, s.serverName) if err != sql.ErrNoRows {
acc.ServerName = s.serverName log.WithError(err).Error("Unable to retrieve user from the db")
}
return nil, err
} }
return if appserviceIDPtr.Valid {
acc.AppServiceID = appserviceIDPtr.String
}
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
acc.ServerName = s.serverName
return &acc, nil
} }
func (s *accountsStatements) selectNewNumericLocalpart( func (s *accountsStatements) selectNewNumericLocalpart(