diff --git a/database/db-types.go b/database/db-types.go index 7887a9b..6e0e0e4 100644 --- a/database/db-types.go +++ b/database/db-types.go @@ -15,7 +15,6 @@ type User struct { Sub uuid.UUID `json:"sub"` Name string `json:"name,omitempty"` Username string `json:"username"` - Password string `json:"password"` Picture NullStringScanner `json:"picture,omitempty"` Website NullStringScanner `json:"website,omitempty"` Email string `json:"email"` @@ -104,6 +103,11 @@ func (u *UserPatch) ParseFromForm(v url.Values) (safeErrs []error) { return } +type ClientInfoDbOutput struct { + Sub, Name, Secret, Domain, Owner string + SSO, Active bool +} + var _ oauth2.ClientInfo = &ClientInfoDbOutput{} func (c *ClientInfoDbOutput) GetID() string { return c.Sub } diff --git a/database/tx.go b/database/tx.go index a8ff289..818fb52 100644 --- a/database/tx.go +++ b/database/tx.go @@ -49,13 +49,14 @@ func (t *Tx) InsertUser(name, un, pw, email string, role UserRole, active bool) func (t *Tx) CheckLogin(un, pw string) (*User, bool, bool, error) { var u User + var pwHash password.HashString var hasOtp, hasVerify bool row := t.tx.QueryRow(`SELECT subject, password, EXISTS(SELECT 1 FROM otp WHERE otp.subject = users.subject), email, email_verified FROM users WHERE username = ?`, un) - err := row.Scan(&u.Sub, &u.Password, &hasOtp, &u.Email, &hasVerify) + err := row.Scan(&u.Sub, &pwHash, &hasOtp, &u.Email, &hasVerify) if err != nil { return nil, false, false, err } - err = password.CheckPasswordHash(u.Password, pw) + err = password.CheckPasswordHash(pwHash, pw) return &u, hasOtp, hasVerify, err } @@ -76,8 +77,8 @@ func (t *Tx) GetUserRole(sub uuid.UUID) (UserRole, error) { func (t *Tx) GetUser(sub uuid.UUID) (*User, error) { var u User - row := t.tx.QueryRow(`SELECT name, username, password, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, updated_at, active FROM users WHERE subject = ?`, sub.String()) - err := row.Scan(&u.Name, &u.Username, &u.Password, &u.Picture, &u.Website, &u.Email, &u.EmailVerified, &u.Pronouns, &u.Birthdate, &u.ZoneInfo, &u.Locale, &u.UpdatedAt, &u.Active) + row := t.tx.QueryRow(`SELECT name, username, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, updated_at, active FROM users WHERE subject = ?`, sub.String()) + err := row.Scan(&u.Name, &u.Username, &u.Picture, &u.Website, &u.Email, &u.EmailVerified, &u.Pronouns, &u.Birthdate, &u.ZoneInfo, &u.Locale, &u.UpdatedAt, &u.Active) u.Sub = sub return &u, err } @@ -94,7 +95,7 @@ func (t *Tx) ChangeUserPassword(sub uuid.UUID, pwOld, pwNew string) error { if err != nil { return err } - var pwHash string + var pwHash password.HashString if q.Next() { err = q.Scan(&pwHash) if err != nil { @@ -275,7 +276,27 @@ func (t *Tx) VerifyUserEmail(sub uuid.UUID) error { return err } -type ClientInfoDbOutput struct { - Sub, Name, Secret, Domain, Owner string - SSO, Active bool +func (t *Tx) UserResetPassword(sub uuid.UUID, pw string) error { + hashPassword, err := password.HashPassword(pw) + if err != nil { + return err + } + exec, err := t.tx.Exec(`UPDATE users SET password = ?, updated_at = ? WHERE subject = ?`, hashPassword, updatedAt(), sub.String()) + if err != nil { + return err + } + affected, err := exec.RowsAffected() + if err != nil { + return err + } + if affected != 1 { + return fmt.Errorf("row wasn't updated") + } + return nil +} + +func (t *Tx) UserEmailExists(email string) (exists bool, err error) { + row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users WHERE email = ? and email_verified = 1)`, email) + err = row.Scan(&exists) + return } diff --git a/database/tx_test.go b/database/tx_test.go index 789d1da..09fef36 100644 --- a/database/tx_test.go +++ b/database/tx_test.go @@ -26,7 +26,7 @@ func TestTx_ChangeUserPassword(t *testing.T) { query, err := d.db.Query(`SELECT password FROM users WHERE subject = ? AND username = ?`, u.String(), "test") assert.NoError(t, err) assert.True(t, query.Next()) - var oldPw string + var oldPw password.HashString assert.NoError(t, query.Scan(&oldPw)) assert.NoError(t, password.CheckPasswordHash(oldPw, "new")) assert.NoError(t, query.Err()) diff --git a/pages/login.go.html b/pages/login.go.html index 4802a16..49cd4d7 100644 --- a/pages/login.go.html +++ b/pages/login.go.html @@ -26,6 +26,16 @@ + +