tulip/database/tx.go

295 lines
8.3 KiB
Go
Raw Normal View History

2023-09-06 22:20:09 +01:00
package database
import (
"database/sql"
"fmt"
"github.com/1f349/tulip/password"
"github.com/1f349/twofactor"
2023-09-06 22:20:09 +01:00
"github.com/go-oauth2/oauth2/v4"
"github.com/google/uuid"
"time"
)
func updatedAt() string {
return time.Now().UTC().Format(time.DateTime)
}
2023-09-06 22:20:09 +01:00
type Tx struct{ tx *sql.Tx }
func (t *Tx) Commit() error {
return t.tx.Commit()
}
func (t *Tx) Rollback() {
_ = t.tx.Rollback()
}
func (t *Tx) HasUser() error {
var exists bool
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users)`)
err := row.Scan(&exists)
if err != nil {
return err
}
if !exists {
return sql.ErrNoRows
}
return nil
}
func (t *Tx) InsertUser(name, un, pw, email string, role UserRole, active bool) error {
2023-09-06 22:20:09 +01:00
pwHash, err := password.HashPassword(pw)
if err != nil {
return err
}
_, err = t.tx.Exec(`INSERT INTO users (subject, name, username, password, email, role, updated_at, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, uuid.NewString(), name, un, pwHash, email, role, updatedAt(), active)
2023-09-06 22:20:09 +01:00
return err
}
func (t *Tx) CheckLogin(un, pw string) (*User, bool, error) {
2023-09-06 22:20:09 +01:00
var u User
var hasOtp bool
row := t.tx.QueryRow(`SELECT subject, password, EXISTS(SELECT 1 FROM otp WHERE otp.subject = users.subject) FROM users WHERE username = ?`, un)
err := row.Scan(&u.Sub, &u.Password, &hasOtp)
2023-09-06 22:20:09 +01:00
if err != nil {
return nil, false, err
2023-09-06 22:20:09 +01:00
}
err = password.CheckPasswordHash(u.Password, pw)
return &u, hasOtp, err
2023-09-06 22:20:09 +01:00
}
func (t *Tx) GetUserDisplayName(sub uuid.UUID) (*User, error) {
var u User
row := t.tx.QueryRow(`SELECT name FROM users WHERE subject = ? LIMIT 1`, sub.String())
err := row.Scan(&u.Name)
u.Sub = sub
return &u, err
}
func (t *Tx) GetUserRole(sub uuid.UUID) (UserRole, error) {
var r UserRole
row := t.tx.QueryRow(`SELECT role FROM users WHERE subject = ? LIMIT 1`, sub.String())
err := row.Scan(&r)
return r, err
}
2023-09-06 22:20:09 +01:00
func (t *Tx) GetUser(sub uuid.UUID) (*User, error) {
var u User
row := t.tx.QueryRow(`SELECT name, username, password, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, updated_at, active FROM users WHERE subject = ?`, sub.String())
2023-09-06 22:20:09 +01:00
err := row.Scan(&u.Name, &u.Username, &u.Password, &u.Picture, &u.Website, &u.Email, &u.EmailVerified, &u.Pronouns, &u.Birthdate, &u.ZoneInfo, &u.Locale, &u.UpdatedAt, &u.Active)
u.Sub = sub
return &u, err
}
func (t *Tx) GetUserEmail(sub uuid.UUID) (string, error) {
var email string
row := t.tx.QueryRow(`SELECT email FROM users WHERE subject = ?`, sub.String())
err := row.Scan(&email)
return email, err
}
2023-09-06 22:20:09 +01:00
func (t *Tx) ChangeUserPassword(sub uuid.UUID, pwOld, pwNew string) error {
q, err := t.tx.Query(`SELECT password FROM users WHERE subject = ?`, sub)
if err != nil {
return err
}
var pwHash string
if q.Next() {
err = q.Scan(&pwHash)
if err != nil {
return err
}
} else {
return fmt.Errorf("invalid user")
}
if err := q.Err(); err != nil {
return err
}
if err := q.Close(); err != nil {
return err
}
err = password.CheckPasswordHash(pwHash, pwOld)
if err != nil {
return err
}
pwNewHash, err := password.HashPassword(pwNew)
if err != nil {
return err
}
exec, err := t.tx.Exec(`UPDATE users SET password = ?, updated_at = ? WHERE subject = ? AND password = ?`, pwNewHash, updatedAt(), sub, pwHash)
2023-09-06 22:20:09 +01:00
if err != nil {
return err
}
affected, err := exec.RowsAffected()
if err != nil {
return err
}
if affected != 1 {
return fmt.Errorf("row wasn't updated")
}
return nil
}
func (t *Tx) ModifyUser(sub uuid.UUID, v *UserPatch) error {
exec, err := t.tx.Exec(
`UPDATE users
SET name = ?,
picture = ?,
website = ?,
pronouns = ?,
birthdate = ?,
zoneinfo = ?,
locale = ?,
2023-09-06 22:20:09 +01:00
updated_at = ?
WHERE subject = ?`,
v.Name,
v.Picture,
v.Website,
2023-09-06 22:20:09 +01:00
v.Pronouns.String(),
v.Birthdate,
2023-09-06 22:20:09 +01:00
v.ZoneInfo.String(),
v.Locale.String(),
updatedAt(),
2023-09-06 22:20:09 +01:00
sub,
)
if err != nil {
return err
}
affected, err := exec.RowsAffected()
if err != nil {
return err
}
if affected != 1 {
return fmt.Errorf("row wasn't updated")
}
return nil
}
func (t *Tx) SetTwoFactor(sub uuid.UUID, totp *twofactor.Totp) error {
u, err := totp.ToBytes()
if err != nil {
return err
}
_, err = t.tx.Exec(`INSERT INTO otp(subject, raw) VALUES (?, ?) ON CONFLICT(subject) DO UPDATE SET raw = excluded.raw`, sub.String(), u)
return err
}
func (t *Tx) GetTwoFactor(sub uuid.UUID, issuer string) (*twofactor.Totp, error) {
var u []byte
row := t.tx.QueryRow(`SELECT raw FROM otp WHERE subject = ?`, sub.String())
err := row.Scan(&u)
if err != nil {
return nil, err
}
return twofactor.TOTPFromBytes(u, issuer)
}
func (t *Tx) HasTwoFactor(sub uuid.UUID) (bool, error) {
var hasOtp bool
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM otp WHERE otp.subject = ?)`, sub)
err := row.Scan(&hasOtp)
if err != nil {
return false, err
}
return hasOtp, row.Err()
}
2023-09-06 22:20:09 +01:00
func (t *Tx) GetClientInfo(sub string) (oauth2.ClientInfo, error) {
var u ClientInfoDbOutput
row := t.tx.QueryRow(`SELECT secret, name, domain, sso, active FROM client_store WHERE subject = ? LIMIT 1`, sub)
err := row.Scan(&u.Secret, &u.Name, &u.Domain, &u.SSO, &u.Active)
u.Owner = sub
if !u.Active {
return nil, fmt.Errorf("client is not active")
}
2023-09-06 22:20:09 +01:00
return &u, err
}
func (t *Tx) GetAppList(offset int) ([]ClientInfoDbOutput, error) {
var u []ClientInfoDbOutput
row, err := t.tx.Query(`SELECT subject, name, domain, owner, sso, active FROM client_store LIMIT 25 OFFSET ?`, offset)
if err != nil {
return nil, err
}
defer row.Close()
for row.Next() {
var a ClientInfoDbOutput
err := row.Scan(&a.Sub, &a.Name, &a.Domain, &a.Owner, &a.SSO, &a.Active)
if err != nil {
return nil, err
}
u = append(u, a)
}
return u, row.Err()
2023-09-06 22:20:09 +01:00
}
func (t *Tx) InsertClientApp(name, domain string, sso, active bool, owner uuid.UUID) error {
u := uuid.New()
secret, err := password.GenerateApiSecret(70)
if err != nil {
return err
}
_, err = t.tx.Exec(`INSERT INTO client_store (subject, name, secret, domain, owner, sso, active) VALUES (?, ?, ?, ?, ?, ?, ?)`, u.String(), name, secret, domain, owner.String(), sso, active)
return err
}
func (t *Tx) UpdateClientApp(subject uuid.UUID, name, domain string, sso, active bool) error {
_, err := t.tx.Exec(`UPDATE client_store SET name = ?, domain = ?, sso = ?, active = ? WHERE subject = ?`, name, domain, sso, active, subject.String())
return err
}
func (t *Tx) ResetClientAppSecret(subject uuid.UUID, secret string) error {
secret, err := password.GenerateApiSecret(70)
if err != nil {
return err
}
_, err = t.tx.Exec(`UPDATE client_store SET secret = ? WHERE subject = ?`, secret, subject.String())
return err
}
func (t *Tx) GetUserList(offset int) ([]User, error) {
var u []User
row, err := t.tx.Query(`SELECT subject, name, username, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, role, updated_at, active FROM users LIMIT 25 OFFSET ?`, offset)
if err != nil {
return nil, err
}
for row.Next() {
var a User
err := row.Scan(&a.Sub, &a.Name, &a.Username, &a.Picture, &a.Website, &a.Email, &a.EmailVerified, &a.Pronouns, &a.Birthdate, &a.ZoneInfo, &a.Locale, &a.Role, &a.UpdatedAt, &a.Active)
if err != nil {
return nil, err
}
u = append(u, a)
}
return u, row.Err()
}
func (t *Tx) UpdateUser(subject uuid.UUID, role UserRole, active bool) error {
_, err := t.tx.Exec(`UPDATE users SET active = ?, role = ? WHERE subject = ?`, active, role, subject)
return err
}
type ClientInfoDbOutput struct {
Sub, Name, Secret, Domain, Owner string
SSO, Active bool
}
var _ oauth2.ClientInfo = &ClientInfoDbOutput{}
func (c *ClientInfoDbOutput) GetID() string { return c.Sub }
func (c *ClientInfoDbOutput) GetSecret() string { return c.Secret }
func (c *ClientInfoDbOutput) GetDomain() string { return c.Domain }
func (c *ClientInfoDbOutput) IsPublic() bool { return false }
func (c *ClientInfoDbOutput) GetUserID() string { return c.Owner }
// GetName is an extra field for the oauth handler to display the application
// name
func (c *ClientInfoDbOutput) GetName() string { return c.Name }
// IsSSO is an extra field for the oauth handler to skip the user input stage
// this is for trusted applications to get permissions without asking the user
func (c *ClientInfoDbOutput) IsSSO() bool { return c.SSO }
// IsActive is an extra field for the app manager to get the active state
func (c *ClientInfoDbOutput) IsActive() bool { return c.Active }