From 37570e215770b110bcd667d86a9a3b950e3f0918 Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Mon, 11 Mar 2024 12:39:52 +0000 Subject: [PATCH] Correct refactored database calls --- client-store/client-store.go | 12 +- cmd/tulip/serve.go | 4 +- database/clientstore-wrapper.go | 82 +++++ database/db-scanner.go | 145 --------- database/db-scanner_test.go | 52 --- database/db-types.go | 127 -------- database/db_test.go | 5 - database/manage-oauth.sql.go | 64 ++-- database/manage-users.sql.go | 28 +- .../migrations/20240309221547_init.down.sql | 4 + .../migrations/20240309221547_init.up.sql | 32 +- database/models.go | 52 +-- database/password-wrapper.go | 47 +++ database/queries/manage-oauth.sql | 2 +- database/queries/users.sql | 26 +- database/tx.go | 302 ------------------ database/tx_test.go | 52 --- database/types/userlocale.go | 38 +++ database/types/userlocale_test.go | 12 + database/types/userpronoun.go | 38 +++ database/types/userpronoun_test.go | 15 + database/types/userrole.go | 27 ++ database/types/userzone.go | 45 +++ database/types/userzone_test.go | 14 + database/types/utils_test.go | 11 + database/users.sql.go | 135 ++++---- server/auth.go | 9 +- server/db.go | 16 +- server/edit.go | 23 +- server/home.go | 14 +- server/id_token.go | 14 +- server/login.go | 4 +- server/mail.go | 8 +- server/manage-apps.go | 22 +- server/manage-users.go | 23 +- server/oauth.go | 2 +- server/otp.go | 8 +- server/server.go | 10 +- sqlc.yaml | 14 +- 39 files changed, 604 insertions(+), 934 deletions(-) create mode 100644 database/clientstore-wrapper.go delete mode 100644 database/db-scanner.go delete mode 100644 database/db-scanner_test.go delete mode 100644 database/db-types.go delete mode 100644 database/db_test.go create mode 100644 database/password-wrapper.go delete mode 100644 database/tx.go delete mode 100644 database/tx_test.go create mode 100644 database/types/userlocale.go create mode 100644 database/types/userlocale_test.go create mode 100644 database/types/userpronoun.go create mode 100644 database/types/userpronoun_test.go create mode 100644 database/types/userrole.go create mode 100644 database/types/userzone.go create mode 100644 database/types/userzone_test.go create mode 100644 database/types/utils_test.go diff --git a/client-store/client-store.go b/client-store/client-store.go index a57e841..06cb48b 100644 --- a/client-store/client-store.go +++ b/client-store/client-store.go @@ -7,20 +7,16 @@ import ( ) type ClientStore struct { - db *database.DB + db *database.Queries } var _ oauth2.ClientStore = &ClientStore{} -func New(db *database.DB) *ClientStore { +func New(db *database.Queries) *ClientStore { return &ClientStore{db: db} } func (c *ClientStore) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) { - tx, err := c.db.BeginCtx(ctx) - if err != nil { - return nil, err - } - defer tx.Rollback() - return tx.GetClientInfo(id) + a, err := c.db.GetClientInfo(ctx, id) + return &a, err } diff --git a/cmd/tulip/serve.go b/cmd/tulip/serve.go index bec34a0..bd3eb7d 100644 --- a/cmd/tulip/serve.go +++ b/cmd/tulip/serve.go @@ -118,7 +118,7 @@ func genHmacKey() []byte { return a } -func checkDbHasUser(db *database.DB) error { +func checkDbHasUser(db *database.Queries) error { tx, err := db.Begin() if err != nil { return fmt.Errorf("failed to start transaction: %w", err) @@ -126,7 +126,7 @@ func checkDbHasUser(db *database.DB) error { defer tx.Rollback() if err := tx.HasUser(); err != nil { if errors.Is(err, sql.ErrNoRows) { - _, err := tx.InsertUser("Admin", "admin", "admin", "admin@localhost", false, database.RoleAdmin, false) + _, err := tx.InsertUser("Admin", "admin", "admin", "admin@localhost", false, types.RoleAdmin, false) if err != nil { return fmt.Errorf("failed to add user: %w", err) } diff --git a/database/clientstore-wrapper.go b/database/clientstore-wrapper.go new file mode 100644 index 0000000..ba9cf84 --- /dev/null +++ b/database/clientstore-wrapper.go @@ -0,0 +1,82 @@ +package database + +import ( + "database/sql" + "fmt" + "github.com/1f349/tulip/database/types" + "github.com/MrMelon54/pronouns" + "github.com/go-oauth2/oauth2/v4" + "golang.org/x/text/language" + "net/url" + "time" +) + +type UserPatch struct { + Name string + Picture string + Website string + Pronouns types.UserPronoun + Birthdate sql.NullTime + ZoneInfo types.UserZone + Locale types.UserLocale +} + +func (u *UserPatch) ParseFromForm(v url.Values) (safeErrs []error) { + var err error + u.Name = v.Get("name") + u.Picture = v.Get("picture") + u.Website = v.Get("website") + if v.Has("reset_pronouns") { + u.Pronouns.Pronoun = pronouns.TheyThem + } else { + u.Pronouns.Pronoun, err = pronouns.FindPronoun(v.Get("pronouns")) + if err != nil { + safeErrs = append(safeErrs, fmt.Errorf("invalid pronoun selected")) + } + } + if v.Has("reset_birthdate") || v.Get("birthdate") == "" { + u.Birthdate = sql.NullTime{} + } else { + u.Birthdate = sql.NullTime{Valid: true} + u.Birthdate.Time, err = time.Parse(time.DateOnly, v.Get("birthdate")) + if err != nil { + safeErrs = append(safeErrs, fmt.Errorf("invalid time selected")) + } + } + if v.Has("reset_zoneinfo") { + u.ZoneInfo.Location = time.UTC + } else { + u.ZoneInfo.Location, err = time.LoadLocation(v.Get("zoneinfo")) + if err != nil { + safeErrs = append(safeErrs, fmt.Errorf("invalid timezone selected")) + } + } + if v.Has("reset_locale") { + u.Locale.Tag = language.AmericanEnglish + } else { + u.Locale.Tag, err = language.Parse(v.Get("locale")) + if err != nil { + safeErrs = append(safeErrs, fmt.Errorf("invalid language selected")) + } + } + return +} + +var _ oauth2.ClientInfo = &ClientStore{} + +func (c *ClientStore) GetID() string { return c.Subject } +func (c *ClientStore) GetSecret() string { return c.Secret } +func (c *ClientStore) GetDomain() string { return c.Domain } +func (c *ClientStore) IsPublic() bool { return c.Public } +func (c *ClientStore) GetUserID() string { return c.Owner } + +// GetName is an extra field for the oauth handler to display the application +// name +func (c *ClientStore) GetName() string { return c.Name } + +// IsSSO is an extra field for the oauth handler to skip the user input stage +// this is for trusted applications to get permissions without asking the user +func (c *ClientStore) IsSSO() bool { return c.Sso } + +// IsActive is an extra field for the app manager to get the active state +func (c *ClientStore) IsActive() bool { return c.Active } diff --git a/database/db-scanner.go b/database/db-scanner.go deleted file mode 100644 index b4848b6..0000000 --- a/database/db-scanner.go +++ /dev/null @@ -1,145 +0,0 @@ -package database - -import ( - "database/sql" - "encoding/json" - "fmt" - "github.com/MrMelon54/pronouns" - "golang.org/x/text/language" - "time" -) - -var ( - _, _, _, _, _ sql.Scanner = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{} - _, _, _, _, _ json.Marshaler = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{} - _, _, _, _, _ json.Unmarshaler = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{} -) - -func marshalValueOrNull(null bool, data any) ([]byte, error) { - if null { - return json.Marshal(nil) - } - return json.Marshal(data) -} - -type NullStringScanner struct{ sql.NullString } - -func (s *NullStringScanner) Null() bool { return !s.Valid } -func (s *NullStringScanner) Scan(src any) error { return s.NullString.Scan(src) } -func (s NullStringScanner) MarshalJSON() ([]byte, error) { - return marshalValueOrNull(s.Null(), s.NullString.String) -} -func (s *NullStringScanner) UnmarshalJSON(bytes []byte) error { - if string(bytes) == "null" { - return s.Scan(nil) - } - var a string - err := json.Unmarshal(bytes, &a) - if err != nil { - return err - } - return s.Scan(&a) -} -func (s NullStringScanner) String() string { - if s.Null() { - return "" - } - return s.NullString.String -} - -type NullDateScanner struct{ sql.NullTime } - -func (t *NullDateScanner) Null() bool { return !t.Valid } -func (t *NullDateScanner) Scan(src any) error { return t.NullTime.Scan(src) } -func (t NullDateScanner) MarshalJSON() ([]byte, error) { - return marshalValueOrNull(t.Null(), t.Time.UTC().Format(time.DateOnly)) -} -func (t *NullDateScanner) UnmarshalJSON(bytes []byte) error { - if string(bytes) == "null" { - return t.Scan(nil) - } - var a string - err := json.Unmarshal(bytes, &a) - if err != nil { - return err - } - return t.Scan(&a) -} -func (t NullDateScanner) String() string { - if t.Null() { - return "" - } - return t.NullTime.Time.UTC().Format(time.DateOnly) -} - -type LocationScanner struct{ *time.Location } - -func (l *LocationScanner) Scan(src any) error { - s, ok := src.(string) - if !ok { - return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l) - } - loc, err := time.LoadLocation(s) - if err != nil { - return err - } - l.Location = loc - return nil -} -func (l LocationScanner) MarshalJSON() ([]byte, error) { return json.Marshal(l.Location.String()) } -func (l *LocationScanner) UnmarshalJSON(bytes []byte) error { - var a string - err := json.Unmarshal(bytes, &a) - if err != nil { - return err - } - return l.Scan(a) -} - -type LocaleScanner struct{ language.Tag } - -func (l *LocaleScanner) Scan(src any) error { - s, ok := src.(string) - if !ok { - return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l) - } - lang, err := language.Parse(s) - if err != nil { - return err - } - l.Tag = lang - return nil -} -func (l LocaleScanner) MarshalJSON() ([]byte, error) { return json.Marshal(l.Tag.String()) } -func (l *LocaleScanner) UnmarshalJSON(bytes []byte) error { - var a string - err := json.Unmarshal(bytes, &a) - if err != nil { - return err - } - return l.Scan(a) -} - -type PronounScanner struct{ pronouns.Pronoun } - -func (p *PronounScanner) Scan(src any) error { - s, ok := src.(string) - if !ok { - return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, p) - } - pro, err := pronouns.FindPronoun(s) - if err != nil { - return err - } - p.Pronoun = pro - return nil -} -func (p PronounScanner) MarshalJSON() ([]byte, error) { return json.Marshal(p.Pronoun.String()) } -func (p *PronounScanner) UnmarshalJSON(bytes []byte) error { - var a string - err := json.Unmarshal(bytes, &a) - if err != nil { - return err - } - return p.Scan(a) -} diff --git a/database/db-scanner_test.go b/database/db-scanner_test.go deleted file mode 100644 index c7217f3..0000000 --- a/database/db-scanner_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package database - -import ( - "database/sql" - "encoding/json" - "github.com/MrMelon54/pronouns" - "github.com/stretchr/testify/assert" - "golang.org/x/text/language" - "testing" - "time" -) - -func encode(data any) string { - j, err := json.Marshal(map[string]any{"value": data}) - if err != nil { - panic(err) - } - return string(j) -} - -func TestStringScanner_MarshalJSON(t *testing.T) { - assert.Equal(t, "{\"value\":\"Hello world\"}", encode(NullStringScanner{sql.NullString{String: "Hello world", Valid: true}})) - assert.Equal(t, "{\"value\":null}", encode(NullStringScanner{sql.NullString{String: "Hello world", Valid: false}})) -} - -func TestDateScanner_MarshalJSON(t *testing.T) { - location, err := time.LoadLocation("Europe/London") - assert.NoError(t, err) - assert.Equal(t, "{\"value\":\"2006-01-02\"}", encode(NullDateScanner{sql.NullTime{Time: time.Date(2006, time.January, 2, 0, 0, 0, 0, time.UTC), Valid: true}})) - assert.Equal(t, "{\"value\":\"2006-08-01\"}", encode(NullDateScanner{sql.NullTime{Time: time.Date(2006, time.August, 2, 0, 0, 0, 0, location), Valid: true}})) - assert.Equal(t, "{\"value\":null}", encode(NullDateScanner{})) -} - -func TestLocationScanner_MarshalJSON(t *testing.T) { - location, err := time.LoadLocation("Europe/London") - assert.NoError(t, err) - assert.Equal(t, "{\"value\":\"Europe/London\"}", encode(LocationScanner{location})) - assert.Equal(t, "{\"value\":\"UTC\"}", encode(LocationScanner{time.UTC})) -} - -func TestLocaleScanner_MarshalJSON(t *testing.T) { - assert.Equal(t, "{\"value\":\"en-US\"}", encode(LocaleScanner{language.AmericanEnglish})) - assert.Equal(t, "{\"value\":\"en-GB\"}", encode(LocaleScanner{language.BritishEnglish})) -} - -func TestPronounScanner_MarshalJSON(t *testing.T) { - assert.Equal(t, "{\"value\":\"they/them\"}", encode(PronounScanner{pronouns.TheyThem})) - assert.Equal(t, "{\"value\":\"he/him\"}", encode(PronounScanner{pronouns.HeHim})) - assert.Equal(t, "{\"value\":\"she/her\"}", encode(PronounScanner{pronouns.SheHer})) - assert.Equal(t, "{\"value\":\"it/its\"}", encode(PronounScanner{pronouns.ItIts})) - assert.Equal(t, "{\"value\":\"one/one's\"}", encode(PronounScanner{pronouns.OneOnes})) -} diff --git a/database/db-types.go b/database/db-types.go deleted file mode 100644 index d368445..0000000 --- a/database/db-types.go +++ /dev/null @@ -1,127 +0,0 @@ -package database - -import ( - "database/sql" - "fmt" - "github.com/MrMelon54/pronouns" - "github.com/go-oauth2/oauth2/v4" - "golang.org/x/text/language" - "net/url" - "time" -) - -type User struct { - Sub string `json:"sub"` - Name string `json:"name,omitempty"` - Username string `json:"username"` - Picture NullStringScanner `json:"picture,omitempty"` - Website NullStringScanner `json:"website,omitempty"` - Email string `json:"email"` - EmailVerified bool `json:"email_verified"` - Pronouns PronounScanner `json:"pronouns,omitempty"` - Birthdate NullDateScanner `json:"birthdate,omitempty"` - ZoneInfo LocationScanner `json:"zoneinfo,omitempty"` - Locale LocaleScanner `json:"locale,omitempty"` - Role UserRole `json:"role"` - UpdatedAt time.Time `json:"updated_at"` - Active bool `json:"active"` -} - -type UserRole int - -const ( - RoleMember UserRole = iota - RoleAdmin - RoleToDelete -) - -func (r UserRole) String() string { - switch r { - case RoleMember: - return "Member" - case RoleAdmin: - return "Admin" - case RoleToDelete: - return "ToDelete" - } - return fmt.Sprintf("UserRole{ %d }", r) -} - -func (r UserRole) IsValid() bool { - return r == RoleMember || r == RoleAdmin -} - -type UserPatch struct { - Name string - Picture string - Website string - Pronouns pronouns.Pronoun - Birthdate sql.NullTime - ZoneInfo *time.Location - Locale language.Tag -} - -func (u *UserPatch) ParseFromForm(v url.Values) (safeErrs []error) { - var err error - u.Name = v.Get("name") - u.Picture = v.Get("picture") - u.Website = v.Get("website") - if v.Has("reset_pronouns") { - u.Pronouns = pronouns.TheyThem - } else { - u.Pronouns, err = pronouns.FindPronoun(v.Get("pronouns")) - if err != nil { - safeErrs = append(safeErrs, fmt.Errorf("invalid pronoun selected")) - } - } - if v.Has("reset_birthdate") || v.Get("birthdate") == "" { - u.Birthdate = sql.NullTime{} - } else { - u.Birthdate = sql.NullTime{Valid: true} - u.Birthdate.Time, err = time.Parse(time.DateOnly, v.Get("birthdate")) - if err != nil { - safeErrs = append(safeErrs, fmt.Errorf("invalid time selected")) - } - } - if v.Has("reset_zoneinfo") { - u.ZoneInfo = time.UTC - } else { - u.ZoneInfo, err = time.LoadLocation(v.Get("zoneinfo")) - if err != nil { - safeErrs = append(safeErrs, fmt.Errorf("invalid timezone selected")) - } - } - if v.Has("reset_locale") { - u.Locale = language.AmericanEnglish - } else { - u.Locale, err = language.Parse(v.Get("locale")) - if err != nil { - safeErrs = append(safeErrs, fmt.Errorf("invalid language selected")) - } - } - return -} - -type ClientInfoDbOutput struct { - Sub, Name, Secret, Domain, Owner string - Public, SSO, Active bool -} - -var _ oauth2.ClientInfo = &ClientInfoDbOutput{} - -func (c *ClientInfoDbOutput) GetID() string { return c.Sub } -func (c *ClientInfoDbOutput) GetSecret() string { return c.Secret } -func (c *ClientInfoDbOutput) GetDomain() string { return c.Domain } -func (c *ClientInfoDbOutput) IsPublic() bool { return c.Public } -func (c *ClientInfoDbOutput) GetUserID() string { return c.Owner } - -// GetName is an extra field for the oauth handler to display the application -// name -func (c *ClientInfoDbOutput) GetName() string { return c.Name } - -// IsSSO is an extra field for the oauth handler to skip the user input stage -// this is for trusted applications to get permissions without asking the user -func (c *ClientInfoDbOutput) IsSSO() bool { return c.SSO } - -// IsActive is an extra field for the app manager to get the active state -func (c *ClientInfoDbOutput) IsActive() bool { return c.Active } diff --git a/database/db_test.go b/database/db_test.go deleted file mode 100644 index 5d73aab..0000000 --- a/database/db_test.go +++ /dev/null @@ -1,5 +0,0 @@ -package database - -import ( - _ "github.com/mattn/go-sqlite3" -) diff --git a/database/manage-oauth.sql.go b/database/manage-oauth.sql.go index 4f60ecf..7626567 100644 --- a/database/manage-oauth.sql.go +++ b/database/manage-oauth.sql.go @@ -7,7 +7,6 @@ package database import ( "context" - "database/sql" ) const getAppList = `-- name: GetAppList :many @@ -25,13 +24,13 @@ type GetAppListParams struct { } type GetAppListRow struct { - Subject string `json:"subject"` - Name string `json:"name"` - Domain string `json:"domain"` - Owner string `json:"owner"` - Public sql.NullInt64 `json:"public"` - Sso sql.NullInt64 `json:"sso"` - Active sql.NullInt64 `json:"active"` + Subject string `json:"subject"` + Name string `json:"name"` + Domain string `json:"domain"` + Owner string `json:"owner"` + Public bool `json:"public"` + Sso bool `json:"sso"` + Active bool `json:"active"` } func (q *Queries) GetAppList(ctx context.Context, arg GetAppListParams) ([]GetAppListRow, error) { @@ -66,28 +65,21 @@ func (q *Queries) GetAppList(ctx context.Context, arg GetAppListParams) ([]GetAp } const getClientInfo = `-- name: GetClientInfo :one -SELECT secret, name, domain, public, sso, active +SELECT subject, name, secret, domain, owner, public, sso, active FROM client_store WHERE subject = ? LIMIT 1 ` -type GetClientInfoRow struct { - Secret string `json:"secret"` - Name string `json:"name"` - Domain string `json:"domain"` - Public sql.NullInt64 `json:"public"` - Sso sql.NullInt64 `json:"sso"` - Active sql.NullInt64 `json:"active"` -} - -func (q *Queries) GetClientInfo(ctx context.Context, subject string) (GetClientInfoRow, error) { +func (q *Queries) GetClientInfo(ctx context.Context, subject string) (ClientStore, error) { row := q.db.QueryRowContext(ctx, getClientInfo, subject) - var i GetClientInfoRow + var i ClientStore err := row.Scan( - &i.Secret, + &i.Subject, &i.Name, + &i.Secret, &i.Domain, + &i.Owner, &i.Public, &i.Sso, &i.Active, @@ -101,14 +93,14 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?) ` type InsertClientAppParams struct { - Subject string `json:"subject"` - Name string `json:"name"` - Secret string `json:"secret"` - Domain string `json:"domain"` - Owner string `json:"owner"` - Public sql.NullInt64 `json:"public"` - Sso sql.NullInt64 `json:"sso"` - Active sql.NullInt64 `json:"active"` + Subject string `json:"subject"` + Name string `json:"name"` + Secret string `json:"secret"` + Domain string `json:"domain"` + Owner string `json:"owner"` + Public bool `json:"public"` + Sso bool `json:"sso"` + Active bool `json:"active"` } func (q *Queries) InsertClientApp(ctx context.Context, arg InsertClientAppParams) error { @@ -137,13 +129,13 @@ WHERE subject = ? ` type UpdateClientAppParams struct { - Name string `json:"name"` - Domain string `json:"domain"` - Public sql.NullInt64 `json:"public"` - Sso sql.NullInt64 `json:"sso"` - Active sql.NullInt64 `json:"active"` - Subject string `json:"subject"` - Owner string `json:"owner"` + Name string `json:"name"` + Domain string `json:"domain"` + Public bool `json:"public"` + Sso bool `json:"sso"` + Active bool `json:"active"` + Subject string `json:"subject"` + Owner string `json:"owner"` } func (q *Queries) UpdateClientApp(ctx context.Context, arg UpdateClientAppParams) error { diff --git a/database/manage-users.sql.go b/database/manage-users.sql.go index 57e2241..910218a 100644 --- a/database/manage-users.sql.go +++ b/database/manage-users.sql.go @@ -7,7 +7,9 @@ package database import ( "context" - "database/sql" + "time" + + "github.com/1f349/tulip/database/types" ) const getUserList = `-- name: GetUserList :many @@ -25,15 +27,15 @@ LIMIT 25 OFFSET ? ` type GetUserListRow struct { - Subject string `json:"subject"` - Name string `json:"name"` - Username string `json:"username"` - Picture interface{} `json:"picture"` - Email string `json:"email"` - EmailVerified int64 `json:"email_verified"` - Role int64 `json:"role"` - UpdatedAt sql.NullTime `json:"updated_at"` - Active sql.NullInt64 `json:"active"` + Subject string `json:"subject"` + Name string `json:"name"` + Username string `json:"username"` + Picture string `json:"picture"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Role types.UserRole `json:"role"` + UpdatedAt time.Time `json:"updated_at"` + Active bool `json:"active"` } func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListRow, error) { @@ -77,9 +79,9 @@ WHERE subject = ? ` type UpdateUserRoleParams struct { - Active sql.NullInt64 `json:"active"` - Role int64 `json:"role"` - Subject string `json:"subject"` + Active bool `json:"active"` + Role types.UserRole `json:"role"` + Subject string `json:"subject"` } func (q *Queries) UpdateUserRole(ctx context.Context, arg UpdateUserRoleParams) error { diff --git a/database/migrations/20240309221547_init.down.sql b/database/migrations/20240309221547_init.down.sql index e69de29..617d782 100644 --- a/database/migrations/20240309221547_init.down.sql +++ b/database/migrations/20240309221547_init.down.sql @@ -0,0 +1,4 @@ +DROP TABLE users; +DROP INDEX username_index; +DROP TABLE client_store; +DROP TABLE otp; diff --git a/database/migrations/20240309221547_init.up.sql b/database/migrations/20240309221547_init.up.sql index 691aeda..85e4b52 100644 --- a/database/migrations/20240309221547_init.up.sql +++ b/database/migrations/20240309221547_init.up.sql @@ -1,39 +1,39 @@ -CREATE TABLE IF NOT EXISTS users +CREATE TABLE users ( subject TEXT PRIMARY KEY UNIQUE NOT NULL, name TEXT NOT NULL, username TEXT UNIQUE NOT NULL, password TEXT NOT NULL, - picture TEXT DEFAULT "" NOT NULL, - website TEXT DEFAULT "" NOT NULL, + picture TEXT DEFAULT '' NOT NULL, + website TEXT DEFAULT '' NOT NULL, email TEXT NOT NULL, - email_verified INTEGER DEFAULT 0 NOT NULL, - pronouns TEXT DEFAULT "they/them" NOT NULL, + email_verified BOOLEAN DEFAULT 0 NOT NULL, + pronouns TEXT DEFAULT 'they/them' NOT NULL, birthdate DATE, - zoneinfo TEXT DEFAULT "UTC" NOT NULL, - locale TEXT DEFAULT "en-US" NOT NULL, + zoneinfo TEXT DEFAULT 'UTC' NOT NULL, + locale TEXT DEFAULT 'en-US' NOT NULL, role INTEGER DEFAULT 0 NOT NULL, - updated_at DATETIME, - registered INTEGER DEFAULT 0, - active INTEGER DEFAULT 1 + updated_at DATETIME NOT NULL, + registered DATETIME NOT NULL, + active BOOLEAN DEFAULT 1 NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS username_index ON users (username); +CREATE UNIQUE INDEX username_index ON users (username); -CREATE TABLE IF NOT EXISTS client_store +CREATE TABLE client_store ( subject TEXT PRIMARY KEY UNIQUE NOT NULL, name TEXT NOT NULL, secret TEXT UNIQUE NOT NULL, domain TEXT NOT NULL, owner TEXT NOT NULL, - public INTEGER, - sso INTEGER, - active INTEGER DEFAULT 1, + public BOOLEAN NOT NULL, + sso BOOLEAN NOT NULL, + active BOOLEAN DEFAULT 1 NOT NULL, FOREIGN KEY (owner) REFERENCES users (subject) ); -CREATE TABLE IF NOT EXISTS otp +CREATE TABLE otp ( subject TEXT PRIMARY KEY UNIQUE NOT NULL, secret TEXT NOT NULL, diff --git a/database/models.go b/database/models.go index a00dec9..f0b0816 100644 --- a/database/models.go +++ b/database/models.go @@ -6,17 +6,21 @@ package database import ( "database/sql" + "time" + + "github.com/1f349/tulip/database/types" + "github.com/1f349/tulip/password" ) type ClientStore struct { - Subject string `json:"subject"` - Name string `json:"name"` - Secret string `json:"secret"` - Domain string `json:"domain"` - Owner string `json:"owner"` - Public sql.NullInt64 `json:"public"` - Sso sql.NullInt64 `json:"sso"` - Active sql.NullInt64 `json:"active"` + Subject string `json:"subject"` + Name string `json:"name"` + Secret string `json:"secret"` + Domain string `json:"domain"` + Owner string `json:"owner"` + Public bool `json:"public"` + Sso bool `json:"sso"` + Active bool `json:"active"` } type Otp struct { @@ -26,20 +30,20 @@ type Otp struct { } type User struct { - Subject string `json:"subject"` - Name string `json:"name"` - Username string `json:"username"` - Password string `json:"password"` - Picture interface{} `json:"picture"` - Website interface{} `json:"website"` - Email string `json:"email"` - EmailVerified int64 `json:"email_verified"` - Pronouns interface{} `json:"pronouns"` - Birthdate sql.NullTime `json:"birthdate"` - Zoneinfo interface{} `json:"zoneinfo"` - Locale interface{} `json:"locale"` - Role int64 `json:"role"` - UpdatedAt sql.NullTime `json:"updated_at"` - Registered sql.NullInt64 `json:"registered"` - Active sql.NullInt64 `json:"active"` + Subject string `json:"subject"` + Name string `json:"name"` + Username string `json:"username"` + Password password.HashString `json:"password"` + Picture string `json:"picture"` + Website string `json:"website"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Pronouns types.UserPronoun `json:"pronouns"` + Birthdate sql.NullTime `json:"birthdate"` + Zoneinfo types.UserZone `json:"zoneinfo"` + Locale types.UserLocale `json:"locale"` + Role types.UserRole `json:"role"` + UpdatedAt time.Time `json:"updated_at"` + Registered time.Time `json:"registered"` + Active bool `json:"active"` } diff --git a/database/password-wrapper.go b/database/password-wrapper.go new file mode 100644 index 0000000..765502c --- /dev/null +++ b/database/password-wrapper.go @@ -0,0 +1,47 @@ +package database + +import ( + "context" + "github.com/1f349/tulip/database/types" + "github.com/1f349/tulip/password" + "github.com/google/uuid" + "time" +) + +type AddUserParams struct { + Name string `json:"name"` + Username string `json:"username"` + Password string `json:"password"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Role types.UserRole `json:"role"` + UpdatedAt time.Time `json:"updated_at"` + Active bool `json:"active"` +} + +func (q *Queries) AddUser(ctx context.Context, arg AddUserParams) (string, error) { + pwHash, err := password.HashPassword(arg.Password) + if err != nil { + return "", err + } + a := addUserParams{ + Subject: uuid.NewString(), + Name: arg.Name, + Username: arg.Username, + Password: pwHash, + Email: arg.Email, + EmailVerified: arg.EmailVerified, + Role: arg.Role, + UpdatedAt: arg.UpdatedAt, + Active: arg.Active, + } + return a.Subject, q.addUser(ctx, a) +} + +type CheckLoginRow struct { + Subject string `json:"subject"` + Password password.HashString `json:"password"` + HasTwoFactor bool `json:"hasTwoFactor"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` +} diff --git a/database/queries/manage-oauth.sql b/database/queries/manage-oauth.sql index 60373fd..100054c 100644 --- a/database/queries/manage-oauth.sql +++ b/database/queries/manage-oauth.sql @@ -1,5 +1,5 @@ -- name: GetClientInfo :one -SELECT secret, name, domain, public, sso, active +SELECT * FROM client_store WHERE subject = ? LIMIT 1; diff --git a/database/queries/users.sql b/database/queries/users.sql index 25790c8..1517817 100644 --- a/database/queries/users.sql +++ b/database/queries/users.sql @@ -13,22 +13,21 @@ WHERE username = ? LIMIT 1; -- name: GetUser :one -SELECT name, - username, - picture, - website, - email, - email_verified, - pronouns, - birthdate, - zoneinfo, - locale, - updated_at, - active +SELECT * FROM users WHERE subject = ? LIMIT 1; +-- name: GetUserRole :one +SELECT role +FROM users +WHERE subject = ?; + +-- name: GetUserDisplayName :one +SELECT name +FROM users +WHERE subject = ?; + -- name: getUserPassword :one SELECT password FROM users @@ -68,3 +67,6 @@ WHERE otp.subject = ?; SELECT secret, digits FROM otp WHERE subject = ?; + +-- name: HasTwoFactor :one +SELECT cast(EXISTS(SELECT 1 FROM otp WHERE subject = ?) AS BOOLEAN); diff --git a/database/tx.go b/database/tx.go deleted file mode 100644 index fec526e..0000000 --- a/database/tx.go +++ /dev/null @@ -1,302 +0,0 @@ -package database - -import ( - "database/sql" - "fmt" - "github.com/1f349/tulip/password" - "github.com/go-oauth2/oauth2/v4" - "github.com/google/uuid" - "time" -) - -func updatedAt() string { - return time.Now().UTC().Format(time.DateTime) -} - -type Tx struct{ tx *sql.Tx } - -func (t *Tx) Commit() error { - return t.tx.Commit() -} - -func (t *Tx) Rollback() { - _ = t.tx.Rollback() -} - -func (t *Tx) HasUser() error { - var exists bool - row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users)`) - err := row.Scan(&exists) - if err != nil { - return err - } - if !exists { - return sql.ErrNoRows - } - return nil -} - -func (t *Tx) InsertUser(name, un, pw, email string, verifyEmail bool, role UserRole, active bool) (uuid.UUID, error) { - pwHash, err := password.HashPassword(pw) - if err != nil { - return uuid.UUID{}, err - } - u := uuid.New() - _, err = t.tx.Exec(`INSERT INTO users (subject, name, username, password, email, email_verified, role, updated_at, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, u, name, un, pwHash, email, verifyEmail, role, updatedAt(), active) - return u, err -} - -func (t *Tx) CheckLogin(un, pw string) (*User, bool, bool, error) { - var u User - var pwHash password.HashString - var hasOtp, hasVerify bool - row := t.tx.QueryRow(`SELECT subject, password, EXISTS(SELECT 1 FROM otp WHERE otp.subject = users.subject), email, email_verified FROM users WHERE username = ?`, un) - err := row.Scan(&u.Sub, &pwHash, &hasOtp, &u.Email, &hasVerify) - if err != nil { - return nil, false, false, err - } - err = password.CheckPasswordHash(pwHash, pw) - return &u, hasOtp, hasVerify, err -} - -func (t *Tx) GetUserDisplayName(sub string) (*User, error) { - var u User - row := t.tx.QueryRow(`SELECT name FROM users WHERE subject = ? LIMIT 1`, sub) - err := row.Scan(&u.Name) - u.Sub = sub - return &u, err -} - -func (t *Tx) GetUserRole(sub string) (UserRole, error) { - var r UserRole - row := t.tx.QueryRow(`SELECT role FROM users WHERE subject = ? LIMIT 1`, sub) - err := row.Scan(&r) - return r, err -} - -func (t *Tx) GetUser(sub string) (*User, error) { - var u User - row := t.tx.QueryRow(`SELECT name, username, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, updated_at, active FROM users WHERE subject = ?`, sub) - err := row.Scan(&u.Name, &u.Username, &u.Picture, &u.Website, &u.Email, &u.EmailVerified, &u.Pronouns, &u.Birthdate, &u.ZoneInfo, &u.Locale, &u.UpdatedAt, &u.Active) - u.Sub = sub - return &u, err -} - -func (t *Tx) GetUserEmail(sub string) (string, error) { - var email string - row := t.tx.QueryRow(`SELECT email FROM users WHERE subject = ?`, sub) - err := row.Scan(&email) - return email, err -} - -func (t *Tx) ChangeUserPassword(sub, pwOld, pwNew string) error { - q, err := t.tx.Query(`SELECT password FROM users WHERE subject = ?`, sub) - if err != nil { - return err - } - var pwHash password.HashString - if q.Next() { - err = q.Scan(&pwHash) - if err != nil { - return err - } - } else { - return fmt.Errorf("invalid user") - } - if err := q.Err(); err != nil { - return err - } - if err := q.Close(); err != nil { - return err - } - err = password.CheckPasswordHash(pwHash, pwOld) - if err != nil { - return err - } - pwNewHash, err := password.HashPassword(pwNew) - if err != nil { - return err - } - exec, err := t.tx.Exec(`UPDATE users SET password = ?, updated_at = ? WHERE subject = ? AND password = ?`, pwNewHash, updatedAt(), sub, pwHash) - if err != nil { - return err - } - affected, err := exec.RowsAffected() - if err != nil { - return err - } - if affected != 1 { - return fmt.Errorf("row wasn't updated") - } - return nil -} - -func (t *Tx) ModifyUser(sub string, v *UserPatch) error { - exec, err := t.tx.Exec( - `UPDATE users -SET name = ?, - picture = ?, - website = ?, - pronouns = ?, - birthdate = ?, - zoneinfo = ?, - locale = ?, - updated_at = ? -WHERE subject = ?`, - v.Name, - v.Picture, - v.Website, - v.Pronouns.String(), - v.Birthdate, - v.ZoneInfo.String(), - v.Locale.String(), - updatedAt(), - sub, - ) - if err != nil { - return err - } - affected, err := exec.RowsAffected() - if err != nil { - return err - } - if affected != 1 { - return fmt.Errorf("row wasn't updated") - } - return nil -} - -func (t *Tx) SetTwoFactor(sub string, secret string, digits int) error { - if secret == "" && digits == 0 { - _, err := t.tx.Exec(`DELETE FROM otp WHERE otp.subject = ?`, sub) - return err - } - _, err := t.tx.Exec(`INSERT INTO otp(subject, secret, digits) VALUES (?, ?, ?) ON CONFLICT(subject) DO UPDATE SET secret = excluded.secret, digits = excluded.digits`, sub, secret, digits) - return err -} - -func (t *Tx) GetTwoFactor(sub string) (string, int, error) { - var secret string - var digits int - row := t.tx.QueryRow(`SELECT secret, digits FROM otp WHERE subject = ?`, sub) - err := row.Scan(&secret, &digits) - if err != nil { - return "", 0, err - } - return secret, digits, nil -} - -func (t *Tx) HasTwoFactor(sub string) (bool, error) { - var hasOtp bool - row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM otp WHERE otp.subject = ?)`, sub) - err := row.Scan(&hasOtp) - if err != nil { - return false, err - } - return hasOtp, row.Err() -} - -func (t *Tx) GetClientInfo(sub string) (oauth2.ClientInfo, error) { - var u ClientInfoDbOutput - row := t.tx.QueryRow(`SELECT secret, name, domain, public, sso, active FROM client_store WHERE subject = ? LIMIT 1`, sub) - err := row.Scan(&u.Secret, &u.Name, &u.Domain, &u.Public, &u.SSO, &u.Active) - u.Owner = sub - if !u.Active { - return nil, fmt.Errorf("client is not active") - } - return &u, err -} - -func (t *Tx) GetAppList(owner string, admin bool, offset int) ([]ClientInfoDbOutput, error) { - var u []ClientInfoDbOutput - row, err := t.tx.Query(`SELECT subject, name, domain, owner, public, sso, active FROM client_store WHERE owner = ? OR ? = 1 LIMIT 25 OFFSET ?`, owner, admin, offset) - if err != nil { - return nil, err - } - defer row.Close() - for row.Next() { - var a ClientInfoDbOutput - err := row.Scan(&a.Sub, &a.Name, &a.Domain, &a.Owner, &a.Public, &a.SSO, &a.Active) - if err != nil { - return nil, err - } - u = append(u, a) - } - return u, row.Err() -} - -func (t *Tx) InsertClientApp(name, domain string, public, sso, active bool, owner string) error { - u := uuid.New() - secret, err := password.GenerateApiSecret(70) - if err != nil { - return err - } - _, err = t.tx.Exec(`INSERT INTO client_store (subject, name, secret, domain, owner, public, sso, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, u.String(), name, secret, domain, owner, public, sso, active) - return err -} - -func (t *Tx) UpdateClientApp(subject, owner string, name, domain string, public, sso, active bool) error { - _, err := t.tx.Exec(`UPDATE client_store SET name = ?, domain = ?, public = ?, sso = ?, active = ? WHERE subject = ? AND owner = ?`, name, domain, public, sso, active, subject, owner) - return err -} - -func (t *Tx) ResetClientAppSecret(subject, owner string) (string, error) { - secret, err := password.GenerateApiSecret(70) - if err != nil { - return "", err - } - _, err = t.tx.Exec(`UPDATE client_store SET secret = ? WHERE subject = ? AND owner = ?`, secret, subject, owner) - return secret, err -} - -func (t *Tx) GetUserList(offset int) ([]User, error) { - var u []User - row, err := t.tx.Query(`SELECT subject, name, username, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, role, updated_at, active FROM users LIMIT 25 OFFSET ?`, offset) - if err != nil { - return nil, err - } - for row.Next() { - var a User - err := row.Scan(&a.Sub, &a.Name, &a.Username, &a.Picture, &a.Website, &a.Email, &a.EmailVerified, &a.Pronouns, &a.Birthdate, &a.ZoneInfo, &a.Locale, &a.Role, &a.UpdatedAt, &a.Active) - if err != nil { - return nil, err - } - u = append(u, a) - } - return u, row.Err() -} - -func (t *Tx) UpdateUser(subject string, role UserRole, active bool) error { - _, err := t.tx.Exec(`UPDATE users SET active = ?, role = ? WHERE subject = ?`, active, role, subject) - return err -} - -func (t *Tx) VerifyUserEmail(sub string) error { - _, err := t.tx.Exec(`UPDATE users SET email_verified = 1 WHERE subject = ?`, sub) - return err -} - -func (t *Tx) UserResetPassword(sub string, pw string) error { - hashPassword, err := password.HashPassword(pw) - if err != nil { - return err - } - exec, err := t.tx.Exec(`UPDATE users SET password = ?, updated_at = ? WHERE subject = ?`, hashPassword, updatedAt(), sub) - if err != nil { - return err - } - affected, err := exec.RowsAffected() - if err != nil { - return err - } - if affected != 1 { - return fmt.Errorf("row wasn't updated") - } - return nil -} - -func (t *Tx) UserEmailExists(email string) (exists bool, err error) { - row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users WHERE email = ? and email_verified = 1)`, email) - err = row.Scan(&exists) - return -} diff --git a/database/tx_test.go b/database/tx_test.go deleted file mode 100644 index 7e90224..0000000 --- a/database/tx_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package database - -import ( - "github.com/1f349/tulip/password" - "github.com/MrMelon54/pronouns" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "golang.org/x/text/language" - "testing" - "time" -) - -func TestTx_ChangeUserPassword(t *testing.T) { - u := uuid.New() - pw, err := password.HashPassword("test") - assert.NoError(t, err) - d, err := Open("file::memory:") - assert.NoError(t, err) - _, err = d.db.Exec(`INSERT INTO users (subject, name, username, password, email, updated_at) VALUES (?, ?, ?, ?, ?, ?)`, u.String(), "Test", "test", pw, "test@localhost", updatedAt()) - assert.NoError(t, err) - tx, err := d.Begin() - assert.NoError(t, err) - err = tx.ChangeUserPassword(u.String(), "test", "new") - assert.NoError(t, err) - assert.NoError(t, tx.Commit()) - query, err := d.db.Query(`SELECT password FROM users WHERE subject = ? AND username = ?`, u.String(), "test") - assert.NoError(t, err) - assert.True(t, query.Next()) - var oldPw password.HashString - assert.NoError(t, query.Scan(&oldPw)) - assert.NoError(t, password.CheckPasswordHash(oldPw, "new")) - assert.NoError(t, query.Err()) - assert.NoError(t, query.Close()) -} - -func TestTx_ModifyUser(t *testing.T) { - u := uuid.New() - pw, err := password.HashPassword("test") - assert.NoError(t, err) - d, err := Open("file::memory:") - assert.NoError(t, err) - _, err = d.db.Exec(`INSERT INTO users (subject, name, username, password, email, updated_at) VALUES (?, ?, ?, ?, ?, ?)`, u.String(), "Test", "test", pw, "test@localhost", updatedAt()) - assert.NoError(t, err) - tx, err := d.Begin() - assert.NoError(t, err) - assert.NoError(t, tx.ModifyUser(u.String(), &UserPatch{ - Name: "example", - Pronouns: pronouns.TheyThem, - ZoneInfo: time.UTC, - Locale: language.AmericanEnglish, - })) -} diff --git a/database/types/userlocale.go b/database/types/userlocale.go new file mode 100644 index 0000000..aea629d --- /dev/null +++ b/database/types/userlocale.go @@ -0,0 +1,38 @@ +package types + +import ( + "database/sql" + "encoding/json" + "fmt" + "golang.org/x/text/language" +) + +var ( + _ sql.Scanner = &UserLocale{} + _ json.Marshaler = &UserLocale{} + _ json.Unmarshaler = &UserLocale{} +) + +type UserLocale struct{ language.Tag } + +func (l *UserLocale) Scan(src any) error { + s, ok := src.(string) + if !ok { + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l) + } + lang, err := language.Parse(s) + if err != nil { + return err + } + l.Tag = lang + return nil +} +func (l UserLocale) MarshalJSON() ([]byte, error) { return json.Marshal(l.Tag.String()) } +func (l *UserLocale) UnmarshalJSON(bytes []byte) error { + var a string + err := json.Unmarshal(bytes, &a) + if err != nil { + return err + } + return l.Scan(a) +} diff --git a/database/types/userlocale_test.go b/database/types/userlocale_test.go new file mode 100644 index 0000000..cc53f80 --- /dev/null +++ b/database/types/userlocale_test.go @@ -0,0 +1,12 @@ +package types + +import ( + "github.com/stretchr/testify/assert" + "golang.org/x/text/language" + "testing" +) + +func TestUserLocale_MarshalJSON(t *testing.T) { + assert.Equal(t, "\"en-US\"", encode(UserLocale{language.AmericanEnglish})) + assert.Equal(t, "\"en-GB\"", encode(UserLocale{language.BritishEnglish})) +} diff --git a/database/types/userpronoun.go b/database/types/userpronoun.go new file mode 100644 index 0000000..49793f9 --- /dev/null +++ b/database/types/userpronoun.go @@ -0,0 +1,38 @@ +package types + +import ( + "database/sql" + "encoding/json" + "fmt" + "github.com/MrMelon54/pronouns" +) + +var ( + _ sql.Scanner = &UserPronoun{} + _ json.Marshaler = &UserPronoun{} + _ json.Unmarshaler = &UserPronoun{} +) + +type UserPronoun struct{ pronouns.Pronoun } + +func (p *UserPronoun) Scan(src any) error { + s, ok := src.(string) + if !ok { + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, p) + } + pro, err := pronouns.FindPronoun(s) + if err != nil { + return err + } + p.Pronoun = pro + return nil +} +func (p UserPronoun) MarshalJSON() ([]byte, error) { return json.Marshal(p.Pronoun.String()) } +func (p *UserPronoun) UnmarshalJSON(bytes []byte) error { + var a string + err := json.Unmarshal(bytes, &a) + if err != nil { + return err + } + return p.Scan(a) +} diff --git a/database/types/userpronoun_test.go b/database/types/userpronoun_test.go new file mode 100644 index 0000000..e42363e --- /dev/null +++ b/database/types/userpronoun_test.go @@ -0,0 +1,15 @@ +package types + +import ( + "github.com/MrMelon54/pronouns" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestUserPronoun_MarshalJSON(t *testing.T) { + assert.Equal(t, "\"they/them\"", encode(UserPronoun{pronouns.TheyThem})) + assert.Equal(t, "\"he/him\"", encode(UserPronoun{pronouns.HeHim})) + assert.Equal(t, "\"she/her\"", encode(UserPronoun{pronouns.SheHer})) + assert.Equal(t, "\"it/its\"", encode(UserPronoun{pronouns.ItIts})) + assert.Equal(t, "\"one/one's\"", encode(UserPronoun{pronouns.OneOnes})) +} diff --git a/database/types/userrole.go b/database/types/userrole.go new file mode 100644 index 0000000..85fadfd --- /dev/null +++ b/database/types/userrole.go @@ -0,0 +1,27 @@ +package types + +import "fmt" + +type UserRole int64 + +const ( + RoleMember UserRole = iota + RoleAdmin + RoleToDelete +) + +func (r UserRole) String() string { + switch r { + case RoleMember: + return "Member" + case RoleAdmin: + return "Admin" + case RoleToDelete: + return "ToDelete" + } + return fmt.Sprintf("UserRole{ %d }", r) +} + +func (r UserRole) IsValid() bool { + return r == RoleMember || r == RoleAdmin +} diff --git a/database/types/userzone.go b/database/types/userzone.go new file mode 100644 index 0000000..5ef952f --- /dev/null +++ b/database/types/userzone.go @@ -0,0 +1,45 @@ +package types + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "fmt" + "time" +) + +var ( + _ sql.Scanner = &UserZone{} + _ driver.Valuer = &UserZone{} + _ json.Marshaler = &UserZone{} + _ json.Unmarshaler = &UserZone{} +) + +type UserZone struct{ *time.Location } + +func (l *UserZone) Scan(src any) error { + s, ok := src.(string) + if !ok { + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l) + } + loc, err := time.LoadLocation(s) + if err != nil { + return err + } + l.Location = loc + return nil +} +func (l UserZone) Value() (driver.Value, error) { + return l.Location.String(), nil +} +func (l UserZone) MarshalJSON() ([]byte, error) { + return json.Marshal(l.Location.String()) +} +func (l *UserZone) UnmarshalJSON(bytes []byte) error { + var a string + err := json.Unmarshal(bytes, &a) + if err != nil { + return err + } + return l.Scan(a) +} diff --git a/database/types/userzone_test.go b/database/types/userzone_test.go new file mode 100644 index 0000000..a1f2ef5 --- /dev/null +++ b/database/types/userzone_test.go @@ -0,0 +1,14 @@ +package types + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestUserZone_MarshalJSON(t *testing.T) { + location, err := time.LoadLocation("Europe/London") + assert.NoError(t, err) + assert.Equal(t, "\"Europe/London\"", encode(UserZone{location})) + assert.Equal(t, "\"UTC\"", encode(UserZone{time.UTC})) +} diff --git a/database/types/utils_test.go b/database/types/utils_test.go new file mode 100644 index 0000000..9f56874 --- /dev/null +++ b/database/types/utils_test.go @@ -0,0 +1,11 @@ +package types + +import "encoding/json" + +func encode(data any) string { + j, err := json.Marshal(data) + if err != nil { + panic(err) + } + return string(j) +} diff --git a/database/users.sql.go b/database/users.sql.go index f818a0a..231e337 100644 --- a/database/users.sql.go +++ b/database/users.sql.go @@ -8,6 +8,10 @@ package database import ( "context" "database/sql" + "time" + + "github.com/1f349/tulip/database/types" + "github.com/1f349/tulip/password" ) const deleteTwoFactor = `-- name: DeleteTwoFactor :exec @@ -40,44 +44,20 @@ func (q *Queries) GetTwoFactor(ctx context.Context, subject string) (GetTwoFacto } const getUser = `-- name: GetUser :one -SELECT name, - username, - picture, - website, - email, - email_verified, - pronouns, - birthdate, - zoneinfo, - locale, - updated_at, - active +SELECT subject, name, username, password, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, role, updated_at, registered, active FROM users WHERE subject = ? LIMIT 1 ` -type GetUserRow struct { - Name string `json:"name"` - Username string `json:"username"` - Picture interface{} `json:"picture"` - Website interface{} `json:"website"` - Email string `json:"email"` - EmailVerified int64 `json:"email_verified"` - Pronouns interface{} `json:"pronouns"` - Birthdate sql.NullTime `json:"birthdate"` - Zoneinfo interface{} `json:"zoneinfo"` - Locale interface{} `json:"locale"` - UpdatedAt sql.NullTime `json:"updated_at"` - Active sql.NullInt64 `json:"active"` -} - -func (q *Queries) GetUser(ctx context.Context, subject string) (GetUserRow, error) { +func (q *Queries) GetUser(ctx context.Context, subject string) (User, error) { row := q.db.QueryRowContext(ctx, getUser, subject) - var i GetUserRow + var i User err := row.Scan( + &i.Subject, &i.Name, &i.Username, + &i.Password, &i.Picture, &i.Website, &i.Email, @@ -86,12 +66,51 @@ func (q *Queries) GetUser(ctx context.Context, subject string) (GetUserRow, erro &i.Birthdate, &i.Zoneinfo, &i.Locale, + &i.Role, &i.UpdatedAt, + &i.Registered, &i.Active, ) return i, err } +const getUserDisplayName = `-- name: GetUserDisplayName :one +SELECT name +FROM users +WHERE subject = ? +` + +func (q *Queries) GetUserDisplayName(ctx context.Context, subject string) (string, error) { + row := q.db.QueryRowContext(ctx, getUserDisplayName, subject) + var name string + err := row.Scan(&name) + return name, err +} + +const getUserRole = `-- name: GetUserRole :one +SELECT role +FROM users +WHERE subject = ? +` + +func (q *Queries) GetUserRole(ctx context.Context, subject string) (types.UserRole, error) { + row := q.db.QueryRowContext(ctx, getUserRole, subject) + var role types.UserRole + err := row.Scan(&role) + return role, err +} + +const hasTwoFactor = `-- name: HasTwoFactor :one +SELECT cast(EXISTS(SELECT 1 FROM otp WHERE subject = ?) AS BOOLEAN) +` + +func (q *Queries) HasTwoFactor(ctx context.Context, subject string) (bool, error) { + row := q.db.QueryRowContext(ctx, hasTwoFactor, subject) + var column_1 bool + err := row.Scan(&column_1) + return column_1, err +} + const hasUser = `-- name: HasUser :one SELECT cast(count(subject) AS BOOLEAN) AS hasUser FROM users @@ -118,15 +137,15 @@ WHERE subject = ? ` type ModifyUserParams struct { - Name string `json:"name"` - Picture interface{} `json:"picture"` - Website interface{} `json:"website"` - Pronouns interface{} `json:"pronouns"` - Birthdate sql.NullTime `json:"birthdate"` - Zoneinfo interface{} `json:"zoneinfo"` - Locale interface{} `json:"locale"` - UpdatedAt sql.NullTime `json:"updated_at"` - Subject string `json:"subject"` + Name string `json:"name"` + Picture string `json:"picture"` + Website string `json:"website"` + Pronouns types.UserPronoun `json:"pronouns"` + Birthdate sql.NullTime `json:"birthdate"` + Zoneinfo types.UserZone `json:"zoneinfo"` + Locale types.UserLocale `json:"locale"` + UpdatedAt time.Time `json:"updated_at"` + Subject string `json:"subject"` } func (q *Queries) ModifyUser(ctx context.Context, arg ModifyUserParams) (int64, error) { @@ -171,15 +190,15 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ` type addUserParams struct { - Subject string `json:"subject"` - Name string `json:"name"` - Username string `json:"username"` - Password string `json:"password"` - Email string `json:"email"` - EmailVerified int64 `json:"email_verified"` - Role int64 `json:"role"` - UpdatedAt sql.NullTime `json:"updated_at"` - Active sql.NullInt64 `json:"active"` + Subject string `json:"subject"` + Name string `json:"name"` + Username string `json:"username"` + Password password.HashString `json:"password"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Role types.UserRole `json:"role"` + UpdatedAt time.Time `json:"updated_at"` + Active bool `json:"active"` } func (q *Queries) addUser(ctx context.Context, arg addUserParams) error { @@ -206,10 +225,10 @@ WHERE subject = ? ` type changeUserPasswordParams struct { - Password string `json:"password"` - UpdatedAt sql.NullTime `json:"updated_at"` - Subject string `json:"subject"` - Password_2 string `json:"password_2"` + Password password.HashString `json:"password"` + UpdatedAt time.Time `json:"updated_at"` + Subject string `json:"subject"` + Password_2 password.HashString `json:"password_2"` } func (q *Queries) changeUserPassword(ctx context.Context, arg changeUserPasswordParams) (int64, error) { @@ -233,11 +252,11 @@ LIMIT 1 ` type checkLoginRow struct { - Subject string `json:"subject"` - Password string `json:"password"` - Column3 int64 `json:"column_3"` - Email string `json:"email"` - EmailVerified int64 `json:"email_verified"` + Subject string `json:"subject"` + Password password.HashString `json:"password"` + Column3 int64 `json:"column_3"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` } func (q *Queries) checkLogin(ctx context.Context, username string) (checkLoginRow, error) { @@ -259,9 +278,9 @@ FROM users WHERE subject = ? ` -func (q *Queries) getUserPassword(ctx context.Context, subject string) (string, error) { +func (q *Queries) getUserPassword(ctx context.Context, subject string) (password.HashString, error) { row := q.db.QueryRowContext(ctx, getUserPassword, subject) - var password string + var password password.HashString err := row.Scan(&password) return password, err } diff --git a/server/auth.go b/server/auth.go index b0c09ab..a57856b 100644 --- a/server/auth.go +++ b/server/auth.go @@ -4,6 +4,7 @@ import ( "github.com/1f349/mjwt" "github.com/1f349/mjwt/auth" "github.com/1f349/tulip/database" + "github.com/1f349/tulip/database/types" "github.com/julienschmidt/httprouter" "net/http" "net/url" @@ -30,14 +31,14 @@ func (u UserAuth) IsGuest() bool { func (h *HttpServer) RequireAdminAuthentication(next UserHandler) httprouter.Handle { return h.RequireAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { - var role database.UserRole - if h.DbTx(rw, func(tx *database.Tx) (err error) { - role, err = tx.GetUserRole(auth.ID) + var role types.UserRole + if h.DbTx(rw, func(tx *database.Queries) (err error) { + role, err = tx.GetUserRole(req.Context(), auth.ID) return }) { return } - if role != database.RoleAdmin { + if role != types.RoleAdmin { http.Error(rw, "403 Forbidden", http.StatusForbidden) return } diff --git a/server/db.go b/server/db.go index 4836f93..d6d970f 100644 --- a/server/db.go +++ b/server/db.go @@ -9,25 +9,13 @@ import ( // DbTx wraps a database transaction with http error messages and a simple action // function. If the action function returns an error the transaction will be // rolled back. If there is no error then the transaction is committed. -func (h *HttpServer) DbTx(rw http.ResponseWriter, action func(tx *database.Tx) error) bool { - tx, err := h.db.Begin() - if err != nil { - http.Error(rw, "Failed to begin database transaction", http.StatusInternalServerError) - return true - } - defer tx.Rollback() - - err = action(tx) +func (h *HttpServer) DbTx(rw http.ResponseWriter, action func(db *database.Queries) error) bool { + err := action(h.db) if err != nil { http.Error(rw, "Database error", http.StatusInternalServerError) log.Println("Database action error:", err) return true } - err = tx.Commit() - if err != nil { - http.Error(rw, "Database error", http.StatusInternalServerError) - log.Println("Database commit error:", err) - } return false } diff --git a/server/edit.go b/server/edit.go index 92633ae..4c96d5f 100644 --- a/server/edit.go +++ b/server/edit.go @@ -11,12 +11,12 @@ import ( "time" ) -func (h *HttpServer) EditGet(rw http.ResponseWriter, _ *http.Request, _ httprouter.Params, auth UserAuth) { - var user *database.User +func (h *HttpServer) EditGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { + var user database.User - if h.DbTx(rw, func(tx *database.Tx) error { + if h.DbTx(rw, func(tx *database.Queries) error { var err error - user, err = tx.GetUser(auth.ID) + user, err = tx.GetUser(req.Context(), auth.ID) if err != nil { return fmt.Errorf("failed to read user data: %w", err) } @@ -64,8 +64,19 @@ func (h *HttpServer) EditPost(rw http.ResponseWriter, req *http.Request, _ httpr _, _ = fmt.Fprintln(rw, "\n") return } - if h.DbTx(rw, func(tx *database.Tx) error { - if err := tx.ModifyUser(auth.ID, &patch); err != nil { + m := database.ModifyUserParams{ + Name: patch.Name, + Picture: patch.Picture, + Website: patch.Website, + Pronouns: patch.Pronouns, + Birthdate: patch.Birthdate, + Zoneinfo: patch.ZoneInfo, + Locale: patch.Locale, + UpdatedAt: time.Now(), + Subject: auth.ID, + } + if h.DbTx(rw, func(tx *database.Queries) error { + if _, err := tx.ModifyUser(req.Context(), m); err != nil { return fmt.Errorf("failed to modify user info: %w", err) } return nil diff --git a/server/home.go b/server/home.go index 759119c..802fdfa 100644 --- a/server/home.go +++ b/server/home.go @@ -3,6 +3,7 @@ package server import ( "fmt" "github.com/1f349/tulip/database" + "github.com/1f349/tulip/database/types" "github.com/1f349/tulip/pages" "github.com/google/uuid" "github.com/julienschmidt/httprouter" @@ -29,18 +30,19 @@ func (h *HttpServer) Home(rw http.ResponseWriter, req *http.Request, _ httproute return } - var userWithName *database.User + var userWithName string + var userRole types.UserRole var hasTwoFactor bool - if h.DbTx(rw, func(tx *database.Tx) (err error) { - userWithName, err = tx.GetUserDisplayName(auth.ID) + if h.DbTx(rw, func(tx *database.Queries) (err error) { + userWithName, err = tx.GetUserDisplayName(req.Context(), auth.ID) if err != nil { return fmt.Errorf("failed to get user display name: %w", err) } - hasTwoFactor, err = tx.HasTwoFactor(auth.ID) + hasTwoFactor, err = tx.HasTwoFactor(req.Context(), auth.ID) if err != nil { return fmt.Errorf("failed to get user two factor state: %w", err) } - userWithName.Role, err = tx.GetUserRole(auth.ID) + userRole, err = tx.GetUserRole(req.Context(), auth.ID) if err != nil { return fmt.Errorf("failed to get user role: %w", err) } @@ -54,6 +56,6 @@ func (h *HttpServer) Home(rw http.ResponseWriter, req *http.Request, _ httproute "User": userWithName, "Nonce": lNonce, "OtpEnabled": hasTwoFactor, - "IsAdmin": userWithName.Role, + "IsAdmin": userRole == types.RoleAdmin, }) } diff --git a/server/id_token.go b/server/id_token.go index 58af98f..4a59edc 100644 --- a/server/id_token.go +++ b/server/id_token.go @@ -1,6 +1,7 @@ package server import ( + "context" "github.com/1f349/mjwt" "github.com/1f349/tulip/database" "github.com/go-oauth2/oauth2/v4" @@ -9,7 +10,7 @@ import ( "strings" ) -func addIdTokenSupport(srv *server.Server, db *database.DB, key mjwt.Signer) { +func addIdTokenSupport(srv *server.Server, db *database.Queries, key mjwt.Signer) { srv.SetExtensionFieldsHandler(func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) { scope := ti.GetScope() if containsScope(scope, "openid") { @@ -30,18 +31,13 @@ type IdTokenClaims struct{} func (a IdTokenClaims) Valid() error { return nil } func (a IdTokenClaims) Type() string { return "access-token" } -func generateIDToken(ti oauth2.TokenInfo, us *database.DB, key mjwt.Signer) (token string, err error) { - tx, err := us.Begin() +func generateIDToken(ti oauth2.TokenInfo, us *database.Queries, key mjwt.Signer) (token string, err error) { + user, err := us.GetUser(context.Background(), ti.GetUserID()) if err != nil { return "", err } - user, err := tx.GetUser(ti.GetUserID()) - if err != nil { - return "", err - } - tx.Rollback() - token, err = key.GenerateJwt(user.Sub, "", jwt.ClaimStrings{ti.GetClientID()}, ti.GetAccessExpiresIn(), IdTokenClaims{}) + token, err = key.GenerateJwt(user.Subject, "", jwt.ClaimStrings{ti.GetClientID()}, ti.GetAccessExpiresIn(), IdTokenClaims{}) return } diff --git a/server/login.go b/server/login.go index 15061a2..fd8a14f 100644 --- a/server/login.go +++ b/server/login.go @@ -62,7 +62,7 @@ func (h *HttpServer) LoginPost(rw http.ResponseWriter, req *http.Request, _ http var loginMismatch byte var hasOtp bool - if h.DbTx(rw, func(tx *database.Tx) error { + if h.DbTx(rw, func(tx *database.Queries) error { loginUser, hasOtpRaw, hasVerifiedEmail, err := tx.CheckLogin(un, pw) if err != nil { if errors.Is(err, sql.ErrNoRows) || errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { @@ -176,7 +176,7 @@ func (h *HttpServer) LoginResetPasswordPost(rw http.ResponseWriter, req *http.Re } var emailExists bool - if h.DbTx(rw, func(tx *database.Tx) (err error) { + if h.DbTx(rw, func(tx *database.Queries) (err error) { emailExists, err = tx.UserEmailExists(email) return err }) { diff --git a/server/mail.go b/server/mail.go index dbe591c..a188c55 100644 --- a/server/mail.go +++ b/server/mail.go @@ -18,7 +18,7 @@ func (h *HttpServer) MailVerify(rw http.ResponseWriter, _ *http.Request, params http.Error(rw, "Invalid email verification code", http.StatusBadRequest) return } - if h.DbTx(rw, func(tx *database.Tx) error { + if h.DbTx(rw, func(tx *database.Queries) error { return tx.VerifyUserEmail(userSub) }) { return @@ -75,7 +75,7 @@ func (h *HttpServer) MailPasswordPost(rw http.ResponseWriter, req *http.Request, h.mailLinkCache.Delete(k) // reset password database call - if h.DbTx(rw, func(tx *database.Tx) error { + if h.DbTx(rw, func(tx *database.Queries) error { return tx.UserResetPassword(userSub, pw) }) { return @@ -94,12 +94,12 @@ func (h *HttpServer) MailDelete(rw http.ResponseWriter, _ *http.Request, params return } var userInfo *database.User - if h.DbTx(rw, func(tx *database.Tx) (err error) { + if h.DbTx(rw, func(tx *database.Queries) (err error) { userInfo, err = tx.GetUser(userSub) if err != nil { return } - return tx.UpdateUser(userSub, database.RoleToDelete, false) + return tx.UpdateUser(userSub, types.RoleToDelete, false) }) { return } diff --git a/server/manage-apps.go b/server/manage-apps.go index 586174f..29a554f 100644 --- a/server/manage-apps.go +++ b/server/manage-apps.go @@ -22,14 +22,14 @@ func (h *HttpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _ } } - var role database.UserRole - var appList []database.ClientInfoDbOutput - if h.DbTx(rw, func(tx *database.Tx) (err error) { + var role types.UserRole + var appList []database.ClientStore + if h.DbTx(rw, func(tx *database.Queries) (err error) { role, err = tx.GetUserRole(auth.ID) if err != nil { return } - appList, err = tx.GetAppList(auth.ID, role == database.RoleAdmin, offset) + appList, err = tx.GetAppList(auth.ID, role == types.RoleAdmin, offset) return }) { return @@ -39,7 +39,7 @@ func (h *HttpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _ "ServiceName": h.conf.ServiceName, "Apps": appList, "Offset": offset, - "IsAdmin": role == database.RoleAdmin, + "IsAdmin": role == types.RoleAdmin, "NewAppName": q.Get("NewAppName"), "NewAppSecret": q.Get("NewAppSecret"), } @@ -76,14 +76,14 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _ active := req.Form.Has("active") if sso { - var role database.UserRole - if h.DbTx(rw, func(tx *database.Tx) (err error) { + var role types.UserRole + if h.DbTx(rw, func(tx *database.Queries) (err error) { role, err = tx.GetUserRole(auth.ID) return }) { return } - if role != database.RoleAdmin { + if role != types.RoleAdmin { http.Error(rw, "400 Bad Request: Only admin users can create SSO client applications", http.StatusBadRequest) return } @@ -91,13 +91,13 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _ switch action { case "create": - if h.DbTx(rw, func(tx *database.Tx) error { + if h.DbTx(rw, func(tx *database.Queries) error { return tx.InsertClientApp(name, domain, public, sso, active, auth.ID) }) { return } case "edit": - if h.DbTx(rw, func(tx *database.Tx) error { + if h.DbTx(rw, func(tx *database.Queries) error { return tx.UpdateClientApp(req.Form.Get("subject"), auth.ID, name, domain, public, sso, active) }) { return @@ -105,7 +105,7 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _ case "secret": var info oauth2.ClientInfo var secret string - if h.DbTx(rw, func(tx *database.Tx) error { + if h.DbTx(rw, func(tx *database.Queries) error { sub := req.Form.Get("subject") info, err = tx.GetClientInfo(sub) if err != nil { diff --git a/server/manage-users.go b/server/manage-users.go index 87d38f6..c40781c 100644 --- a/server/manage-users.go +++ b/server/manage-users.go @@ -3,6 +3,7 @@ package server import ( "errors" "github.com/1f349/tulip/database" + "github.com/1f349/tulip/database/types" "github.com/1f349/tulip/pages" "github.com/emersion/go-message/mail" "github.com/google/uuid" @@ -27,9 +28,9 @@ func (h *HttpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _ } } - var role database.UserRole + var role types.UserRole var userList []database.User - if h.DbTx(rw, func(tx *database.Tx) (err error) { + if h.DbTx(rw, func(tx *database.Queries) (err error) { role, err = tx.GetUserRole(auth.ID) if err != nil { return @@ -39,7 +40,7 @@ func (h *HttpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _ }) { return } - if role != database.RoleAdmin { + if role != types.RoleAdmin { http.Error(rw, "403 Forbidden", http.StatusForbidden) return } @@ -76,14 +77,14 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request, return } - var role database.UserRole - if h.DbTx(rw, func(tx *database.Tx) (err error) { + var role types.UserRole + if h.DbTx(rw, func(tx *database.Queries) (err error) { role, err = tx.GetUserRole(auth.ID) return }) { return } - if role != database.RoleAdmin { + if role != types.RoleAdmin { http.Error(rw, "400 Bad Request: Only admin users can manage users", http.StatusBadRequest) return } @@ -116,7 +117,7 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request, addrDomain := address.Address[n+1:] var userSub uuid.UUID - if h.DbTx(rw, func(tx *database.Tx) (err error) { + if h.DbTx(rw, func(tx *database.Queries) (err error) { userSub, err = tx.InsertUser(name, username, "", email, addrDomain == h.conf.Namespace, newRole, active) return err }) { @@ -136,7 +137,7 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request, return } case "edit": - if h.DbTx(rw, func(tx *database.Tx) error { + if h.DbTx(rw, func(tx *database.Queries) error { sub := req.Form.Get("subject") return tx.UpdateUser(sub, newRole, active) }) { @@ -151,12 +152,12 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request, http.Redirect(rw, req, redirectUrl.String(), http.StatusFound) } -func parseRoleValue(role string) (database.UserRole, error) { +func parseRoleValue(role string) (types.UserRole, error) { switch role { case "member": - return database.RoleMember, nil + return types.RoleMember, nil case "admin": - return database.RoleAdmin, nil + return types.RoleAdmin, nil } return 0, errors.New("invalid role value") } diff --git a/server/oauth.go b/server/oauth.go index 907adf8..1b9bf88 100644 --- a/server/oauth.go +++ b/server/oauth.go @@ -79,7 +79,7 @@ func (h *HttpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request var user *database.User var hasOtp bool - if h.DbTx(rw, func(tx *database.Tx) (err error) { + if h.DbTx(rw, func(tx *database.Queries) (err error) { user, err = tx.GetUserDisplayName(auth.ID) if err != nil { return diff --git a/server/otp.go b/server/otp.go index 0e22b5b..17df3c1 100644 --- a/server/otp.go +++ b/server/otp.go @@ -47,7 +47,7 @@ func (h *HttpServer) fetchAndValidateOtp(rw http.ResponseWriter, sub, code strin var hasOtp bool var secret string var digits int - if h.DbTx(rw, func(tx *database.Tx) (err error) { + if h.DbTx(rw, func(tx *database.Queries) (err error) { hasOtp, err = tx.HasTwoFactor(sub) if err != nil { return @@ -86,7 +86,7 @@ func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ ht return } - if h.DbTx(rw, func(tx *database.Tx) error { + if h.DbTx(rw, func(tx *database.Queries) error { return tx.SetTwoFactor(auth.ID, "", 0) }) { return @@ -118,7 +118,7 @@ func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ ht if secret == "" { // get user email var email string - if h.DbTx(rw, func(tx *database.Tx) error { + if h.DbTx(rw, func(tx *database.Queries) error { var err error email, err = tx.GetUserEmail(auth.ID) return err @@ -167,7 +167,7 @@ func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ ht return } - if h.DbTx(rw, func(tx *database.Tx) error { + if h.DbTx(rw, func(tx *database.Queries) error { return tx.SetTwoFactor(auth.ID, secret, digits) }) { return diff --git a/server/server.go b/server/server.go index 4491b82..5deb729 100644 --- a/server/server.go +++ b/server/server.go @@ -31,7 +31,7 @@ type HttpServer struct { r *httprouter.Router oauthSrv *server.Server oauthMgr *manage.Manager - db *database.DB + db *database.Queries conf Conf signingKey mjwt.Signer @@ -50,7 +50,7 @@ type mailLinkKey struct { data string } -func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Server { +func NewHttpServer(conf Conf, db *database.Queries, signingKey mjwt.Signer) *http.Server { r := httprouter.New() // remove last slash from baseUrl @@ -191,10 +191,10 @@ func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Ser return } - var userData *database.User + var userData database.GetUserRow - if hs.DbTx(rw, func(tx *database.Tx) (err error) { - userData, err = tx.GetUser(userId) + if hs.DbTx(rw, func(tx *database.Queries) (err error) { + userData, err = tx.GetUser(req.Context(), userId) return err }) { return diff --git a/sqlc.yaml b/sqlc.yaml index 953e616..998fc1d 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -9,7 +9,13 @@ sql: out: "database" emit_json_tags: true overrides: - - column: "routes.flags" - go_type: "github.com/1f349/violet/target.Flags" - - column: "redirects.flags" - go_type: "github.com/1f349/violet/target.Flags" + - column: "users.password" + go_type: "github.com/1f349/tulip/password.HashString" + - column: "users.role" + go_type: "github.com/1f349/tulip/database/types.UserRole" + - column: "users.pronouns" + go_type: "github.com/1f349/tulip/database/types.UserPronoun" + - column: "users.zoneinfo" + go_type: "github.com/1f349/tulip/database/types.UserZone" + - column: "users.locale" + go_type: "github.com/1f349/tulip/database/types.UserLocale"