Correct refactored database calls

This commit is contained in:
Melon 2024-03-11 12:39:52 +00:00
parent fbd49da2db
commit 37570e2157
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
39 changed files with 604 additions and 934 deletions

View File

@ -7,20 +7,16 @@ import (
) )
type ClientStore struct { type ClientStore struct {
db *database.DB db *database.Queries
} }
var _ oauth2.ClientStore = &ClientStore{} var _ oauth2.ClientStore = &ClientStore{}
func New(db *database.DB) *ClientStore { func New(db *database.Queries) *ClientStore {
return &ClientStore{db: db} return &ClientStore{db: db}
} }
func (c *ClientStore) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) { func (c *ClientStore) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) {
tx, err := c.db.BeginCtx(ctx) a, err := c.db.GetClientInfo(ctx, id)
if err != nil { return &a, err
return nil, err
}
defer tx.Rollback()
return tx.GetClientInfo(id)
} }

View File

@ -118,7 +118,7 @@ func genHmacKey() []byte {
return a return a
} }
func checkDbHasUser(db *database.DB) error { func checkDbHasUser(db *database.Queries) error {
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { if err != nil {
return fmt.Errorf("failed to start transaction: %w", err) return fmt.Errorf("failed to start transaction: %w", err)
@ -126,7 +126,7 @@ func checkDbHasUser(db *database.DB) error {
defer tx.Rollback() defer tx.Rollback()
if err := tx.HasUser(); err != nil { if err := tx.HasUser(); err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
_, err := tx.InsertUser("Admin", "admin", "admin", "admin@localhost", false, database.RoleAdmin, false) _, err := tx.InsertUser("Admin", "admin", "admin", "admin@localhost", false, types.RoleAdmin, false)
if err != nil { if err != nil {
return fmt.Errorf("failed to add user: %w", err) return fmt.Errorf("failed to add user: %w", err)
} }

View File

@ -0,0 +1,82 @@
package database
import (
"database/sql"
"fmt"
"github.com/1f349/tulip/database/types"
"github.com/MrMelon54/pronouns"
"github.com/go-oauth2/oauth2/v4"
"golang.org/x/text/language"
"net/url"
"time"
)
type UserPatch struct {
Name string
Picture string
Website string
Pronouns types.UserPronoun
Birthdate sql.NullTime
ZoneInfo types.UserZone
Locale types.UserLocale
}
func (u *UserPatch) ParseFromForm(v url.Values) (safeErrs []error) {
var err error
u.Name = v.Get("name")
u.Picture = v.Get("picture")
u.Website = v.Get("website")
if v.Has("reset_pronouns") {
u.Pronouns.Pronoun = pronouns.TheyThem
} else {
u.Pronouns.Pronoun, err = pronouns.FindPronoun(v.Get("pronouns"))
if err != nil {
safeErrs = append(safeErrs, fmt.Errorf("invalid pronoun selected"))
}
}
if v.Has("reset_birthdate") || v.Get("birthdate") == "" {
u.Birthdate = sql.NullTime{}
} else {
u.Birthdate = sql.NullTime{Valid: true}
u.Birthdate.Time, err = time.Parse(time.DateOnly, v.Get("birthdate"))
if err != nil {
safeErrs = append(safeErrs, fmt.Errorf("invalid time selected"))
}
}
if v.Has("reset_zoneinfo") {
u.ZoneInfo.Location = time.UTC
} else {
u.ZoneInfo.Location, err = time.LoadLocation(v.Get("zoneinfo"))
if err != nil {
safeErrs = append(safeErrs, fmt.Errorf("invalid timezone selected"))
}
}
if v.Has("reset_locale") {
u.Locale.Tag = language.AmericanEnglish
} else {
u.Locale.Tag, err = language.Parse(v.Get("locale"))
if err != nil {
safeErrs = append(safeErrs, fmt.Errorf("invalid language selected"))
}
}
return
}
var _ oauth2.ClientInfo = &ClientStore{}
func (c *ClientStore) GetID() string { return c.Subject }
func (c *ClientStore) GetSecret() string { return c.Secret }
func (c *ClientStore) GetDomain() string { return c.Domain }
func (c *ClientStore) IsPublic() bool { return c.Public }
func (c *ClientStore) GetUserID() string { return c.Owner }
// GetName is an extra field for the oauth handler to display the application
// name
func (c *ClientStore) GetName() string { return c.Name }
// IsSSO is an extra field for the oauth handler to skip the user input stage
// this is for trusted applications to get permissions without asking the user
func (c *ClientStore) IsSSO() bool { return c.Sso }
// IsActive is an extra field for the app manager to get the active state
func (c *ClientStore) IsActive() bool { return c.Active }

View File

@ -1,145 +0,0 @@
package database
import (
"database/sql"
"encoding/json"
"fmt"
"github.com/MrMelon54/pronouns"
"golang.org/x/text/language"
"time"
)
var (
_, _, _, _, _ sql.Scanner = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{}
_, _, _, _, _ json.Marshaler = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{}
_, _, _, _, _ json.Unmarshaler = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{}
)
func marshalValueOrNull(null bool, data any) ([]byte, error) {
if null {
return json.Marshal(nil)
}
return json.Marshal(data)
}
type NullStringScanner struct{ sql.NullString }
func (s *NullStringScanner) Null() bool { return !s.Valid }
func (s *NullStringScanner) Scan(src any) error { return s.NullString.Scan(src) }
func (s NullStringScanner) MarshalJSON() ([]byte, error) {
return marshalValueOrNull(s.Null(), s.NullString.String)
}
func (s *NullStringScanner) UnmarshalJSON(bytes []byte) error {
if string(bytes) == "null" {
return s.Scan(nil)
}
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return s.Scan(&a)
}
func (s NullStringScanner) String() string {
if s.Null() {
return ""
}
return s.NullString.String
}
type NullDateScanner struct{ sql.NullTime }
func (t *NullDateScanner) Null() bool { return !t.Valid }
func (t *NullDateScanner) Scan(src any) error { return t.NullTime.Scan(src) }
func (t NullDateScanner) MarshalJSON() ([]byte, error) {
return marshalValueOrNull(t.Null(), t.Time.UTC().Format(time.DateOnly))
}
func (t *NullDateScanner) UnmarshalJSON(bytes []byte) error {
if string(bytes) == "null" {
return t.Scan(nil)
}
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return t.Scan(&a)
}
func (t NullDateScanner) String() string {
if t.Null() {
return ""
}
return t.NullTime.Time.UTC().Format(time.DateOnly)
}
type LocationScanner struct{ *time.Location }
func (l *LocationScanner) Scan(src any) error {
s, ok := src.(string)
if !ok {
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
}
loc, err := time.LoadLocation(s)
if err != nil {
return err
}
l.Location = loc
return nil
}
func (l LocationScanner) MarshalJSON() ([]byte, error) { return json.Marshal(l.Location.String()) }
func (l *LocationScanner) UnmarshalJSON(bytes []byte) error {
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return l.Scan(a)
}
type LocaleScanner struct{ language.Tag }
func (l *LocaleScanner) Scan(src any) error {
s, ok := src.(string)
if !ok {
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
}
lang, err := language.Parse(s)
if err != nil {
return err
}
l.Tag = lang
return nil
}
func (l LocaleScanner) MarshalJSON() ([]byte, error) { return json.Marshal(l.Tag.String()) }
func (l *LocaleScanner) UnmarshalJSON(bytes []byte) error {
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return l.Scan(a)
}
type PronounScanner struct{ pronouns.Pronoun }
func (p *PronounScanner) Scan(src any) error {
s, ok := src.(string)
if !ok {
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, p)
}
pro, err := pronouns.FindPronoun(s)
if err != nil {
return err
}
p.Pronoun = pro
return nil
}
func (p PronounScanner) MarshalJSON() ([]byte, error) { return json.Marshal(p.Pronoun.String()) }
func (p *PronounScanner) UnmarshalJSON(bytes []byte) error {
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return p.Scan(a)
}

View File

@ -1,52 +0,0 @@
package database
import (
"database/sql"
"encoding/json"
"github.com/MrMelon54/pronouns"
"github.com/stretchr/testify/assert"
"golang.org/x/text/language"
"testing"
"time"
)
func encode(data any) string {
j, err := json.Marshal(map[string]any{"value": data})
if err != nil {
panic(err)
}
return string(j)
}
func TestStringScanner_MarshalJSON(t *testing.T) {
assert.Equal(t, "{\"value\":\"Hello world\"}", encode(NullStringScanner{sql.NullString{String: "Hello world", Valid: true}}))
assert.Equal(t, "{\"value\":null}", encode(NullStringScanner{sql.NullString{String: "Hello world", Valid: false}}))
}
func TestDateScanner_MarshalJSON(t *testing.T) {
location, err := time.LoadLocation("Europe/London")
assert.NoError(t, err)
assert.Equal(t, "{\"value\":\"2006-01-02\"}", encode(NullDateScanner{sql.NullTime{Time: time.Date(2006, time.January, 2, 0, 0, 0, 0, time.UTC), Valid: true}}))
assert.Equal(t, "{\"value\":\"2006-08-01\"}", encode(NullDateScanner{sql.NullTime{Time: time.Date(2006, time.August, 2, 0, 0, 0, 0, location), Valid: true}}))
assert.Equal(t, "{\"value\":null}", encode(NullDateScanner{}))
}
func TestLocationScanner_MarshalJSON(t *testing.T) {
location, err := time.LoadLocation("Europe/London")
assert.NoError(t, err)
assert.Equal(t, "{\"value\":\"Europe/London\"}", encode(LocationScanner{location}))
assert.Equal(t, "{\"value\":\"UTC\"}", encode(LocationScanner{time.UTC}))
}
func TestLocaleScanner_MarshalJSON(t *testing.T) {
assert.Equal(t, "{\"value\":\"en-US\"}", encode(LocaleScanner{language.AmericanEnglish}))
assert.Equal(t, "{\"value\":\"en-GB\"}", encode(LocaleScanner{language.BritishEnglish}))
}
func TestPronounScanner_MarshalJSON(t *testing.T) {
assert.Equal(t, "{\"value\":\"they/them\"}", encode(PronounScanner{pronouns.TheyThem}))
assert.Equal(t, "{\"value\":\"he/him\"}", encode(PronounScanner{pronouns.HeHim}))
assert.Equal(t, "{\"value\":\"she/her\"}", encode(PronounScanner{pronouns.SheHer}))
assert.Equal(t, "{\"value\":\"it/its\"}", encode(PronounScanner{pronouns.ItIts}))
assert.Equal(t, "{\"value\":\"one/one's\"}", encode(PronounScanner{pronouns.OneOnes}))
}

View File

@ -1,127 +0,0 @@
package database
import (
"database/sql"
"fmt"
"github.com/MrMelon54/pronouns"
"github.com/go-oauth2/oauth2/v4"
"golang.org/x/text/language"
"net/url"
"time"
)
type User struct {
Sub string `json:"sub"`
Name string `json:"name,omitempty"`
Username string `json:"username"`
Picture NullStringScanner `json:"picture,omitempty"`
Website NullStringScanner `json:"website,omitempty"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Pronouns PronounScanner `json:"pronouns,omitempty"`
Birthdate NullDateScanner `json:"birthdate,omitempty"`
ZoneInfo LocationScanner `json:"zoneinfo,omitempty"`
Locale LocaleScanner `json:"locale,omitempty"`
Role UserRole `json:"role"`
UpdatedAt time.Time `json:"updated_at"`
Active bool `json:"active"`
}
type UserRole int
const (
RoleMember UserRole = iota
RoleAdmin
RoleToDelete
)
func (r UserRole) String() string {
switch r {
case RoleMember:
return "Member"
case RoleAdmin:
return "Admin"
case RoleToDelete:
return "ToDelete"
}
return fmt.Sprintf("UserRole{ %d }", r)
}
func (r UserRole) IsValid() bool {
return r == RoleMember || r == RoleAdmin
}
type UserPatch struct {
Name string
Picture string
Website string
Pronouns pronouns.Pronoun
Birthdate sql.NullTime
ZoneInfo *time.Location
Locale language.Tag
}
func (u *UserPatch) ParseFromForm(v url.Values) (safeErrs []error) {
var err error
u.Name = v.Get("name")
u.Picture = v.Get("picture")
u.Website = v.Get("website")
if v.Has("reset_pronouns") {
u.Pronouns = pronouns.TheyThem
} else {
u.Pronouns, err = pronouns.FindPronoun(v.Get("pronouns"))
if err != nil {
safeErrs = append(safeErrs, fmt.Errorf("invalid pronoun selected"))
}
}
if v.Has("reset_birthdate") || v.Get("birthdate") == "" {
u.Birthdate = sql.NullTime{}
} else {
u.Birthdate = sql.NullTime{Valid: true}
u.Birthdate.Time, err = time.Parse(time.DateOnly, v.Get("birthdate"))
if err != nil {
safeErrs = append(safeErrs, fmt.Errorf("invalid time selected"))
}
}
if v.Has("reset_zoneinfo") {
u.ZoneInfo = time.UTC
} else {
u.ZoneInfo, err = time.LoadLocation(v.Get("zoneinfo"))
if err != nil {
safeErrs = append(safeErrs, fmt.Errorf("invalid timezone selected"))
}
}
if v.Has("reset_locale") {
u.Locale = language.AmericanEnglish
} else {
u.Locale, err = language.Parse(v.Get("locale"))
if err != nil {
safeErrs = append(safeErrs, fmt.Errorf("invalid language selected"))
}
}
return
}
type ClientInfoDbOutput struct {
Sub, Name, Secret, Domain, Owner string
Public, SSO, Active bool
}
var _ oauth2.ClientInfo = &ClientInfoDbOutput{}
func (c *ClientInfoDbOutput) GetID() string { return c.Sub }
func (c *ClientInfoDbOutput) GetSecret() string { return c.Secret }
func (c *ClientInfoDbOutput) GetDomain() string { return c.Domain }
func (c *ClientInfoDbOutput) IsPublic() bool { return c.Public }
func (c *ClientInfoDbOutput) GetUserID() string { return c.Owner }
// GetName is an extra field for the oauth handler to display the application
// name
func (c *ClientInfoDbOutput) GetName() string { return c.Name }
// IsSSO is an extra field for the oauth handler to skip the user input stage
// this is for trusted applications to get permissions without asking the user
func (c *ClientInfoDbOutput) IsSSO() bool { return c.SSO }
// IsActive is an extra field for the app manager to get the active state
func (c *ClientInfoDbOutput) IsActive() bool { return c.Active }

View File

@ -1,5 +0,0 @@
package database
import (
_ "github.com/mattn/go-sqlite3"
)

View File

@ -7,7 +7,6 @@ package database
import ( import (
"context" "context"
"database/sql"
) )
const getAppList = `-- name: GetAppList :many const getAppList = `-- name: GetAppList :many
@ -25,13 +24,13 @@ type GetAppListParams struct {
} }
type GetAppListRow struct { type GetAppListRow struct {
Subject string `json:"subject"` Subject string `json:"subject"`
Name string `json:"name"` Name string `json:"name"`
Domain string `json:"domain"` Domain string `json:"domain"`
Owner string `json:"owner"` Owner string `json:"owner"`
Public sql.NullInt64 `json:"public"` Public bool `json:"public"`
Sso sql.NullInt64 `json:"sso"` Sso bool `json:"sso"`
Active sql.NullInt64 `json:"active"` Active bool `json:"active"`
} }
func (q *Queries) GetAppList(ctx context.Context, arg GetAppListParams) ([]GetAppListRow, error) { func (q *Queries) GetAppList(ctx context.Context, arg GetAppListParams) ([]GetAppListRow, error) {
@ -66,28 +65,21 @@ func (q *Queries) GetAppList(ctx context.Context, arg GetAppListParams) ([]GetAp
} }
const getClientInfo = `-- name: GetClientInfo :one const getClientInfo = `-- name: GetClientInfo :one
SELECT secret, name, domain, public, sso, active SELECT subject, name, secret, domain, owner, public, sso, active
FROM client_store FROM client_store
WHERE subject = ? WHERE subject = ?
LIMIT 1 LIMIT 1
` `
type GetClientInfoRow struct { func (q *Queries) GetClientInfo(ctx context.Context, subject string) (ClientStore, error) {
Secret string `json:"secret"`
Name string `json:"name"`
Domain string `json:"domain"`
Public sql.NullInt64 `json:"public"`
Sso sql.NullInt64 `json:"sso"`
Active sql.NullInt64 `json:"active"`
}
func (q *Queries) GetClientInfo(ctx context.Context, subject string) (GetClientInfoRow, error) {
row := q.db.QueryRowContext(ctx, getClientInfo, subject) row := q.db.QueryRowContext(ctx, getClientInfo, subject)
var i GetClientInfoRow var i ClientStore
err := row.Scan( err := row.Scan(
&i.Secret, &i.Subject,
&i.Name, &i.Name,
&i.Secret,
&i.Domain, &i.Domain,
&i.Owner,
&i.Public, &i.Public,
&i.Sso, &i.Sso,
&i.Active, &i.Active,
@ -101,14 +93,14 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
` `
type InsertClientAppParams struct { type InsertClientAppParams struct {
Subject string `json:"subject"` Subject string `json:"subject"`
Name string `json:"name"` Name string `json:"name"`
Secret string `json:"secret"` Secret string `json:"secret"`
Domain string `json:"domain"` Domain string `json:"domain"`
Owner string `json:"owner"` Owner string `json:"owner"`
Public sql.NullInt64 `json:"public"` Public bool `json:"public"`
Sso sql.NullInt64 `json:"sso"` Sso bool `json:"sso"`
Active sql.NullInt64 `json:"active"` Active bool `json:"active"`
} }
func (q *Queries) InsertClientApp(ctx context.Context, arg InsertClientAppParams) error { func (q *Queries) InsertClientApp(ctx context.Context, arg InsertClientAppParams) error {
@ -137,13 +129,13 @@ WHERE subject = ?
` `
type UpdateClientAppParams struct { type UpdateClientAppParams struct {
Name string `json:"name"` Name string `json:"name"`
Domain string `json:"domain"` Domain string `json:"domain"`
Public sql.NullInt64 `json:"public"` Public bool `json:"public"`
Sso sql.NullInt64 `json:"sso"` Sso bool `json:"sso"`
Active sql.NullInt64 `json:"active"` Active bool `json:"active"`
Subject string `json:"subject"` Subject string `json:"subject"`
Owner string `json:"owner"` Owner string `json:"owner"`
} }
func (q *Queries) UpdateClientApp(ctx context.Context, arg UpdateClientAppParams) error { func (q *Queries) UpdateClientApp(ctx context.Context, arg UpdateClientAppParams) error {

View File

@ -7,7 +7,9 @@ package database
import ( import (
"context" "context"
"database/sql" "time"
"github.com/1f349/tulip/database/types"
) )
const getUserList = `-- name: GetUserList :many const getUserList = `-- name: GetUserList :many
@ -25,15 +27,15 @@ LIMIT 25 OFFSET ?
` `
type GetUserListRow struct { type GetUserListRow struct {
Subject string `json:"subject"` Subject string `json:"subject"`
Name string `json:"name"` Name string `json:"name"`
Username string `json:"username"` Username string `json:"username"`
Picture interface{} `json:"picture"` Picture string `json:"picture"`
Email string `json:"email"` Email string `json:"email"`
EmailVerified int64 `json:"email_verified"` EmailVerified bool `json:"email_verified"`
Role int64 `json:"role"` Role types.UserRole `json:"role"`
UpdatedAt sql.NullTime `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
Active sql.NullInt64 `json:"active"` Active bool `json:"active"`
} }
func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListRow, error) { func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListRow, error) {
@ -77,9 +79,9 @@ WHERE subject = ?
` `
type UpdateUserRoleParams struct { type UpdateUserRoleParams struct {
Active sql.NullInt64 `json:"active"` Active bool `json:"active"`
Role int64 `json:"role"` Role types.UserRole `json:"role"`
Subject string `json:"subject"` Subject string `json:"subject"`
} }
func (q *Queries) UpdateUserRole(ctx context.Context, arg UpdateUserRoleParams) error { func (q *Queries) UpdateUserRole(ctx context.Context, arg UpdateUserRoleParams) error {

View File

@ -0,0 +1,4 @@
DROP TABLE users;
DROP INDEX username_index;
DROP TABLE client_store;
DROP TABLE otp;

View File

@ -1,39 +1,39 @@
CREATE TABLE IF NOT EXISTS users CREATE TABLE users
( (
subject TEXT PRIMARY KEY UNIQUE NOT NULL, subject TEXT PRIMARY KEY UNIQUE NOT NULL,
name TEXT NOT NULL, name TEXT NOT NULL,
username TEXT UNIQUE NOT NULL, username TEXT UNIQUE NOT NULL,
password TEXT NOT NULL, password TEXT NOT NULL,
picture TEXT DEFAULT "" NOT NULL, picture TEXT DEFAULT '' NOT NULL,
website TEXT DEFAULT "" NOT NULL, website TEXT DEFAULT '' NOT NULL,
email TEXT NOT NULL, email TEXT NOT NULL,
email_verified INTEGER DEFAULT 0 NOT NULL, email_verified BOOLEAN DEFAULT 0 NOT NULL,
pronouns TEXT DEFAULT "they/them" NOT NULL, pronouns TEXT DEFAULT 'they/them' NOT NULL,
birthdate DATE, birthdate DATE,
zoneinfo TEXT DEFAULT "UTC" NOT NULL, zoneinfo TEXT DEFAULT 'UTC' NOT NULL,
locale TEXT DEFAULT "en-US" NOT NULL, locale TEXT DEFAULT 'en-US' NOT NULL,
role INTEGER DEFAULT 0 NOT NULL, role INTEGER DEFAULT 0 NOT NULL,
updated_at DATETIME, updated_at DATETIME NOT NULL,
registered INTEGER DEFAULT 0, registered DATETIME NOT NULL,
active INTEGER DEFAULT 1 active BOOLEAN DEFAULT 1 NOT NULL
); );
CREATE UNIQUE INDEX IF NOT EXISTS username_index ON users (username); CREATE UNIQUE INDEX username_index ON users (username);
CREATE TABLE IF NOT EXISTS client_store CREATE TABLE client_store
( (
subject TEXT PRIMARY KEY UNIQUE NOT NULL, subject TEXT PRIMARY KEY UNIQUE NOT NULL,
name TEXT NOT NULL, name TEXT NOT NULL,
secret TEXT UNIQUE NOT NULL, secret TEXT UNIQUE NOT NULL,
domain TEXT NOT NULL, domain TEXT NOT NULL,
owner TEXT NOT NULL, owner TEXT NOT NULL,
public INTEGER, public BOOLEAN NOT NULL,
sso INTEGER, sso BOOLEAN NOT NULL,
active INTEGER DEFAULT 1, active BOOLEAN DEFAULT 1 NOT NULL,
FOREIGN KEY (owner) REFERENCES users (subject) FOREIGN KEY (owner) REFERENCES users (subject)
); );
CREATE TABLE IF NOT EXISTS otp CREATE TABLE otp
( (
subject TEXT PRIMARY KEY UNIQUE NOT NULL, subject TEXT PRIMARY KEY UNIQUE NOT NULL,
secret TEXT NOT NULL, secret TEXT NOT NULL,

View File

@ -6,17 +6,21 @@ package database
import ( import (
"database/sql" "database/sql"
"time"
"github.com/1f349/tulip/database/types"
"github.com/1f349/tulip/password"
) )
type ClientStore struct { type ClientStore struct {
Subject string `json:"subject"` Subject string `json:"subject"`
Name string `json:"name"` Name string `json:"name"`
Secret string `json:"secret"` Secret string `json:"secret"`
Domain string `json:"domain"` Domain string `json:"domain"`
Owner string `json:"owner"` Owner string `json:"owner"`
Public sql.NullInt64 `json:"public"` Public bool `json:"public"`
Sso sql.NullInt64 `json:"sso"` Sso bool `json:"sso"`
Active sql.NullInt64 `json:"active"` Active bool `json:"active"`
} }
type Otp struct { type Otp struct {
@ -26,20 +30,20 @@ type Otp struct {
} }
type User struct { type User struct {
Subject string `json:"subject"` Subject string `json:"subject"`
Name string `json:"name"` Name string `json:"name"`
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password password.HashString `json:"password"`
Picture interface{} `json:"picture"` Picture string `json:"picture"`
Website interface{} `json:"website"` Website string `json:"website"`
Email string `json:"email"` Email string `json:"email"`
EmailVerified int64 `json:"email_verified"` EmailVerified bool `json:"email_verified"`
Pronouns interface{} `json:"pronouns"` Pronouns types.UserPronoun `json:"pronouns"`
Birthdate sql.NullTime `json:"birthdate"` Birthdate sql.NullTime `json:"birthdate"`
Zoneinfo interface{} `json:"zoneinfo"` Zoneinfo types.UserZone `json:"zoneinfo"`
Locale interface{} `json:"locale"` Locale types.UserLocale `json:"locale"`
Role int64 `json:"role"` Role types.UserRole `json:"role"`
UpdatedAt sql.NullTime `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
Registered sql.NullInt64 `json:"registered"` Registered time.Time `json:"registered"`
Active sql.NullInt64 `json:"active"` Active bool `json:"active"`
} }

View File

@ -0,0 +1,47 @@
package database
import (
"context"
"github.com/1f349/tulip/database/types"
"github.com/1f349/tulip/password"
"github.com/google/uuid"
"time"
)
type AddUserParams struct {
Name string `json:"name"`
Username string `json:"username"`
Password string `json:"password"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Role types.UserRole `json:"role"`
UpdatedAt time.Time `json:"updated_at"`
Active bool `json:"active"`
}
func (q *Queries) AddUser(ctx context.Context, arg AddUserParams) (string, error) {
pwHash, err := password.HashPassword(arg.Password)
if err != nil {
return "", err
}
a := addUserParams{
Subject: uuid.NewString(),
Name: arg.Name,
Username: arg.Username,
Password: pwHash,
Email: arg.Email,
EmailVerified: arg.EmailVerified,
Role: arg.Role,
UpdatedAt: arg.UpdatedAt,
Active: arg.Active,
}
return a.Subject, q.addUser(ctx, a)
}
type CheckLoginRow struct {
Subject string `json:"subject"`
Password password.HashString `json:"password"`
HasTwoFactor bool `json:"hasTwoFactor"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
}

View File

@ -1,5 +1,5 @@
-- name: GetClientInfo :one -- name: GetClientInfo :one
SELECT secret, name, domain, public, sso, active SELECT *
FROM client_store FROM client_store
WHERE subject = ? WHERE subject = ?
LIMIT 1; LIMIT 1;

View File

@ -13,22 +13,21 @@ WHERE username = ?
LIMIT 1; LIMIT 1;
-- name: GetUser :one -- name: GetUser :one
SELECT name, SELECT *
username,
picture,
website,
email,
email_verified,
pronouns,
birthdate,
zoneinfo,
locale,
updated_at,
active
FROM users FROM users
WHERE subject = ? WHERE subject = ?
LIMIT 1; LIMIT 1;
-- name: GetUserRole :one
SELECT role
FROM users
WHERE subject = ?;
-- name: GetUserDisplayName :one
SELECT name
FROM users
WHERE subject = ?;
-- name: getUserPassword :one -- name: getUserPassword :one
SELECT password SELECT password
FROM users FROM users
@ -68,3 +67,6 @@ WHERE otp.subject = ?;
SELECT secret, digits SELECT secret, digits
FROM otp FROM otp
WHERE subject = ?; WHERE subject = ?;
-- name: HasTwoFactor :one
SELECT cast(EXISTS(SELECT 1 FROM otp WHERE subject = ?) AS BOOLEAN);

View File

@ -1,302 +0,0 @@
package database
import (
"database/sql"
"fmt"
"github.com/1f349/tulip/password"
"github.com/go-oauth2/oauth2/v4"
"github.com/google/uuid"
"time"
)
func updatedAt() string {
return time.Now().UTC().Format(time.DateTime)
}
type Tx struct{ tx *sql.Tx }
func (t *Tx) Commit() error {
return t.tx.Commit()
}
func (t *Tx) Rollback() {
_ = t.tx.Rollback()
}
func (t *Tx) HasUser() error {
var exists bool
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users)`)
err := row.Scan(&exists)
if err != nil {
return err
}
if !exists {
return sql.ErrNoRows
}
return nil
}
func (t *Tx) InsertUser(name, un, pw, email string, verifyEmail bool, role UserRole, active bool) (uuid.UUID, error) {
pwHash, err := password.HashPassword(pw)
if err != nil {
return uuid.UUID{}, err
}
u := uuid.New()
_, err = t.tx.Exec(`INSERT INTO users (subject, name, username, password, email, email_verified, role, updated_at, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, u, name, un, pwHash, email, verifyEmail, role, updatedAt(), active)
return u, err
}
func (t *Tx) CheckLogin(un, pw string) (*User, bool, bool, error) {
var u User
var pwHash password.HashString
var hasOtp, hasVerify bool
row := t.tx.QueryRow(`SELECT subject, password, EXISTS(SELECT 1 FROM otp WHERE otp.subject = users.subject), email, email_verified FROM users WHERE username = ?`, un)
err := row.Scan(&u.Sub, &pwHash, &hasOtp, &u.Email, &hasVerify)
if err != nil {
return nil, false, false, err
}
err = password.CheckPasswordHash(pwHash, pw)
return &u, hasOtp, hasVerify, err
}
func (t *Tx) GetUserDisplayName(sub string) (*User, error) {
var u User
row := t.tx.QueryRow(`SELECT name FROM users WHERE subject = ? LIMIT 1`, sub)
err := row.Scan(&u.Name)
u.Sub = sub
return &u, err
}
func (t *Tx) GetUserRole(sub string) (UserRole, error) {
var r UserRole
row := t.tx.QueryRow(`SELECT role FROM users WHERE subject = ? LIMIT 1`, sub)
err := row.Scan(&r)
return r, err
}
func (t *Tx) GetUser(sub string) (*User, error) {
var u User
row := t.tx.QueryRow(`SELECT name, username, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, updated_at, active FROM users WHERE subject = ?`, sub)
err := row.Scan(&u.Name, &u.Username, &u.Picture, &u.Website, &u.Email, &u.EmailVerified, &u.Pronouns, &u.Birthdate, &u.ZoneInfo, &u.Locale, &u.UpdatedAt, &u.Active)
u.Sub = sub
return &u, err
}
func (t *Tx) GetUserEmail(sub string) (string, error) {
var email string
row := t.tx.QueryRow(`SELECT email FROM users WHERE subject = ?`, sub)
err := row.Scan(&email)
return email, err
}
func (t *Tx) ChangeUserPassword(sub, pwOld, pwNew string) error {
q, err := t.tx.Query(`SELECT password FROM users WHERE subject = ?`, sub)
if err != nil {
return err
}
var pwHash password.HashString
if q.Next() {
err = q.Scan(&pwHash)
if err != nil {
return err
}
} else {
return fmt.Errorf("invalid user")
}
if err := q.Err(); err != nil {
return err
}
if err := q.Close(); err != nil {
return err
}
err = password.CheckPasswordHash(pwHash, pwOld)
if err != nil {
return err
}
pwNewHash, err := password.HashPassword(pwNew)
if err != nil {
return err
}
exec, err := t.tx.Exec(`UPDATE users SET password = ?, updated_at = ? WHERE subject = ? AND password = ?`, pwNewHash, updatedAt(), sub, pwHash)
if err != nil {
return err
}
affected, err := exec.RowsAffected()
if err != nil {
return err
}
if affected != 1 {
return fmt.Errorf("row wasn't updated")
}
return nil
}
func (t *Tx) ModifyUser(sub string, v *UserPatch) error {
exec, err := t.tx.Exec(
`UPDATE users
SET name = ?,
picture = ?,
website = ?,
pronouns = ?,
birthdate = ?,
zoneinfo = ?,
locale = ?,
updated_at = ?
WHERE subject = ?`,
v.Name,
v.Picture,
v.Website,
v.Pronouns.String(),
v.Birthdate,
v.ZoneInfo.String(),
v.Locale.String(),
updatedAt(),
sub,
)
if err != nil {
return err
}
affected, err := exec.RowsAffected()
if err != nil {
return err
}
if affected != 1 {
return fmt.Errorf("row wasn't updated")
}
return nil
}
func (t *Tx) SetTwoFactor(sub string, secret string, digits int) error {
if secret == "" && digits == 0 {
_, err := t.tx.Exec(`DELETE FROM otp WHERE otp.subject = ?`, sub)
return err
}
_, err := t.tx.Exec(`INSERT INTO otp(subject, secret, digits) VALUES (?, ?, ?) ON CONFLICT(subject) DO UPDATE SET secret = excluded.secret, digits = excluded.digits`, sub, secret, digits)
return err
}
func (t *Tx) GetTwoFactor(sub string) (string, int, error) {
var secret string
var digits int
row := t.tx.QueryRow(`SELECT secret, digits FROM otp WHERE subject = ?`, sub)
err := row.Scan(&secret, &digits)
if err != nil {
return "", 0, err
}
return secret, digits, nil
}
func (t *Tx) HasTwoFactor(sub string) (bool, error) {
var hasOtp bool
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM otp WHERE otp.subject = ?)`, sub)
err := row.Scan(&hasOtp)
if err != nil {
return false, err
}
return hasOtp, row.Err()
}
func (t *Tx) GetClientInfo(sub string) (oauth2.ClientInfo, error) {
var u ClientInfoDbOutput
row := t.tx.QueryRow(`SELECT secret, name, domain, public, sso, active FROM client_store WHERE subject = ? LIMIT 1`, sub)
err := row.Scan(&u.Secret, &u.Name, &u.Domain, &u.Public, &u.SSO, &u.Active)
u.Owner = sub
if !u.Active {
return nil, fmt.Errorf("client is not active")
}
return &u, err
}
func (t *Tx) GetAppList(owner string, admin bool, offset int) ([]ClientInfoDbOutput, error) {
var u []ClientInfoDbOutput
row, err := t.tx.Query(`SELECT subject, name, domain, owner, public, sso, active FROM client_store WHERE owner = ? OR ? = 1 LIMIT 25 OFFSET ?`, owner, admin, offset)
if err != nil {
return nil, err
}
defer row.Close()
for row.Next() {
var a ClientInfoDbOutput
err := row.Scan(&a.Sub, &a.Name, &a.Domain, &a.Owner, &a.Public, &a.SSO, &a.Active)
if err != nil {
return nil, err
}
u = append(u, a)
}
return u, row.Err()
}
func (t *Tx) InsertClientApp(name, domain string, public, sso, active bool, owner string) error {
u := uuid.New()
secret, err := password.GenerateApiSecret(70)
if err != nil {
return err
}
_, err = t.tx.Exec(`INSERT INTO client_store (subject, name, secret, domain, owner, public, sso, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, u.String(), name, secret, domain, owner, public, sso, active)
return err
}
func (t *Tx) UpdateClientApp(subject, owner string, name, domain string, public, sso, active bool) error {
_, err := t.tx.Exec(`UPDATE client_store SET name = ?, domain = ?, public = ?, sso = ?, active = ? WHERE subject = ? AND owner = ?`, name, domain, public, sso, active, subject, owner)
return err
}
func (t *Tx) ResetClientAppSecret(subject, owner string) (string, error) {
secret, err := password.GenerateApiSecret(70)
if err != nil {
return "", err
}
_, err = t.tx.Exec(`UPDATE client_store SET secret = ? WHERE subject = ? AND owner = ?`, secret, subject, owner)
return secret, err
}
func (t *Tx) GetUserList(offset int) ([]User, error) {
var u []User
row, err := t.tx.Query(`SELECT subject, name, username, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, role, updated_at, active FROM users LIMIT 25 OFFSET ?`, offset)
if err != nil {
return nil, err
}
for row.Next() {
var a User
err := row.Scan(&a.Sub, &a.Name, &a.Username, &a.Picture, &a.Website, &a.Email, &a.EmailVerified, &a.Pronouns, &a.Birthdate, &a.ZoneInfo, &a.Locale, &a.Role, &a.UpdatedAt, &a.Active)
if err != nil {
return nil, err
}
u = append(u, a)
}
return u, row.Err()
}
func (t *Tx) UpdateUser(subject string, role UserRole, active bool) error {
_, err := t.tx.Exec(`UPDATE users SET active = ?, role = ? WHERE subject = ?`, active, role, subject)
return err
}
func (t *Tx) VerifyUserEmail(sub string) error {
_, err := t.tx.Exec(`UPDATE users SET email_verified = 1 WHERE subject = ?`, sub)
return err
}
func (t *Tx) UserResetPassword(sub string, pw string) error {
hashPassword, err := password.HashPassword(pw)
if err != nil {
return err
}
exec, err := t.tx.Exec(`UPDATE users SET password = ?, updated_at = ? WHERE subject = ?`, hashPassword, updatedAt(), sub)
if err != nil {
return err
}
affected, err := exec.RowsAffected()
if err != nil {
return err
}
if affected != 1 {
return fmt.Errorf("row wasn't updated")
}
return nil
}
func (t *Tx) UserEmailExists(email string) (exists bool, err error) {
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users WHERE email = ? and email_verified = 1)`, email)
err = row.Scan(&exists)
return
}

View File

@ -1,52 +0,0 @@
package database
import (
"github.com/1f349/tulip/password"
"github.com/MrMelon54/pronouns"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"golang.org/x/text/language"
"testing"
"time"
)
func TestTx_ChangeUserPassword(t *testing.T) {
u := uuid.New()
pw, err := password.HashPassword("test")
assert.NoError(t, err)
d, err := Open("file::memory:")
assert.NoError(t, err)
_, err = d.db.Exec(`INSERT INTO users (subject, name, username, password, email, updated_at) VALUES (?, ?, ?, ?, ?, ?)`, u.String(), "Test", "test", pw, "test@localhost", updatedAt())
assert.NoError(t, err)
tx, err := d.Begin()
assert.NoError(t, err)
err = tx.ChangeUserPassword(u.String(), "test", "new")
assert.NoError(t, err)
assert.NoError(t, tx.Commit())
query, err := d.db.Query(`SELECT password FROM users WHERE subject = ? AND username = ?`, u.String(), "test")
assert.NoError(t, err)
assert.True(t, query.Next())
var oldPw password.HashString
assert.NoError(t, query.Scan(&oldPw))
assert.NoError(t, password.CheckPasswordHash(oldPw, "new"))
assert.NoError(t, query.Err())
assert.NoError(t, query.Close())
}
func TestTx_ModifyUser(t *testing.T) {
u := uuid.New()
pw, err := password.HashPassword("test")
assert.NoError(t, err)
d, err := Open("file::memory:")
assert.NoError(t, err)
_, err = d.db.Exec(`INSERT INTO users (subject, name, username, password, email, updated_at) VALUES (?, ?, ?, ?, ?, ?)`, u.String(), "Test", "test", pw, "test@localhost", updatedAt())
assert.NoError(t, err)
tx, err := d.Begin()
assert.NoError(t, err)
assert.NoError(t, tx.ModifyUser(u.String(), &UserPatch{
Name: "example",
Pronouns: pronouns.TheyThem,
ZoneInfo: time.UTC,
Locale: language.AmericanEnglish,
}))
}

View File

@ -0,0 +1,38 @@
package types
import (
"database/sql"
"encoding/json"
"fmt"
"golang.org/x/text/language"
)
var (
_ sql.Scanner = &UserLocale{}
_ json.Marshaler = &UserLocale{}
_ json.Unmarshaler = &UserLocale{}
)
type UserLocale struct{ language.Tag }
func (l *UserLocale) Scan(src any) error {
s, ok := src.(string)
if !ok {
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
}
lang, err := language.Parse(s)
if err != nil {
return err
}
l.Tag = lang
return nil
}
func (l UserLocale) MarshalJSON() ([]byte, error) { return json.Marshal(l.Tag.String()) }
func (l *UserLocale) UnmarshalJSON(bytes []byte) error {
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return l.Scan(a)
}

View File

@ -0,0 +1,12 @@
package types
import (
"github.com/stretchr/testify/assert"
"golang.org/x/text/language"
"testing"
)
func TestUserLocale_MarshalJSON(t *testing.T) {
assert.Equal(t, "\"en-US\"", encode(UserLocale{language.AmericanEnglish}))
assert.Equal(t, "\"en-GB\"", encode(UserLocale{language.BritishEnglish}))
}

View File

@ -0,0 +1,38 @@
package types
import (
"database/sql"
"encoding/json"
"fmt"
"github.com/MrMelon54/pronouns"
)
var (
_ sql.Scanner = &UserPronoun{}
_ json.Marshaler = &UserPronoun{}
_ json.Unmarshaler = &UserPronoun{}
)
type UserPronoun struct{ pronouns.Pronoun }
func (p *UserPronoun) Scan(src any) error {
s, ok := src.(string)
if !ok {
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, p)
}
pro, err := pronouns.FindPronoun(s)
if err != nil {
return err
}
p.Pronoun = pro
return nil
}
func (p UserPronoun) MarshalJSON() ([]byte, error) { return json.Marshal(p.Pronoun.String()) }
func (p *UserPronoun) UnmarshalJSON(bytes []byte) error {
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return p.Scan(a)
}

View File

@ -0,0 +1,15 @@
package types
import (
"github.com/MrMelon54/pronouns"
"github.com/stretchr/testify/assert"
"testing"
)
func TestUserPronoun_MarshalJSON(t *testing.T) {
assert.Equal(t, "\"they/them\"", encode(UserPronoun{pronouns.TheyThem}))
assert.Equal(t, "\"he/him\"", encode(UserPronoun{pronouns.HeHim}))
assert.Equal(t, "\"she/her\"", encode(UserPronoun{pronouns.SheHer}))
assert.Equal(t, "\"it/its\"", encode(UserPronoun{pronouns.ItIts}))
assert.Equal(t, "\"one/one's\"", encode(UserPronoun{pronouns.OneOnes}))
}

View File

@ -0,0 +1,27 @@
package types
import "fmt"
type UserRole int64
const (
RoleMember UserRole = iota
RoleAdmin
RoleToDelete
)
func (r UserRole) String() string {
switch r {
case RoleMember:
return "Member"
case RoleAdmin:
return "Admin"
case RoleToDelete:
return "ToDelete"
}
return fmt.Sprintf("UserRole{ %d }", r)
}
func (r UserRole) IsValid() bool {
return r == RoleMember || r == RoleAdmin
}

View File

@ -0,0 +1,45 @@
package types
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"time"
)
var (
_ sql.Scanner = &UserZone{}
_ driver.Valuer = &UserZone{}
_ json.Marshaler = &UserZone{}
_ json.Unmarshaler = &UserZone{}
)
type UserZone struct{ *time.Location }
func (l *UserZone) Scan(src any) error {
s, ok := src.(string)
if !ok {
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
}
loc, err := time.LoadLocation(s)
if err != nil {
return err
}
l.Location = loc
return nil
}
func (l UserZone) Value() (driver.Value, error) {
return l.Location.String(), nil
}
func (l UserZone) MarshalJSON() ([]byte, error) {
return json.Marshal(l.Location.String())
}
func (l *UserZone) UnmarshalJSON(bytes []byte) error {
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return l.Scan(a)
}

View File

@ -0,0 +1,14 @@
package types
import (
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestUserZone_MarshalJSON(t *testing.T) {
location, err := time.LoadLocation("Europe/London")
assert.NoError(t, err)
assert.Equal(t, "\"Europe/London\"", encode(UserZone{location}))
assert.Equal(t, "\"UTC\"", encode(UserZone{time.UTC}))
}

View File

@ -0,0 +1,11 @@
package types
import "encoding/json"
func encode(data any) string {
j, err := json.Marshal(data)
if err != nil {
panic(err)
}
return string(j)
}

View File

@ -8,6 +8,10 @@ package database
import ( import (
"context" "context"
"database/sql" "database/sql"
"time"
"github.com/1f349/tulip/database/types"
"github.com/1f349/tulip/password"
) )
const deleteTwoFactor = `-- name: DeleteTwoFactor :exec const deleteTwoFactor = `-- name: DeleteTwoFactor :exec
@ -40,44 +44,20 @@ func (q *Queries) GetTwoFactor(ctx context.Context, subject string) (GetTwoFacto
} }
const getUser = `-- name: GetUser :one const getUser = `-- name: GetUser :one
SELECT name, SELECT subject, name, username, password, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, role, updated_at, registered, active
username,
picture,
website,
email,
email_verified,
pronouns,
birthdate,
zoneinfo,
locale,
updated_at,
active
FROM users FROM users
WHERE subject = ? WHERE subject = ?
LIMIT 1 LIMIT 1
` `
type GetUserRow struct { func (q *Queries) GetUser(ctx context.Context, subject string) (User, error) {
Name string `json:"name"`
Username string `json:"username"`
Picture interface{} `json:"picture"`
Website interface{} `json:"website"`
Email string `json:"email"`
EmailVerified int64 `json:"email_verified"`
Pronouns interface{} `json:"pronouns"`
Birthdate sql.NullTime `json:"birthdate"`
Zoneinfo interface{} `json:"zoneinfo"`
Locale interface{} `json:"locale"`
UpdatedAt sql.NullTime `json:"updated_at"`
Active sql.NullInt64 `json:"active"`
}
func (q *Queries) GetUser(ctx context.Context, subject string) (GetUserRow, error) {
row := q.db.QueryRowContext(ctx, getUser, subject) row := q.db.QueryRowContext(ctx, getUser, subject)
var i GetUserRow var i User
err := row.Scan( err := row.Scan(
&i.Subject,
&i.Name, &i.Name,
&i.Username, &i.Username,
&i.Password,
&i.Picture, &i.Picture,
&i.Website, &i.Website,
&i.Email, &i.Email,
@ -86,12 +66,51 @@ func (q *Queries) GetUser(ctx context.Context, subject string) (GetUserRow, erro
&i.Birthdate, &i.Birthdate,
&i.Zoneinfo, &i.Zoneinfo,
&i.Locale, &i.Locale,
&i.Role,
&i.UpdatedAt, &i.UpdatedAt,
&i.Registered,
&i.Active, &i.Active,
) )
return i, err return i, err
} }
const getUserDisplayName = `-- name: GetUserDisplayName :one
SELECT name
FROM users
WHERE subject = ?
`
func (q *Queries) GetUserDisplayName(ctx context.Context, subject string) (string, error) {
row := q.db.QueryRowContext(ctx, getUserDisplayName, subject)
var name string
err := row.Scan(&name)
return name, err
}
const getUserRole = `-- name: GetUserRole :one
SELECT role
FROM users
WHERE subject = ?
`
func (q *Queries) GetUserRole(ctx context.Context, subject string) (types.UserRole, error) {
row := q.db.QueryRowContext(ctx, getUserRole, subject)
var role types.UserRole
err := row.Scan(&role)
return role, err
}
const hasTwoFactor = `-- name: HasTwoFactor :one
SELECT cast(EXISTS(SELECT 1 FROM otp WHERE subject = ?) AS BOOLEAN)
`
func (q *Queries) HasTwoFactor(ctx context.Context, subject string) (bool, error) {
row := q.db.QueryRowContext(ctx, hasTwoFactor, subject)
var column_1 bool
err := row.Scan(&column_1)
return column_1, err
}
const hasUser = `-- name: HasUser :one const hasUser = `-- name: HasUser :one
SELECT cast(count(subject) AS BOOLEAN) AS hasUser SELECT cast(count(subject) AS BOOLEAN) AS hasUser
FROM users FROM users
@ -118,15 +137,15 @@ WHERE subject = ?
` `
type ModifyUserParams struct { type ModifyUserParams struct {
Name string `json:"name"` Name string `json:"name"`
Picture interface{} `json:"picture"` Picture string `json:"picture"`
Website interface{} `json:"website"` Website string `json:"website"`
Pronouns interface{} `json:"pronouns"` Pronouns types.UserPronoun `json:"pronouns"`
Birthdate sql.NullTime `json:"birthdate"` Birthdate sql.NullTime `json:"birthdate"`
Zoneinfo interface{} `json:"zoneinfo"` Zoneinfo types.UserZone `json:"zoneinfo"`
Locale interface{} `json:"locale"` Locale types.UserLocale `json:"locale"`
UpdatedAt sql.NullTime `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
Subject string `json:"subject"` Subject string `json:"subject"`
} }
func (q *Queries) ModifyUser(ctx context.Context, arg ModifyUserParams) (int64, error) { func (q *Queries) ModifyUser(ctx context.Context, arg ModifyUserParams) (int64, error) {
@ -171,15 +190,15 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
type addUserParams struct { type addUserParams struct {
Subject string `json:"subject"` Subject string `json:"subject"`
Name string `json:"name"` Name string `json:"name"`
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password password.HashString `json:"password"`
Email string `json:"email"` Email string `json:"email"`
EmailVerified int64 `json:"email_verified"` EmailVerified bool `json:"email_verified"`
Role int64 `json:"role"` Role types.UserRole `json:"role"`
UpdatedAt sql.NullTime `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
Active sql.NullInt64 `json:"active"` Active bool `json:"active"`
} }
func (q *Queries) addUser(ctx context.Context, arg addUserParams) error { func (q *Queries) addUser(ctx context.Context, arg addUserParams) error {
@ -206,10 +225,10 @@ WHERE subject = ?
` `
type changeUserPasswordParams struct { type changeUserPasswordParams struct {
Password string `json:"password"` Password password.HashString `json:"password"`
UpdatedAt sql.NullTime `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
Subject string `json:"subject"` Subject string `json:"subject"`
Password_2 string `json:"password_2"` Password_2 password.HashString `json:"password_2"`
} }
func (q *Queries) changeUserPassword(ctx context.Context, arg changeUserPasswordParams) (int64, error) { func (q *Queries) changeUserPassword(ctx context.Context, arg changeUserPasswordParams) (int64, error) {
@ -233,11 +252,11 @@ LIMIT 1
` `
type checkLoginRow struct { type checkLoginRow struct {
Subject string `json:"subject"` Subject string `json:"subject"`
Password string `json:"password"` Password password.HashString `json:"password"`
Column3 int64 `json:"column_3"` Column3 int64 `json:"column_3"`
Email string `json:"email"` Email string `json:"email"`
EmailVerified int64 `json:"email_verified"` EmailVerified bool `json:"email_verified"`
} }
func (q *Queries) checkLogin(ctx context.Context, username string) (checkLoginRow, error) { func (q *Queries) checkLogin(ctx context.Context, username string) (checkLoginRow, error) {
@ -259,9 +278,9 @@ FROM users
WHERE subject = ? WHERE subject = ?
` `
func (q *Queries) getUserPassword(ctx context.Context, subject string) (string, error) { func (q *Queries) getUserPassword(ctx context.Context, subject string) (password.HashString, error) {
row := q.db.QueryRowContext(ctx, getUserPassword, subject) row := q.db.QueryRowContext(ctx, getUserPassword, subject)
var password string var password password.HashString
err := row.Scan(&password) err := row.Scan(&password)
return password, err return password, err
} }

View File

@ -4,6 +4,7 @@ import (
"github.com/1f349/mjwt" "github.com/1f349/mjwt"
"github.com/1f349/mjwt/auth" "github.com/1f349/mjwt/auth"
"github.com/1f349/tulip/database" "github.com/1f349/tulip/database"
"github.com/1f349/tulip/database/types"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"net/http" "net/http"
"net/url" "net/url"
@ -30,14 +31,14 @@ func (u UserAuth) IsGuest() bool {
func (h *HttpServer) RequireAdminAuthentication(next UserHandler) httprouter.Handle { func (h *HttpServer) RequireAdminAuthentication(next UserHandler) httprouter.Handle {
return h.RequireAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { return h.RequireAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
var role database.UserRole var role types.UserRole
if h.DbTx(rw, func(tx *database.Tx) (err error) { if h.DbTx(rw, func(tx *database.Queries) (err error) {
role, err = tx.GetUserRole(auth.ID) role, err = tx.GetUserRole(req.Context(), auth.ID)
return return
}) { }) {
return return
} }
if role != database.RoleAdmin { if role != types.RoleAdmin {
http.Error(rw, "403 Forbidden", http.StatusForbidden) http.Error(rw, "403 Forbidden", http.StatusForbidden)
return return
} }

View File

@ -9,25 +9,13 @@ import (
// DbTx wraps a database transaction with http error messages and a simple action // DbTx wraps a database transaction with http error messages and a simple action
// function. If the action function returns an error the transaction will be // function. If the action function returns an error the transaction will be
// rolled back. If there is no error then the transaction is committed. // rolled back. If there is no error then the transaction is committed.
func (h *HttpServer) DbTx(rw http.ResponseWriter, action func(tx *database.Tx) error) bool { func (h *HttpServer) DbTx(rw http.ResponseWriter, action func(db *database.Queries) error) bool {
tx, err := h.db.Begin() err := action(h.db)
if err != nil {
http.Error(rw, "Failed to begin database transaction", http.StatusInternalServerError)
return true
}
defer tx.Rollback()
err = action(tx)
if err != nil { if err != nil {
http.Error(rw, "Database error", http.StatusInternalServerError) http.Error(rw, "Database error", http.StatusInternalServerError)
log.Println("Database action error:", err) log.Println("Database action error:", err)
return true return true
} }
err = tx.Commit()
if err != nil {
http.Error(rw, "Database error", http.StatusInternalServerError)
log.Println("Database commit error:", err)
}
return false return false
} }

View File

@ -11,12 +11,12 @@ import (
"time" "time"
) )
func (h *HttpServer) EditGet(rw http.ResponseWriter, _ *http.Request, _ httprouter.Params, auth UserAuth) { func (h *HttpServer) EditGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
var user *database.User var user database.User
if h.DbTx(rw, func(tx *database.Tx) error { if h.DbTx(rw, func(tx *database.Queries) error {
var err error var err error
user, err = tx.GetUser(auth.ID) user, err = tx.GetUser(req.Context(), auth.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to read user data: %w", err) return fmt.Errorf("failed to read user data: %w", err)
} }
@ -64,8 +64,19 @@ func (h *HttpServer) EditPost(rw http.ResponseWriter, req *http.Request, _ httpr
_, _ = fmt.Fprintln(rw, "</body>\n</html>") _, _ = fmt.Fprintln(rw, "</body>\n</html>")
return return
} }
if h.DbTx(rw, func(tx *database.Tx) error { m := database.ModifyUserParams{
if err := tx.ModifyUser(auth.ID, &patch); err != nil { Name: patch.Name,
Picture: patch.Picture,
Website: patch.Website,
Pronouns: patch.Pronouns,
Birthdate: patch.Birthdate,
Zoneinfo: patch.ZoneInfo,
Locale: patch.Locale,
UpdatedAt: time.Now(),
Subject: auth.ID,
}
if h.DbTx(rw, func(tx *database.Queries) error {
if _, err := tx.ModifyUser(req.Context(), m); err != nil {
return fmt.Errorf("failed to modify user info: %w", err) return fmt.Errorf("failed to modify user info: %w", err)
} }
return nil return nil

View File

@ -3,6 +3,7 @@ package server
import ( import (
"fmt" "fmt"
"github.com/1f349/tulip/database" "github.com/1f349/tulip/database"
"github.com/1f349/tulip/database/types"
"github.com/1f349/tulip/pages" "github.com/1f349/tulip/pages"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
@ -29,18 +30,19 @@ func (h *HttpServer) Home(rw http.ResponseWriter, req *http.Request, _ httproute
return return
} }
var userWithName *database.User var userWithName string
var userRole types.UserRole
var hasTwoFactor bool var hasTwoFactor bool
if h.DbTx(rw, func(tx *database.Tx) (err error) { if h.DbTx(rw, func(tx *database.Queries) (err error) {
userWithName, err = tx.GetUserDisplayName(auth.ID) userWithName, err = tx.GetUserDisplayName(req.Context(), auth.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user display name: %w", err) return fmt.Errorf("failed to get user display name: %w", err)
} }
hasTwoFactor, err = tx.HasTwoFactor(auth.ID) hasTwoFactor, err = tx.HasTwoFactor(req.Context(), auth.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user two factor state: %w", err) return fmt.Errorf("failed to get user two factor state: %w", err)
} }
userWithName.Role, err = tx.GetUserRole(auth.ID) userRole, err = tx.GetUserRole(req.Context(), auth.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user role: %w", err) return fmt.Errorf("failed to get user role: %w", err)
} }
@ -54,6 +56,6 @@ func (h *HttpServer) Home(rw http.ResponseWriter, req *http.Request, _ httproute
"User": userWithName, "User": userWithName,
"Nonce": lNonce, "Nonce": lNonce,
"OtpEnabled": hasTwoFactor, "OtpEnabled": hasTwoFactor,
"IsAdmin": userWithName.Role, "IsAdmin": userRole == types.RoleAdmin,
}) })
} }

View File

@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"github.com/1f349/mjwt" "github.com/1f349/mjwt"
"github.com/1f349/tulip/database" "github.com/1f349/tulip/database"
"github.com/go-oauth2/oauth2/v4" "github.com/go-oauth2/oauth2/v4"
@ -9,7 +10,7 @@ import (
"strings" "strings"
) )
func addIdTokenSupport(srv *server.Server, db *database.DB, key mjwt.Signer) { func addIdTokenSupport(srv *server.Server, db *database.Queries, key mjwt.Signer) {
srv.SetExtensionFieldsHandler(func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) { srv.SetExtensionFieldsHandler(func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) {
scope := ti.GetScope() scope := ti.GetScope()
if containsScope(scope, "openid") { if containsScope(scope, "openid") {
@ -30,18 +31,13 @@ type IdTokenClaims struct{}
func (a IdTokenClaims) Valid() error { return nil } func (a IdTokenClaims) Valid() error { return nil }
func (a IdTokenClaims) Type() string { return "access-token" } func (a IdTokenClaims) Type() string { return "access-token" }
func generateIDToken(ti oauth2.TokenInfo, us *database.DB, key mjwt.Signer) (token string, err error) { func generateIDToken(ti oauth2.TokenInfo, us *database.Queries, key mjwt.Signer) (token string, err error) {
tx, err := us.Begin() user, err := us.GetUser(context.Background(), ti.GetUserID())
if err != nil { if err != nil {
return "", err return "", err
} }
user, err := tx.GetUser(ti.GetUserID())
if err != nil {
return "", err
}
tx.Rollback()
token, err = key.GenerateJwt(user.Sub, "", jwt.ClaimStrings{ti.GetClientID()}, ti.GetAccessExpiresIn(), IdTokenClaims{}) token, err = key.GenerateJwt(user.Subject, "", jwt.ClaimStrings{ti.GetClientID()}, ti.GetAccessExpiresIn(), IdTokenClaims{})
return return
} }

View File

@ -62,7 +62,7 @@ func (h *HttpServer) LoginPost(rw http.ResponseWriter, req *http.Request, _ http
var loginMismatch byte var loginMismatch byte
var hasOtp bool var hasOtp bool
if h.DbTx(rw, func(tx *database.Tx) error { if h.DbTx(rw, func(tx *database.Queries) error {
loginUser, hasOtpRaw, hasVerifiedEmail, err := tx.CheckLogin(un, pw) loginUser, hasOtpRaw, hasVerifiedEmail, err := tx.CheckLogin(un, pw)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) || errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { if errors.Is(err, sql.ErrNoRows) || errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
@ -176,7 +176,7 @@ func (h *HttpServer) LoginResetPasswordPost(rw http.ResponseWriter, req *http.Re
} }
var emailExists bool var emailExists bool
if h.DbTx(rw, func(tx *database.Tx) (err error) { if h.DbTx(rw, func(tx *database.Queries) (err error) {
emailExists, err = tx.UserEmailExists(email) emailExists, err = tx.UserEmailExists(email)
return err return err
}) { }) {

View File

@ -18,7 +18,7 @@ func (h *HttpServer) MailVerify(rw http.ResponseWriter, _ *http.Request, params
http.Error(rw, "Invalid email verification code", http.StatusBadRequest) http.Error(rw, "Invalid email verification code", http.StatusBadRequest)
return return
} }
if h.DbTx(rw, func(tx *database.Tx) error { if h.DbTx(rw, func(tx *database.Queries) error {
return tx.VerifyUserEmail(userSub) return tx.VerifyUserEmail(userSub)
}) { }) {
return return
@ -75,7 +75,7 @@ func (h *HttpServer) MailPasswordPost(rw http.ResponseWriter, req *http.Request,
h.mailLinkCache.Delete(k) h.mailLinkCache.Delete(k)
// reset password database call // reset password database call
if h.DbTx(rw, func(tx *database.Tx) error { if h.DbTx(rw, func(tx *database.Queries) error {
return tx.UserResetPassword(userSub, pw) return tx.UserResetPassword(userSub, pw)
}) { }) {
return return
@ -94,12 +94,12 @@ func (h *HttpServer) MailDelete(rw http.ResponseWriter, _ *http.Request, params
return return
} }
var userInfo *database.User var userInfo *database.User
if h.DbTx(rw, func(tx *database.Tx) (err error) { if h.DbTx(rw, func(tx *database.Queries) (err error) {
userInfo, err = tx.GetUser(userSub) userInfo, err = tx.GetUser(userSub)
if err != nil { if err != nil {
return return
} }
return tx.UpdateUser(userSub, database.RoleToDelete, false) return tx.UpdateUser(userSub, types.RoleToDelete, false)
}) { }) {
return return
} }

View File

@ -22,14 +22,14 @@ func (h *HttpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _
} }
} }
var role database.UserRole var role types.UserRole
var appList []database.ClientInfoDbOutput var appList []database.ClientStore
if h.DbTx(rw, func(tx *database.Tx) (err error) { if h.DbTx(rw, func(tx *database.Queries) (err error) {
role, err = tx.GetUserRole(auth.ID) role, err = tx.GetUserRole(auth.ID)
if err != nil { if err != nil {
return return
} }
appList, err = tx.GetAppList(auth.ID, role == database.RoleAdmin, offset) appList, err = tx.GetAppList(auth.ID, role == types.RoleAdmin, offset)
return return
}) { }) {
return return
@ -39,7 +39,7 @@ func (h *HttpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _
"ServiceName": h.conf.ServiceName, "ServiceName": h.conf.ServiceName,
"Apps": appList, "Apps": appList,
"Offset": offset, "Offset": offset,
"IsAdmin": role == database.RoleAdmin, "IsAdmin": role == types.RoleAdmin,
"NewAppName": q.Get("NewAppName"), "NewAppName": q.Get("NewAppName"),
"NewAppSecret": q.Get("NewAppSecret"), "NewAppSecret": q.Get("NewAppSecret"),
} }
@ -76,14 +76,14 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
active := req.Form.Has("active") active := req.Form.Has("active")
if sso { if sso {
var role database.UserRole var role types.UserRole
if h.DbTx(rw, func(tx *database.Tx) (err error) { if h.DbTx(rw, func(tx *database.Queries) (err error) {
role, err = tx.GetUserRole(auth.ID) role, err = tx.GetUserRole(auth.ID)
return return
}) { }) {
return return
} }
if role != database.RoleAdmin { if role != types.RoleAdmin {
http.Error(rw, "400 Bad Request: Only admin users can create SSO client applications", http.StatusBadRequest) http.Error(rw, "400 Bad Request: Only admin users can create SSO client applications", http.StatusBadRequest)
return return
} }
@ -91,13 +91,13 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
switch action { switch action {
case "create": case "create":
if h.DbTx(rw, func(tx *database.Tx) error { if h.DbTx(rw, func(tx *database.Queries) error {
return tx.InsertClientApp(name, domain, public, sso, active, auth.ID) return tx.InsertClientApp(name, domain, public, sso, active, auth.ID)
}) { }) {
return return
} }
case "edit": case "edit":
if h.DbTx(rw, func(tx *database.Tx) error { if h.DbTx(rw, func(tx *database.Queries) error {
return tx.UpdateClientApp(req.Form.Get("subject"), auth.ID, name, domain, public, sso, active) return tx.UpdateClientApp(req.Form.Get("subject"), auth.ID, name, domain, public, sso, active)
}) { }) {
return return
@ -105,7 +105,7 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
case "secret": case "secret":
var info oauth2.ClientInfo var info oauth2.ClientInfo
var secret string var secret string
if h.DbTx(rw, func(tx *database.Tx) error { if h.DbTx(rw, func(tx *database.Queries) error {
sub := req.Form.Get("subject") sub := req.Form.Get("subject")
info, err = tx.GetClientInfo(sub) info, err = tx.GetClientInfo(sub)
if err != nil { if err != nil {

View File

@ -3,6 +3,7 @@ package server
import ( import (
"errors" "errors"
"github.com/1f349/tulip/database" "github.com/1f349/tulip/database"
"github.com/1f349/tulip/database/types"
"github.com/1f349/tulip/pages" "github.com/1f349/tulip/pages"
"github.com/emersion/go-message/mail" "github.com/emersion/go-message/mail"
"github.com/google/uuid" "github.com/google/uuid"
@ -27,9 +28,9 @@ func (h *HttpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _
} }
} }
var role database.UserRole var role types.UserRole
var userList []database.User var userList []database.User
if h.DbTx(rw, func(tx *database.Tx) (err error) { if h.DbTx(rw, func(tx *database.Queries) (err error) {
role, err = tx.GetUserRole(auth.ID) role, err = tx.GetUserRole(auth.ID)
if err != nil { if err != nil {
return return
@ -39,7 +40,7 @@ func (h *HttpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _
}) { }) {
return return
} }
if role != database.RoleAdmin { if role != types.RoleAdmin {
http.Error(rw, "403 Forbidden", http.StatusForbidden) http.Error(rw, "403 Forbidden", http.StatusForbidden)
return return
} }
@ -76,14 +77,14 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
return return
} }
var role database.UserRole var role types.UserRole
if h.DbTx(rw, func(tx *database.Tx) (err error) { if h.DbTx(rw, func(tx *database.Queries) (err error) {
role, err = tx.GetUserRole(auth.ID) role, err = tx.GetUserRole(auth.ID)
return return
}) { }) {
return return
} }
if role != database.RoleAdmin { if role != types.RoleAdmin {
http.Error(rw, "400 Bad Request: Only admin users can manage users", http.StatusBadRequest) http.Error(rw, "400 Bad Request: Only admin users can manage users", http.StatusBadRequest)
return return
} }
@ -116,7 +117,7 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
addrDomain := address.Address[n+1:] addrDomain := address.Address[n+1:]
var userSub uuid.UUID var userSub uuid.UUID
if h.DbTx(rw, func(tx *database.Tx) (err error) { if h.DbTx(rw, func(tx *database.Queries) (err error) {
userSub, err = tx.InsertUser(name, username, "", email, addrDomain == h.conf.Namespace, newRole, active) userSub, err = tx.InsertUser(name, username, "", email, addrDomain == h.conf.Namespace, newRole, active)
return err return err
}) { }) {
@ -136,7 +137,7 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
return return
} }
case "edit": case "edit":
if h.DbTx(rw, func(tx *database.Tx) error { if h.DbTx(rw, func(tx *database.Queries) error {
sub := req.Form.Get("subject") sub := req.Form.Get("subject")
return tx.UpdateUser(sub, newRole, active) return tx.UpdateUser(sub, newRole, active)
}) { }) {
@ -151,12 +152,12 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
http.Redirect(rw, req, redirectUrl.String(), http.StatusFound) http.Redirect(rw, req, redirectUrl.String(), http.StatusFound)
} }
func parseRoleValue(role string) (database.UserRole, error) { func parseRoleValue(role string) (types.UserRole, error) {
switch role { switch role {
case "member": case "member":
return database.RoleMember, nil return types.RoleMember, nil
case "admin": case "admin":
return database.RoleAdmin, nil return types.RoleAdmin, nil
} }
return 0, errors.New("invalid role value") return 0, errors.New("invalid role value")
} }

View File

@ -79,7 +79,7 @@ func (h *HttpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request
var user *database.User var user *database.User
var hasOtp bool var hasOtp bool
if h.DbTx(rw, func(tx *database.Tx) (err error) { if h.DbTx(rw, func(tx *database.Queries) (err error) {
user, err = tx.GetUserDisplayName(auth.ID) user, err = tx.GetUserDisplayName(auth.ID)
if err != nil { if err != nil {
return return

View File

@ -47,7 +47,7 @@ func (h *HttpServer) fetchAndValidateOtp(rw http.ResponseWriter, sub, code strin
var hasOtp bool var hasOtp bool
var secret string var secret string
var digits int var digits int
if h.DbTx(rw, func(tx *database.Tx) (err error) { if h.DbTx(rw, func(tx *database.Queries) (err error) {
hasOtp, err = tx.HasTwoFactor(sub) hasOtp, err = tx.HasTwoFactor(sub)
if err != nil { if err != nil {
return return
@ -86,7 +86,7 @@ func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ ht
return return
} }
if h.DbTx(rw, func(tx *database.Tx) error { if h.DbTx(rw, func(tx *database.Queries) error {
return tx.SetTwoFactor(auth.ID, "", 0) return tx.SetTwoFactor(auth.ID, "", 0)
}) { }) {
return return
@ -118,7 +118,7 @@ func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ ht
if secret == "" { if secret == "" {
// get user email // get user email
var email string var email string
if h.DbTx(rw, func(tx *database.Tx) error { if h.DbTx(rw, func(tx *database.Queries) error {
var err error var err error
email, err = tx.GetUserEmail(auth.ID) email, err = tx.GetUserEmail(auth.ID)
return err return err
@ -167,7 +167,7 @@ func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ ht
return return
} }
if h.DbTx(rw, func(tx *database.Tx) error { if h.DbTx(rw, func(tx *database.Queries) error {
return tx.SetTwoFactor(auth.ID, secret, digits) return tx.SetTwoFactor(auth.ID, secret, digits)
}) { }) {
return return

View File

@ -31,7 +31,7 @@ type HttpServer struct {
r *httprouter.Router r *httprouter.Router
oauthSrv *server.Server oauthSrv *server.Server
oauthMgr *manage.Manager oauthMgr *manage.Manager
db *database.DB db *database.Queries
conf Conf conf Conf
signingKey mjwt.Signer signingKey mjwt.Signer
@ -50,7 +50,7 @@ type mailLinkKey struct {
data string data string
} }
func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Server { func NewHttpServer(conf Conf, db *database.Queries, signingKey mjwt.Signer) *http.Server {
r := httprouter.New() r := httprouter.New()
// remove last slash from baseUrl // remove last slash from baseUrl
@ -191,10 +191,10 @@ func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Ser
return return
} }
var userData *database.User var userData database.GetUserRow
if hs.DbTx(rw, func(tx *database.Tx) (err error) { if hs.DbTx(rw, func(tx *database.Queries) (err error) {
userData, err = tx.GetUser(userId) userData, err = tx.GetUser(req.Context(), userId)
return err return err
}) { }) {
return return

View File

@ -9,7 +9,13 @@ sql:
out: "database" out: "database"
emit_json_tags: true emit_json_tags: true
overrides: overrides:
- column: "routes.flags" - column: "users.password"
go_type: "github.com/1f349/violet/target.Flags" go_type: "github.com/1f349/tulip/password.HashString"
- column: "redirects.flags" - column: "users.role"
go_type: "github.com/1f349/violet/target.Flags" go_type: "github.com/1f349/tulip/database/types.UserRole"
- column: "users.pronouns"
go_type: "github.com/1f349/tulip/database/types.UserPronoun"
- column: "users.zoneinfo"
go_type: "github.com/1f349/tulip/database/types.UserZone"
- column: "users.locale"
go_type: "github.com/1f349/tulip/database/types.UserLocale"