mirror of
https://github.com/1f349/tulip.git
synced 2024-09-19 18:16:28 +01:00
178 lines
4.5 KiB
Go
178 lines
4.5 KiB
Go
|
package database
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"fmt"
|
||
|
"github.com/1f349/tulip/password"
|
||
|
"github.com/go-oauth2/oauth2/v4"
|
||
|
"github.com/google/uuid"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
type Tx struct{ tx *sql.Tx }
|
||
|
|
||
|
func (t *Tx) Commit() error {
|
||
|
return t.tx.Commit()
|
||
|
}
|
||
|
|
||
|
func (t *Tx) Rollback() {
|
||
|
_ = t.tx.Rollback()
|
||
|
}
|
||
|
|
||
|
func (t *Tx) HasUser() error {
|
||
|
var exists bool
|
||
|
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users)`)
|
||
|
err := row.Scan(&exists)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if !exists {
|
||
|
return sql.ErrNoRows
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (t *Tx) InsertUser(un, pw, email string) error {
|
||
|
pwHash, err := password.HashPassword(pw)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
_, err = t.tx.Exec(`INSERT INTO users (subject, username, password, email) VALUES (?, ?, ?, ?)`, uuid.NewString(), un, pwHash, email)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (t *Tx) CheckLogin(un, pw string) (*User, error) {
|
||
|
var u User
|
||
|
row := t.tx.QueryRow(`SELECT subject, password FROM users WHERE username = ? LIMIT 1`, un)
|
||
|
err := row.Scan(&u.Sub, &u.Password)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
err = password.CheckPasswordHash(u.Password, pw)
|
||
|
return &u, err
|
||
|
}
|
||
|
|
||
|
func (t *Tx) GetUserDisplayName(sub uuid.UUID) (*User, error) {
|
||
|
var u User
|
||
|
row := t.tx.QueryRow(`SELECT name FROM users WHERE subject = ? LIMIT 1`, sub.String())
|
||
|
err := row.Scan(&u.Name)
|
||
|
u.Sub = sub
|
||
|
return &u, err
|
||
|
}
|
||
|
|
||
|
func (t *Tx) GetUser(sub uuid.UUID) (*User, error) {
|
||
|
var u User
|
||
|
row := t.tx.QueryRow(`SELECT name, username, password, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, updated_at, active FROM users WHERE subject = ? LIMIT 1`, sub.String())
|
||
|
err := row.Scan(&u.Name, &u.Username, &u.Password, &u.Picture, &u.Website, &u.Email, &u.EmailVerified, &u.Pronouns, &u.Birthdate, &u.ZoneInfo, &u.Locale, &u.UpdatedAt, &u.Active)
|
||
|
u.Sub = sub
|
||
|
return &u, err
|
||
|
}
|
||
|
|
||
|
func (t *Tx) ChangeUserPassword(sub uuid.UUID, pwOld, pwNew string) error {
|
||
|
q, err := t.tx.Query(`SELECT password FROM users WHERE subject = ?`, sub)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
var pwHash string
|
||
|
if q.Next() {
|
||
|
err = q.Scan(&pwHash)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
} else {
|
||
|
return fmt.Errorf("invalid user")
|
||
|
}
|
||
|
if err := q.Err(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if err := q.Close(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
err = password.CheckPasswordHash(pwHash, pwOld)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
pwNewHash, err := password.HashPassword(pwNew)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
exec, err := t.tx.Exec(`UPDATE users SET password = ?, updated_at = ? WHERE subject = ? AND password = ?`, pwNewHash, time.Now().Format(time.DateTime), sub, pwHash)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
affected, err := exec.RowsAffected()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if affected != 1 {
|
||
|
return fmt.Errorf("row wasn't updated")
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (t *Tx) ModifyUser(sub uuid.UUID, v *UserPatch) error {
|
||
|
exec, err := t.tx.Exec(
|
||
|
`UPDATE users
|
||
|
SET name = ifnull(?, name),
|
||
|
picture = ifnull(?, picture),
|
||
|
website = ifnull(?, website),
|
||
|
pronouns = ifnull(?, pronouns),
|
||
|
birthdate = ifnull(?, birthdate),
|
||
|
zoneinfo = ifnull(?, zoneinfo),
|
||
|
locale = ifnull(?, locale),
|
||
|
updated_at = ?
|
||
|
WHERE subject = ?`,
|
||
|
v.Name,
|
||
|
stringify(v.Picture),
|
||
|
stringify(v.Website),
|
||
|
v.Pronouns.String(),
|
||
|
sql.NullTime{Time: v.Birthdate, Valid: !v.Birthdate.IsZero()},
|
||
|
v.ZoneInfo.String(),
|
||
|
v.Locale.String(),
|
||
|
time.Now().Format(time.DateTime),
|
||
|
sub,
|
||
|
)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
affected, err := exec.RowsAffected()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if affected != 1 {
|
||
|
return fmt.Errorf("row wasn't updated")
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (t *Tx) GetClientInfo(sub string) (oauth2.ClientInfo, error) {
|
||
|
var u clientInfoDbOutput
|
||
|
row := t.tx.QueryRow(`SELECT secret, domain, sso, active FROM client_store WHERE subject = ? LIMIT 1`, sub)
|
||
|
err := row.Scan(&u.secret, &u.domain, &u.sso)
|
||
|
u.sub = sub
|
||
|
return &u, err
|
||
|
}
|
||
|
|
||
|
type clientInfoDbOutput struct {
|
||
|
sub, secret, domain string
|
||
|
sso bool
|
||
|
}
|
||
|
|
||
|
func (c *clientInfoDbOutput) GetID() string { return c.sub }
|
||
|
func (c *clientInfoDbOutput) GetSecret() string { return c.secret }
|
||
|
func (c *clientInfoDbOutput) GetDomain() string { return c.domain }
|
||
|
func (c *clientInfoDbOutput) IsPublic() bool { return false }
|
||
|
func (c *clientInfoDbOutput) GetUserID() string { return "" }
|
||
|
func (c *clientInfoDbOutput) IsSSO() bool { return c.sso }
|
||
|
|
||
|
func stringify(stringer fmt.Stringer) sql.NullString {
|
||
|
if stringer == nil {
|
||
|
return sql.NullString{}
|
||
|
}
|
||
|
return emptyToNull(stringer.String())
|
||
|
}
|
||
|
|
||
|
func emptyToNull(a string) sql.NullString {
|
||
|
return sql.NullString{String: a, Valid: a != ""}
|
||
|
}
|