Correct refactored database calls

This commit is contained in:
Melon 2024-03-11 12:39:52 +00:00
parent fbd49da2db
commit 37570e2157
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
39 changed files with 604 additions and 934 deletions

View File

@ -7,20 +7,16 @@ import (
)
type ClientStore struct {
db *database.DB
db *database.Queries
}
var _ oauth2.ClientStore = &ClientStore{}
func New(db *database.DB) *ClientStore {
func New(db *database.Queries) *ClientStore {
return &ClientStore{db: db}
}
func (c *ClientStore) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) {
tx, err := c.db.BeginCtx(ctx)
if err != nil {
return nil, err
}
defer tx.Rollback()
return tx.GetClientInfo(id)
a, err := c.db.GetClientInfo(ctx, id)
return &a, err
}

View File

@ -118,7 +118,7 @@ func genHmacKey() []byte {
return a
}
func checkDbHasUser(db *database.DB) error {
func checkDbHasUser(db *database.Queries) error {
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("failed to start transaction: %w", err)
@ -126,7 +126,7 @@ func checkDbHasUser(db *database.DB) error {
defer tx.Rollback()
if err := tx.HasUser(); err != nil {
if errors.Is(err, sql.ErrNoRows) {
_, err := tx.InsertUser("Admin", "admin", "admin", "admin@localhost", false, database.RoleAdmin, false)
_, err := tx.InsertUser("Admin", "admin", "admin", "admin@localhost", false, types.RoleAdmin, false)
if err != nil {
return fmt.Errorf("failed to add user: %w", err)
}

View File

@ -0,0 +1,82 @@
package database
import (
"database/sql"
"fmt"
"github.com/1f349/tulip/database/types"
"github.com/MrMelon54/pronouns"
"github.com/go-oauth2/oauth2/v4"
"golang.org/x/text/language"
"net/url"
"time"
)
type UserPatch struct {
Name string
Picture string
Website string
Pronouns types.UserPronoun
Birthdate sql.NullTime
ZoneInfo types.UserZone
Locale types.UserLocale
}
func (u *UserPatch) ParseFromForm(v url.Values) (safeErrs []error) {
var err error
u.Name = v.Get("name")
u.Picture = v.Get("picture")
u.Website = v.Get("website")
if v.Has("reset_pronouns") {
u.Pronouns.Pronoun = pronouns.TheyThem
} else {
u.Pronouns.Pronoun, err = pronouns.FindPronoun(v.Get("pronouns"))
if err != nil {
safeErrs = append(safeErrs, fmt.Errorf("invalid pronoun selected"))
}
}
if v.Has("reset_birthdate") || v.Get("birthdate") == "" {
u.Birthdate = sql.NullTime{}
} else {
u.Birthdate = sql.NullTime{Valid: true}
u.Birthdate.Time, err = time.Parse(time.DateOnly, v.Get("birthdate"))
if err != nil {
safeErrs = append(safeErrs, fmt.Errorf("invalid time selected"))
}
}
if v.Has("reset_zoneinfo") {
u.ZoneInfo.Location = time.UTC
} else {
u.ZoneInfo.Location, err = time.LoadLocation(v.Get("zoneinfo"))
if err != nil {
safeErrs = append(safeErrs, fmt.Errorf("invalid timezone selected"))
}
}
if v.Has("reset_locale") {
u.Locale.Tag = language.AmericanEnglish
} else {
u.Locale.Tag, err = language.Parse(v.Get("locale"))
if err != nil {
safeErrs = append(safeErrs, fmt.Errorf("invalid language selected"))
}
}
return
}
var _ oauth2.ClientInfo = &ClientStore{}
func (c *ClientStore) GetID() string { return c.Subject }
func (c *ClientStore) GetSecret() string { return c.Secret }
func (c *ClientStore) GetDomain() string { return c.Domain }
func (c *ClientStore) IsPublic() bool { return c.Public }
func (c *ClientStore) GetUserID() string { return c.Owner }
// GetName is an extra field for the oauth handler to display the application
// name
func (c *ClientStore) GetName() string { return c.Name }
// IsSSO is an extra field for the oauth handler to skip the user input stage
// this is for trusted applications to get permissions without asking the user
func (c *ClientStore) IsSSO() bool { return c.Sso }
// IsActive is an extra field for the app manager to get the active state
func (c *ClientStore) IsActive() bool { return c.Active }

View File

@ -1,145 +0,0 @@
package database
import (
"database/sql"
"encoding/json"
"fmt"
"github.com/MrMelon54/pronouns"
"golang.org/x/text/language"
"time"
)
var (
_, _, _, _, _ sql.Scanner = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{}
_, _, _, _, _ json.Marshaler = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{}
_, _, _, _, _ json.Unmarshaler = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{}
)
func marshalValueOrNull(null bool, data any) ([]byte, error) {
if null {
return json.Marshal(nil)
}
return json.Marshal(data)
}
type NullStringScanner struct{ sql.NullString }
func (s *NullStringScanner) Null() bool { return !s.Valid }
func (s *NullStringScanner) Scan(src any) error { return s.NullString.Scan(src) }
func (s NullStringScanner) MarshalJSON() ([]byte, error) {
return marshalValueOrNull(s.Null(), s.NullString.String)
}
func (s *NullStringScanner) UnmarshalJSON(bytes []byte) error {
if string(bytes) == "null" {
return s.Scan(nil)
}
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return s.Scan(&a)
}
func (s NullStringScanner) String() string {
if s.Null() {
return ""
}
return s.NullString.String
}
type NullDateScanner struct{ sql.NullTime }
func (t *NullDateScanner) Null() bool { return !t.Valid }
func (t *NullDateScanner) Scan(src any) error { return t.NullTime.Scan(src) }
func (t NullDateScanner) MarshalJSON() ([]byte, error) {
return marshalValueOrNull(t.Null(), t.Time.UTC().Format(time.DateOnly))
}
func (t *NullDateScanner) UnmarshalJSON(bytes []byte) error {
if string(bytes) == "null" {
return t.Scan(nil)
}
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return t.Scan(&a)
}
func (t NullDateScanner) String() string {
if t.Null() {
return ""
}
return t.NullTime.Time.UTC().Format(time.DateOnly)
}
type LocationScanner struct{ *time.Location }
func (l *LocationScanner) Scan(src any) error {
s, ok := src.(string)
if !ok {
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
}
loc, err := time.LoadLocation(s)
if err != nil {
return err
}
l.Location = loc
return nil
}
func (l LocationScanner) MarshalJSON() ([]byte, error) { return json.Marshal(l.Location.String()) }
func (l *LocationScanner) UnmarshalJSON(bytes []byte) error {
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return l.Scan(a)
}
type LocaleScanner struct{ language.Tag }
func (l *LocaleScanner) Scan(src any) error {
s, ok := src.(string)
if !ok {
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
}
lang, err := language.Parse(s)
if err != nil {
return err
}
l.Tag = lang
return nil
}
func (l LocaleScanner) MarshalJSON() ([]byte, error) { return json.Marshal(l.Tag.String()) }
func (l *LocaleScanner) UnmarshalJSON(bytes []byte) error {
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return l.Scan(a)
}
type PronounScanner struct{ pronouns.Pronoun }
func (p *PronounScanner) Scan(src any) error {
s, ok := src.(string)
if !ok {
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, p)
}
pro, err := pronouns.FindPronoun(s)
if err != nil {
return err
}
p.Pronoun = pro
return nil
}
func (p PronounScanner) MarshalJSON() ([]byte, error) { return json.Marshal(p.Pronoun.String()) }
func (p *PronounScanner) UnmarshalJSON(bytes []byte) error {
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return p.Scan(a)
}

View File

@ -1,52 +0,0 @@
package database
import (
"database/sql"
"encoding/json"
"github.com/MrMelon54/pronouns"
"github.com/stretchr/testify/assert"
"golang.org/x/text/language"
"testing"
"time"
)
func encode(data any) string {
j, err := json.Marshal(map[string]any{"value": data})
if err != nil {
panic(err)
}
return string(j)
}
func TestStringScanner_MarshalJSON(t *testing.T) {
assert.Equal(t, "{\"value\":\"Hello world\"}", encode(NullStringScanner{sql.NullString{String: "Hello world", Valid: true}}))
assert.Equal(t, "{\"value\":null}", encode(NullStringScanner{sql.NullString{String: "Hello world", Valid: false}}))
}
func TestDateScanner_MarshalJSON(t *testing.T) {
location, err := time.LoadLocation("Europe/London")
assert.NoError(t, err)
assert.Equal(t, "{\"value\":\"2006-01-02\"}", encode(NullDateScanner{sql.NullTime{Time: time.Date(2006, time.January, 2, 0, 0, 0, 0, time.UTC), Valid: true}}))
assert.Equal(t, "{\"value\":\"2006-08-01\"}", encode(NullDateScanner{sql.NullTime{Time: time.Date(2006, time.August, 2, 0, 0, 0, 0, location), Valid: true}}))
assert.Equal(t, "{\"value\":null}", encode(NullDateScanner{}))
}
func TestLocationScanner_MarshalJSON(t *testing.T) {
location, err := time.LoadLocation("Europe/London")
assert.NoError(t, err)
assert.Equal(t, "{\"value\":\"Europe/London\"}", encode(LocationScanner{location}))
assert.Equal(t, "{\"value\":\"UTC\"}", encode(LocationScanner{time.UTC}))
}
func TestLocaleScanner_MarshalJSON(t *testing.T) {
assert.Equal(t, "{\"value\":\"en-US\"}", encode(LocaleScanner{language.AmericanEnglish}))
assert.Equal(t, "{\"value\":\"en-GB\"}", encode(LocaleScanner{language.BritishEnglish}))
}
func TestPronounScanner_MarshalJSON(t *testing.T) {
assert.Equal(t, "{\"value\":\"they/them\"}", encode(PronounScanner{pronouns.TheyThem}))
assert.Equal(t, "{\"value\":\"he/him\"}", encode(PronounScanner{pronouns.HeHim}))
assert.Equal(t, "{\"value\":\"she/her\"}", encode(PronounScanner{pronouns.SheHer}))
assert.Equal(t, "{\"value\":\"it/its\"}", encode(PronounScanner{pronouns.ItIts}))
assert.Equal(t, "{\"value\":\"one/one's\"}", encode(PronounScanner{pronouns.OneOnes}))
}

View File

@ -1,127 +0,0 @@
package database
import (
"database/sql"
"fmt"
"github.com/MrMelon54/pronouns"
"github.com/go-oauth2/oauth2/v4"
"golang.org/x/text/language"
"net/url"
"time"
)
type User struct {
Sub string `json:"sub"`
Name string `json:"name,omitempty"`
Username string `json:"username"`
Picture NullStringScanner `json:"picture,omitempty"`
Website NullStringScanner `json:"website,omitempty"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Pronouns PronounScanner `json:"pronouns,omitempty"`
Birthdate NullDateScanner `json:"birthdate,omitempty"`
ZoneInfo LocationScanner `json:"zoneinfo,omitempty"`
Locale LocaleScanner `json:"locale,omitempty"`
Role UserRole `json:"role"`
UpdatedAt time.Time `json:"updated_at"`
Active bool `json:"active"`
}
type UserRole int
const (
RoleMember UserRole = iota
RoleAdmin
RoleToDelete
)
func (r UserRole) String() string {
switch r {
case RoleMember:
return "Member"
case RoleAdmin:
return "Admin"
case RoleToDelete:
return "ToDelete"
}
return fmt.Sprintf("UserRole{ %d }", r)
}
func (r UserRole) IsValid() bool {
return r == RoleMember || r == RoleAdmin
}
type UserPatch struct {
Name string
Picture string
Website string
Pronouns pronouns.Pronoun
Birthdate sql.NullTime
ZoneInfo *time.Location
Locale language.Tag
}
func (u *UserPatch) ParseFromForm(v url.Values) (safeErrs []error) {
var err error
u.Name = v.Get("name")
u.Picture = v.Get("picture")
u.Website = v.Get("website")
if v.Has("reset_pronouns") {
u.Pronouns = pronouns.TheyThem
} else {
u.Pronouns, err = pronouns.FindPronoun(v.Get("pronouns"))
if err != nil {
safeErrs = append(safeErrs, fmt.Errorf("invalid pronoun selected"))
}
}
if v.Has("reset_birthdate") || v.Get("birthdate") == "" {
u.Birthdate = sql.NullTime{}
} else {
u.Birthdate = sql.NullTime{Valid: true}
u.Birthdate.Time, err = time.Parse(time.DateOnly, v.Get("birthdate"))
if err != nil {
safeErrs = append(safeErrs, fmt.Errorf("invalid time selected"))
}
}
if v.Has("reset_zoneinfo") {
u.ZoneInfo = time.UTC
} else {
u.ZoneInfo, err = time.LoadLocation(v.Get("zoneinfo"))
if err != nil {
safeErrs = append(safeErrs, fmt.Errorf("invalid timezone selected"))
}
}
if v.Has("reset_locale") {
u.Locale = language.AmericanEnglish
} else {
u.Locale, err = language.Parse(v.Get("locale"))
if err != nil {
safeErrs = append(safeErrs, fmt.Errorf("invalid language selected"))
}
}
return
}
type ClientInfoDbOutput struct {
Sub, Name, Secret, Domain, Owner string
Public, SSO, Active bool
}
var _ oauth2.ClientInfo = &ClientInfoDbOutput{}
func (c *ClientInfoDbOutput) GetID() string { return c.Sub }
func (c *ClientInfoDbOutput) GetSecret() string { return c.Secret }
func (c *ClientInfoDbOutput) GetDomain() string { return c.Domain }
func (c *ClientInfoDbOutput) IsPublic() bool { return c.Public }
func (c *ClientInfoDbOutput) GetUserID() string { return c.Owner }
// GetName is an extra field for the oauth handler to display the application
// name
func (c *ClientInfoDbOutput) GetName() string { return c.Name }
// IsSSO is an extra field for the oauth handler to skip the user input stage
// this is for trusted applications to get permissions without asking the user
func (c *ClientInfoDbOutput) IsSSO() bool { return c.SSO }
// IsActive is an extra field for the app manager to get the active state
func (c *ClientInfoDbOutput) IsActive() bool { return c.Active }

View File

@ -1,5 +0,0 @@
package database
import (
_ "github.com/mattn/go-sqlite3"
)

View File

@ -7,7 +7,6 @@ package database
import (
"context"
"database/sql"
)
const getAppList = `-- name: GetAppList :many
@ -25,13 +24,13 @@ type GetAppListParams struct {
}
type GetAppListRow struct {
Subject string `json:"subject"`
Name string `json:"name"`
Domain string `json:"domain"`
Owner string `json:"owner"`
Public sql.NullInt64 `json:"public"`
Sso sql.NullInt64 `json:"sso"`
Active sql.NullInt64 `json:"active"`
Subject string `json:"subject"`
Name string `json:"name"`
Domain string `json:"domain"`
Owner string `json:"owner"`
Public bool `json:"public"`
Sso bool `json:"sso"`
Active bool `json:"active"`
}
func (q *Queries) GetAppList(ctx context.Context, arg GetAppListParams) ([]GetAppListRow, error) {
@ -66,28 +65,21 @@ func (q *Queries) GetAppList(ctx context.Context, arg GetAppListParams) ([]GetAp
}
const getClientInfo = `-- name: GetClientInfo :one
SELECT secret, name, domain, public, sso, active
SELECT subject, name, secret, domain, owner, public, sso, active
FROM client_store
WHERE subject = ?
LIMIT 1
`
type GetClientInfoRow struct {
Secret string `json:"secret"`
Name string `json:"name"`
Domain string `json:"domain"`
Public sql.NullInt64 `json:"public"`
Sso sql.NullInt64 `json:"sso"`
Active sql.NullInt64 `json:"active"`
}
func (q *Queries) GetClientInfo(ctx context.Context, subject string) (GetClientInfoRow, error) {
func (q *Queries) GetClientInfo(ctx context.Context, subject string) (ClientStore, error) {
row := q.db.QueryRowContext(ctx, getClientInfo, subject)
var i GetClientInfoRow
var i ClientStore
err := row.Scan(
&i.Secret,
&i.Subject,
&i.Name,
&i.Secret,
&i.Domain,
&i.Owner,
&i.Public,
&i.Sso,
&i.Active,
@ -101,14 +93,14 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`
type InsertClientAppParams struct {
Subject string `json:"subject"`
Name string `json:"name"`
Secret string `json:"secret"`
Domain string `json:"domain"`
Owner string `json:"owner"`
Public sql.NullInt64 `json:"public"`
Sso sql.NullInt64 `json:"sso"`
Active sql.NullInt64 `json:"active"`
Subject string `json:"subject"`
Name string `json:"name"`
Secret string `json:"secret"`
Domain string `json:"domain"`
Owner string `json:"owner"`
Public bool `json:"public"`
Sso bool `json:"sso"`
Active bool `json:"active"`
}
func (q *Queries) InsertClientApp(ctx context.Context, arg InsertClientAppParams) error {
@ -137,13 +129,13 @@ WHERE subject = ?
`
type UpdateClientAppParams struct {
Name string `json:"name"`
Domain string `json:"domain"`
Public sql.NullInt64 `json:"public"`
Sso sql.NullInt64 `json:"sso"`
Active sql.NullInt64 `json:"active"`
Subject string `json:"subject"`
Owner string `json:"owner"`
Name string `json:"name"`
Domain string `json:"domain"`
Public bool `json:"public"`
Sso bool `json:"sso"`
Active bool `json:"active"`
Subject string `json:"subject"`
Owner string `json:"owner"`
}
func (q *Queries) UpdateClientApp(ctx context.Context, arg UpdateClientAppParams) error {

View File

@ -7,7 +7,9 @@ package database
import (
"context"
"database/sql"
"time"
"github.com/1f349/tulip/database/types"
)
const getUserList = `-- name: GetUserList :many
@ -25,15 +27,15 @@ LIMIT 25 OFFSET ?
`
type GetUserListRow struct {
Subject string `json:"subject"`
Name string `json:"name"`
Username string `json:"username"`
Picture interface{} `json:"picture"`
Email string `json:"email"`
EmailVerified int64 `json:"email_verified"`
Role int64 `json:"role"`
UpdatedAt sql.NullTime `json:"updated_at"`
Active sql.NullInt64 `json:"active"`
Subject string `json:"subject"`
Name string `json:"name"`
Username string `json:"username"`
Picture string `json:"picture"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Role types.UserRole `json:"role"`
UpdatedAt time.Time `json:"updated_at"`
Active bool `json:"active"`
}
func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListRow, error) {
@ -77,9 +79,9 @@ WHERE subject = ?
`
type UpdateUserRoleParams struct {
Active sql.NullInt64 `json:"active"`
Role int64 `json:"role"`
Subject string `json:"subject"`
Active bool `json:"active"`
Role types.UserRole `json:"role"`
Subject string `json:"subject"`
}
func (q *Queries) UpdateUserRole(ctx context.Context, arg UpdateUserRoleParams) error {

View File

@ -0,0 +1,4 @@
DROP TABLE users;
DROP INDEX username_index;
DROP TABLE client_store;
DROP TABLE otp;

View File

@ -1,39 +1,39 @@
CREATE TABLE IF NOT EXISTS users
CREATE TABLE users
(
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
name TEXT NOT NULL,
username TEXT UNIQUE NOT NULL,
password TEXT NOT NULL,
picture TEXT DEFAULT "" NOT NULL,
website TEXT DEFAULT "" NOT NULL,
picture TEXT DEFAULT '' NOT NULL,
website TEXT DEFAULT '' NOT NULL,
email TEXT NOT NULL,
email_verified INTEGER DEFAULT 0 NOT NULL,
pronouns TEXT DEFAULT "they/them" NOT NULL,
email_verified BOOLEAN DEFAULT 0 NOT NULL,
pronouns TEXT DEFAULT 'they/them' NOT NULL,
birthdate DATE,
zoneinfo TEXT DEFAULT "UTC" NOT NULL,
locale TEXT DEFAULT "en-US" NOT NULL,
zoneinfo TEXT DEFAULT 'UTC' NOT NULL,
locale TEXT DEFAULT 'en-US' NOT NULL,
role INTEGER DEFAULT 0 NOT NULL,
updated_at DATETIME,
registered INTEGER DEFAULT 0,
active INTEGER DEFAULT 1
updated_at DATETIME NOT NULL,
registered DATETIME NOT NULL,
active BOOLEAN DEFAULT 1 NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS username_index ON users (username);
CREATE UNIQUE INDEX username_index ON users (username);
CREATE TABLE IF NOT EXISTS client_store
CREATE TABLE client_store
(
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
name TEXT NOT NULL,
secret TEXT UNIQUE NOT NULL,
domain TEXT NOT NULL,
owner TEXT NOT NULL,
public INTEGER,
sso INTEGER,
active INTEGER DEFAULT 1,
public BOOLEAN NOT NULL,
sso BOOLEAN NOT NULL,
active BOOLEAN DEFAULT 1 NOT NULL,
FOREIGN KEY (owner) REFERENCES users (subject)
);
CREATE TABLE IF NOT EXISTS otp
CREATE TABLE otp
(
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
secret TEXT NOT NULL,

View File

@ -6,17 +6,21 @@ package database
import (
"database/sql"
"time"
"github.com/1f349/tulip/database/types"
"github.com/1f349/tulip/password"
)
type ClientStore struct {
Subject string `json:"subject"`
Name string `json:"name"`
Secret string `json:"secret"`
Domain string `json:"domain"`
Owner string `json:"owner"`
Public sql.NullInt64 `json:"public"`
Sso sql.NullInt64 `json:"sso"`
Active sql.NullInt64 `json:"active"`
Subject string `json:"subject"`
Name string `json:"name"`
Secret string `json:"secret"`
Domain string `json:"domain"`
Owner string `json:"owner"`
Public bool `json:"public"`
Sso bool `json:"sso"`
Active bool `json:"active"`
}
type Otp struct {
@ -26,20 +30,20 @@ type Otp struct {
}
type User struct {
Subject string `json:"subject"`
Name string `json:"name"`
Username string `json:"username"`
Password string `json:"password"`
Picture interface{} `json:"picture"`
Website interface{} `json:"website"`
Email string `json:"email"`
EmailVerified int64 `json:"email_verified"`
Pronouns interface{} `json:"pronouns"`
Birthdate sql.NullTime `json:"birthdate"`
Zoneinfo interface{} `json:"zoneinfo"`
Locale interface{} `json:"locale"`
Role int64 `json:"role"`
UpdatedAt sql.NullTime `json:"updated_at"`
Registered sql.NullInt64 `json:"registered"`
Active sql.NullInt64 `json:"active"`
Subject string `json:"subject"`
Name string `json:"name"`
Username string `json:"username"`
Password password.HashString `json:"password"`
Picture string `json:"picture"`
Website string `json:"website"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Pronouns types.UserPronoun `json:"pronouns"`
Birthdate sql.NullTime `json:"birthdate"`
Zoneinfo types.UserZone `json:"zoneinfo"`
Locale types.UserLocale `json:"locale"`
Role types.UserRole `json:"role"`
UpdatedAt time.Time `json:"updated_at"`
Registered time.Time `json:"registered"`
Active bool `json:"active"`
}

View File

@ -0,0 +1,47 @@
package database
import (
"context"
"github.com/1f349/tulip/database/types"
"github.com/1f349/tulip/password"
"github.com/google/uuid"
"time"
)
type AddUserParams struct {
Name string `json:"name"`
Username string `json:"username"`
Password string `json:"password"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Role types.UserRole `json:"role"`
UpdatedAt time.Time `json:"updated_at"`
Active bool `json:"active"`
}
func (q *Queries) AddUser(ctx context.Context, arg AddUserParams) (string, error) {
pwHash, err := password.HashPassword(arg.Password)
if err != nil {
return "", err
}
a := addUserParams{
Subject: uuid.NewString(),
Name: arg.Name,
Username: arg.Username,
Password: pwHash,
Email: arg.Email,
EmailVerified: arg.EmailVerified,
Role: arg.Role,
UpdatedAt: arg.UpdatedAt,
Active: arg.Active,
}
return a.Subject, q.addUser(ctx, a)
}
type CheckLoginRow struct {
Subject string `json:"subject"`
Password password.HashString `json:"password"`
HasTwoFactor bool `json:"hasTwoFactor"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
}

View File

@ -1,5 +1,5 @@
-- name: GetClientInfo :one
SELECT secret, name, domain, public, sso, active
SELECT *
FROM client_store
WHERE subject = ?
LIMIT 1;

View File

@ -13,22 +13,21 @@ WHERE username = ?
LIMIT 1;
-- name: GetUser :one
SELECT name,
username,
picture,
website,
email,
email_verified,
pronouns,
birthdate,
zoneinfo,
locale,
updated_at,
active
SELECT *
FROM users
WHERE subject = ?
LIMIT 1;
-- name: GetUserRole :one
SELECT role
FROM users
WHERE subject = ?;
-- name: GetUserDisplayName :one
SELECT name
FROM users
WHERE subject = ?;
-- name: getUserPassword :one
SELECT password
FROM users
@ -68,3 +67,6 @@ WHERE otp.subject = ?;
SELECT secret, digits
FROM otp
WHERE subject = ?;
-- name: HasTwoFactor :one
SELECT cast(EXISTS(SELECT 1 FROM otp WHERE subject = ?) AS BOOLEAN);

View File

@ -1,302 +0,0 @@
package database
import (
"database/sql"
"fmt"
"github.com/1f349/tulip/password"
"github.com/go-oauth2/oauth2/v4"
"github.com/google/uuid"
"time"
)
func updatedAt() string {
return time.Now().UTC().Format(time.DateTime)
}
type Tx struct{ tx *sql.Tx }
func (t *Tx) Commit() error {
return t.tx.Commit()
}
func (t *Tx) Rollback() {
_ = t.tx.Rollback()
}
func (t *Tx) HasUser() error {
var exists bool
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users)`)
err := row.Scan(&exists)
if err != nil {
return err
}
if !exists {
return sql.ErrNoRows
}
return nil
}
func (t *Tx) InsertUser(name, un, pw, email string, verifyEmail bool, role UserRole, active bool) (uuid.UUID, error) {
pwHash, err := password.HashPassword(pw)
if err != nil {
return uuid.UUID{}, err
}
u := uuid.New()
_, err = t.tx.Exec(`INSERT INTO users (subject, name, username, password, email, email_verified, role, updated_at, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, u, name, un, pwHash, email, verifyEmail, role, updatedAt(), active)
return u, err
}
func (t *Tx) CheckLogin(un, pw string) (*User, bool, bool, error) {
var u User
var pwHash password.HashString
var hasOtp, hasVerify bool
row := t.tx.QueryRow(`SELECT subject, password, EXISTS(SELECT 1 FROM otp WHERE otp.subject = users.subject), email, email_verified FROM users WHERE username = ?`, un)
err := row.Scan(&u.Sub, &pwHash, &hasOtp, &u.Email, &hasVerify)
if err != nil {
return nil, false, false, err
}
err = password.CheckPasswordHash(pwHash, pw)
return &u, hasOtp, hasVerify, err
}
func (t *Tx) GetUserDisplayName(sub string) (*User, error) {
var u User
row := t.tx.QueryRow(`SELECT name FROM users WHERE subject = ? LIMIT 1`, sub)
err := row.Scan(&u.Name)
u.Sub = sub
return &u, err
}
func (t *Tx) GetUserRole(sub string) (UserRole, error) {
var r UserRole
row := t.tx.QueryRow(`SELECT role FROM users WHERE subject = ? LIMIT 1`, sub)
err := row.Scan(&r)
return r, err
}
func (t *Tx) GetUser(sub string) (*User, error) {
var u User
row := t.tx.QueryRow(`SELECT name, username, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, updated_at, active FROM users WHERE subject = ?`, sub)
err := row.Scan(&u.Name, &u.Username, &u.Picture, &u.Website, &u.Email, &u.EmailVerified, &u.Pronouns, &u.Birthdate, &u.ZoneInfo, &u.Locale, &u.UpdatedAt, &u.Active)
u.Sub = sub
return &u, err
}
func (t *Tx) GetUserEmail(sub string) (string, error) {
var email string
row := t.tx.QueryRow(`SELECT email FROM users WHERE subject = ?`, sub)
err := row.Scan(&email)
return email, err
}
func (t *Tx) ChangeUserPassword(sub, pwOld, pwNew string) error {
q, err := t.tx.Query(`SELECT password FROM users WHERE subject = ?`, sub)
if err != nil {
return err
}
var pwHash password.HashString
if q.Next() {
err = q.Scan(&pwHash)
if err != nil {
return err
}
} else {
return fmt.Errorf("invalid user")
}
if err := q.Err(); err != nil {
return err
}
if err := q.Close(); err != nil {
return err
}
err = password.CheckPasswordHash(pwHash, pwOld)
if err != nil {
return err
}
pwNewHash, err := password.HashPassword(pwNew)
if err != nil {
return err
}
exec, err := t.tx.Exec(`UPDATE users SET password = ?, updated_at = ? WHERE subject = ? AND password = ?`, pwNewHash, updatedAt(), sub, pwHash)
if err != nil {
return err
}
affected, err := exec.RowsAffected()
if err != nil {
return err
}
if affected != 1 {
return fmt.Errorf("row wasn't updated")
}
return nil
}
func (t *Tx) ModifyUser(sub string, v *UserPatch) error {
exec, err := t.tx.Exec(
`UPDATE users
SET name = ?,
picture = ?,
website = ?,
pronouns = ?,
birthdate = ?,
zoneinfo = ?,
locale = ?,
updated_at = ?
WHERE subject = ?`,
v.Name,
v.Picture,
v.Website,
v.Pronouns.String(),
v.Birthdate,
v.ZoneInfo.String(),
v.Locale.String(),
updatedAt(),
sub,
)
if err != nil {
return err
}
affected, err := exec.RowsAffected()
if err != nil {
return err
}
if affected != 1 {
return fmt.Errorf("row wasn't updated")
}
return nil
}
func (t *Tx) SetTwoFactor(sub string, secret string, digits int) error {
if secret == "" && digits == 0 {
_, err := t.tx.Exec(`DELETE FROM otp WHERE otp.subject = ?`, sub)
return err
}
_, err := t.tx.Exec(`INSERT INTO otp(subject, secret, digits) VALUES (?, ?, ?) ON CONFLICT(subject) DO UPDATE SET secret = excluded.secret, digits = excluded.digits`, sub, secret, digits)
return err
}
func (t *Tx) GetTwoFactor(sub string) (string, int, error) {
var secret string
var digits int
row := t.tx.QueryRow(`SELECT secret, digits FROM otp WHERE subject = ?`, sub)
err := row.Scan(&secret, &digits)
if err != nil {
return "", 0, err
}
return secret, digits, nil
}
func (t *Tx) HasTwoFactor(sub string) (bool, error) {
var hasOtp bool
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM otp WHERE otp.subject = ?)`, sub)
err := row.Scan(&hasOtp)
if err != nil {
return false, err
}
return hasOtp, row.Err()
}
func (t *Tx) GetClientInfo(sub string) (oauth2.ClientInfo, error) {
var u ClientInfoDbOutput
row := t.tx.QueryRow(`SELECT secret, name, domain, public, sso, active FROM client_store WHERE subject = ? LIMIT 1`, sub)
err := row.Scan(&u.Secret, &u.Name, &u.Domain, &u.Public, &u.SSO, &u.Active)
u.Owner = sub
if !u.Active {
return nil, fmt.Errorf("client is not active")
}
return &u, err
}
func (t *Tx) GetAppList(owner string, admin bool, offset int) ([]ClientInfoDbOutput, error) {
var u []ClientInfoDbOutput
row, err := t.tx.Query(`SELECT subject, name, domain, owner, public, sso, active FROM client_store WHERE owner = ? OR ? = 1 LIMIT 25 OFFSET ?`, owner, admin, offset)
if err != nil {
return nil, err
}
defer row.Close()
for row.Next() {
var a ClientInfoDbOutput
err := row.Scan(&a.Sub, &a.Name, &a.Domain, &a.Owner, &a.Public, &a.SSO, &a.Active)
if err != nil {
return nil, err
}
u = append(u, a)
}
return u, row.Err()
}
func (t *Tx) InsertClientApp(name, domain string, public, sso, active bool, owner string) error {
u := uuid.New()
secret, err := password.GenerateApiSecret(70)
if err != nil {
return err
}
_, err = t.tx.Exec(`INSERT INTO client_store (subject, name, secret, domain, owner, public, sso, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, u.String(), name, secret, domain, owner, public, sso, active)
return err
}
func (t *Tx) UpdateClientApp(subject, owner string, name, domain string, public, sso, active bool) error {
_, err := t.tx.Exec(`UPDATE client_store SET name = ?, domain = ?, public = ?, sso = ?, active = ? WHERE subject = ? AND owner = ?`, name, domain, public, sso, active, subject, owner)
return err
}
func (t *Tx) ResetClientAppSecret(subject, owner string) (string, error) {
secret, err := password.GenerateApiSecret(70)
if err != nil {
return "", err
}
_, err = t.tx.Exec(`UPDATE client_store SET secret = ? WHERE subject = ? AND owner = ?`, secret, subject, owner)
return secret, err
}
func (t *Tx) GetUserList(offset int) ([]User, error) {
var u []User
row, err := t.tx.Query(`SELECT subject, name, username, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, role, updated_at, active FROM users LIMIT 25 OFFSET ?`, offset)
if err != nil {
return nil, err
}
for row.Next() {
var a User
err := row.Scan(&a.Sub, &a.Name, &a.Username, &a.Picture, &a.Website, &a.Email, &a.EmailVerified, &a.Pronouns, &a.Birthdate, &a.ZoneInfo, &a.Locale, &a.Role, &a.UpdatedAt, &a.Active)
if err != nil {
return nil, err
}
u = append(u, a)
}
return u, row.Err()
}
func (t *Tx) UpdateUser(subject string, role UserRole, active bool) error {
_, err := t.tx.Exec(`UPDATE users SET active = ?, role = ? WHERE subject = ?`, active, role, subject)
return err
}
func (t *Tx) VerifyUserEmail(sub string) error {
_, err := t.tx.Exec(`UPDATE users SET email_verified = 1 WHERE subject = ?`, sub)
return err
}
func (t *Tx) UserResetPassword(sub string, pw string) error {
hashPassword, err := password.HashPassword(pw)
if err != nil {
return err
}
exec, err := t.tx.Exec(`UPDATE users SET password = ?, updated_at = ? WHERE subject = ?`, hashPassword, updatedAt(), sub)
if err != nil {
return err
}
affected, err := exec.RowsAffected()
if err != nil {
return err
}
if affected != 1 {
return fmt.Errorf("row wasn't updated")
}
return nil
}
func (t *Tx) UserEmailExists(email string) (exists bool, err error) {
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users WHERE email = ? and email_verified = 1)`, email)
err = row.Scan(&exists)
return
}

View File

@ -1,52 +0,0 @@
package database
import (
"github.com/1f349/tulip/password"
"github.com/MrMelon54/pronouns"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"golang.org/x/text/language"
"testing"
"time"
)
func TestTx_ChangeUserPassword(t *testing.T) {
u := uuid.New()
pw, err := password.HashPassword("test")
assert.NoError(t, err)
d, err := Open("file::memory:")
assert.NoError(t, err)
_, err = d.db.Exec(`INSERT INTO users (subject, name, username, password, email, updated_at) VALUES (?, ?, ?, ?, ?, ?)`, u.String(), "Test", "test", pw, "test@localhost", updatedAt())
assert.NoError(t, err)
tx, err := d.Begin()
assert.NoError(t, err)
err = tx.ChangeUserPassword(u.String(), "test", "new")
assert.NoError(t, err)
assert.NoError(t, tx.Commit())
query, err := d.db.Query(`SELECT password FROM users WHERE subject = ? AND username = ?`, u.String(), "test")
assert.NoError(t, err)
assert.True(t, query.Next())
var oldPw password.HashString
assert.NoError(t, query.Scan(&oldPw))
assert.NoError(t, password.CheckPasswordHash(oldPw, "new"))
assert.NoError(t, query.Err())
assert.NoError(t, query.Close())
}
func TestTx_ModifyUser(t *testing.T) {
u := uuid.New()
pw, err := password.HashPassword("test")
assert.NoError(t, err)
d, err := Open("file::memory:")
assert.NoError(t, err)
_, err = d.db.Exec(`INSERT INTO users (subject, name, username, password, email, updated_at) VALUES (?, ?, ?, ?, ?, ?)`, u.String(), "Test", "test", pw, "test@localhost", updatedAt())
assert.NoError(t, err)
tx, err := d.Begin()
assert.NoError(t, err)
assert.NoError(t, tx.ModifyUser(u.String(), &UserPatch{
Name: "example",
Pronouns: pronouns.TheyThem,
ZoneInfo: time.UTC,
Locale: language.AmericanEnglish,
}))
}

View File

@ -0,0 +1,38 @@
package types
import (
"database/sql"
"encoding/json"
"fmt"
"golang.org/x/text/language"
)
var (
_ sql.Scanner = &UserLocale{}
_ json.Marshaler = &UserLocale{}
_ json.Unmarshaler = &UserLocale{}
)
type UserLocale struct{ language.Tag }
func (l *UserLocale) Scan(src any) error {
s, ok := src.(string)
if !ok {
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
}
lang, err := language.Parse(s)
if err != nil {
return err
}
l.Tag = lang
return nil
}
func (l UserLocale) MarshalJSON() ([]byte, error) { return json.Marshal(l.Tag.String()) }
func (l *UserLocale) UnmarshalJSON(bytes []byte) error {
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return l.Scan(a)
}

View File

@ -0,0 +1,12 @@
package types
import (
"github.com/stretchr/testify/assert"
"golang.org/x/text/language"
"testing"
)
func TestUserLocale_MarshalJSON(t *testing.T) {
assert.Equal(t, "\"en-US\"", encode(UserLocale{language.AmericanEnglish}))
assert.Equal(t, "\"en-GB\"", encode(UserLocale{language.BritishEnglish}))
}

View File

@ -0,0 +1,38 @@
package types
import (
"database/sql"
"encoding/json"
"fmt"
"github.com/MrMelon54/pronouns"
)
var (
_ sql.Scanner = &UserPronoun{}
_ json.Marshaler = &UserPronoun{}
_ json.Unmarshaler = &UserPronoun{}
)
type UserPronoun struct{ pronouns.Pronoun }
func (p *UserPronoun) Scan(src any) error {
s, ok := src.(string)
if !ok {
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, p)
}
pro, err := pronouns.FindPronoun(s)
if err != nil {
return err
}
p.Pronoun = pro
return nil
}
func (p UserPronoun) MarshalJSON() ([]byte, error) { return json.Marshal(p.Pronoun.String()) }
func (p *UserPronoun) UnmarshalJSON(bytes []byte) error {
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return p.Scan(a)
}

View File

@ -0,0 +1,15 @@
package types
import (
"github.com/MrMelon54/pronouns"
"github.com/stretchr/testify/assert"
"testing"
)
func TestUserPronoun_MarshalJSON(t *testing.T) {
assert.Equal(t, "\"they/them\"", encode(UserPronoun{pronouns.TheyThem}))
assert.Equal(t, "\"he/him\"", encode(UserPronoun{pronouns.HeHim}))
assert.Equal(t, "\"she/her\"", encode(UserPronoun{pronouns.SheHer}))
assert.Equal(t, "\"it/its\"", encode(UserPronoun{pronouns.ItIts}))
assert.Equal(t, "\"one/one's\"", encode(UserPronoun{pronouns.OneOnes}))
}

View File

@ -0,0 +1,27 @@
package types
import "fmt"
type UserRole int64
const (
RoleMember UserRole = iota
RoleAdmin
RoleToDelete
)
func (r UserRole) String() string {
switch r {
case RoleMember:
return "Member"
case RoleAdmin:
return "Admin"
case RoleToDelete:
return "ToDelete"
}
return fmt.Sprintf("UserRole{ %d }", r)
}
func (r UserRole) IsValid() bool {
return r == RoleMember || r == RoleAdmin
}

View File

@ -0,0 +1,45 @@
package types
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"time"
)
var (
_ sql.Scanner = &UserZone{}
_ driver.Valuer = &UserZone{}
_ json.Marshaler = &UserZone{}
_ json.Unmarshaler = &UserZone{}
)
type UserZone struct{ *time.Location }
func (l *UserZone) Scan(src any) error {
s, ok := src.(string)
if !ok {
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
}
loc, err := time.LoadLocation(s)
if err != nil {
return err
}
l.Location = loc
return nil
}
func (l UserZone) Value() (driver.Value, error) {
return l.Location.String(), nil
}
func (l UserZone) MarshalJSON() ([]byte, error) {
return json.Marshal(l.Location.String())
}
func (l *UserZone) UnmarshalJSON(bytes []byte) error {
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return l.Scan(a)
}

View File

@ -0,0 +1,14 @@
package types
import (
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestUserZone_MarshalJSON(t *testing.T) {
location, err := time.LoadLocation("Europe/London")
assert.NoError(t, err)
assert.Equal(t, "\"Europe/London\"", encode(UserZone{location}))
assert.Equal(t, "\"UTC\"", encode(UserZone{time.UTC}))
}

View File

@ -0,0 +1,11 @@
package types
import "encoding/json"
func encode(data any) string {
j, err := json.Marshal(data)
if err != nil {
panic(err)
}
return string(j)
}

View File

@ -8,6 +8,10 @@ package database
import (
"context"
"database/sql"
"time"
"github.com/1f349/tulip/database/types"
"github.com/1f349/tulip/password"
)
const deleteTwoFactor = `-- name: DeleteTwoFactor :exec
@ -40,44 +44,20 @@ func (q *Queries) GetTwoFactor(ctx context.Context, subject string) (GetTwoFacto
}
const getUser = `-- name: GetUser :one
SELECT name,
username,
picture,
website,
email,
email_verified,
pronouns,
birthdate,
zoneinfo,
locale,
updated_at,
active
SELECT subject, name, username, password, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, role, updated_at, registered, active
FROM users
WHERE subject = ?
LIMIT 1
`
type GetUserRow struct {
Name string `json:"name"`
Username string `json:"username"`
Picture interface{} `json:"picture"`
Website interface{} `json:"website"`
Email string `json:"email"`
EmailVerified int64 `json:"email_verified"`
Pronouns interface{} `json:"pronouns"`
Birthdate sql.NullTime `json:"birthdate"`
Zoneinfo interface{} `json:"zoneinfo"`
Locale interface{} `json:"locale"`
UpdatedAt sql.NullTime `json:"updated_at"`
Active sql.NullInt64 `json:"active"`
}
func (q *Queries) GetUser(ctx context.Context, subject string) (GetUserRow, error) {
func (q *Queries) GetUser(ctx context.Context, subject string) (User, error) {
row := q.db.QueryRowContext(ctx, getUser, subject)
var i GetUserRow
var i User
err := row.Scan(
&i.Subject,
&i.Name,
&i.Username,
&i.Password,
&i.Picture,
&i.Website,
&i.Email,
@ -86,12 +66,51 @@ func (q *Queries) GetUser(ctx context.Context, subject string) (GetUserRow, erro
&i.Birthdate,
&i.Zoneinfo,
&i.Locale,
&i.Role,
&i.UpdatedAt,
&i.Registered,
&i.Active,
)
return i, err
}
const getUserDisplayName = `-- name: GetUserDisplayName :one
SELECT name
FROM users
WHERE subject = ?
`
func (q *Queries) GetUserDisplayName(ctx context.Context, subject string) (string, error) {
row := q.db.QueryRowContext(ctx, getUserDisplayName, subject)
var name string
err := row.Scan(&name)
return name, err
}
const getUserRole = `-- name: GetUserRole :one
SELECT role
FROM users
WHERE subject = ?
`
func (q *Queries) GetUserRole(ctx context.Context, subject string) (types.UserRole, error) {
row := q.db.QueryRowContext(ctx, getUserRole, subject)
var role types.UserRole
err := row.Scan(&role)
return role, err
}
const hasTwoFactor = `-- name: HasTwoFactor :one
SELECT cast(EXISTS(SELECT 1 FROM otp WHERE subject = ?) AS BOOLEAN)
`
func (q *Queries) HasTwoFactor(ctx context.Context, subject string) (bool, error) {
row := q.db.QueryRowContext(ctx, hasTwoFactor, subject)
var column_1 bool
err := row.Scan(&column_1)
return column_1, err
}
const hasUser = `-- name: HasUser :one
SELECT cast(count(subject) AS BOOLEAN) AS hasUser
FROM users
@ -118,15 +137,15 @@ WHERE subject = ?
`
type ModifyUserParams struct {
Name string `json:"name"`
Picture interface{} `json:"picture"`
Website interface{} `json:"website"`
Pronouns interface{} `json:"pronouns"`
Birthdate sql.NullTime `json:"birthdate"`
Zoneinfo interface{} `json:"zoneinfo"`
Locale interface{} `json:"locale"`
UpdatedAt sql.NullTime `json:"updated_at"`
Subject string `json:"subject"`
Name string `json:"name"`
Picture string `json:"picture"`
Website string `json:"website"`
Pronouns types.UserPronoun `json:"pronouns"`
Birthdate sql.NullTime `json:"birthdate"`
Zoneinfo types.UserZone `json:"zoneinfo"`
Locale types.UserLocale `json:"locale"`
UpdatedAt time.Time `json:"updated_at"`
Subject string `json:"subject"`
}
func (q *Queries) ModifyUser(ctx context.Context, arg ModifyUserParams) (int64, error) {
@ -171,15 +190,15 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
`
type addUserParams struct {
Subject string `json:"subject"`
Name string `json:"name"`
Username string `json:"username"`
Password string `json:"password"`
Email string `json:"email"`
EmailVerified int64 `json:"email_verified"`
Role int64 `json:"role"`
UpdatedAt sql.NullTime `json:"updated_at"`
Active sql.NullInt64 `json:"active"`
Subject string `json:"subject"`
Name string `json:"name"`
Username string `json:"username"`
Password password.HashString `json:"password"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Role types.UserRole `json:"role"`
UpdatedAt time.Time `json:"updated_at"`
Active bool `json:"active"`
}
func (q *Queries) addUser(ctx context.Context, arg addUserParams) error {
@ -206,10 +225,10 @@ WHERE subject = ?
`
type changeUserPasswordParams struct {
Password string `json:"password"`
UpdatedAt sql.NullTime `json:"updated_at"`
Subject string `json:"subject"`
Password_2 string `json:"password_2"`
Password password.HashString `json:"password"`
UpdatedAt time.Time `json:"updated_at"`
Subject string `json:"subject"`
Password_2 password.HashString `json:"password_2"`
}
func (q *Queries) changeUserPassword(ctx context.Context, arg changeUserPasswordParams) (int64, error) {
@ -233,11 +252,11 @@ LIMIT 1
`
type checkLoginRow struct {
Subject string `json:"subject"`
Password string `json:"password"`
Column3 int64 `json:"column_3"`
Email string `json:"email"`
EmailVerified int64 `json:"email_verified"`
Subject string `json:"subject"`
Password password.HashString `json:"password"`
Column3 int64 `json:"column_3"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
}
func (q *Queries) checkLogin(ctx context.Context, username string) (checkLoginRow, error) {
@ -259,9 +278,9 @@ FROM users
WHERE subject = ?
`
func (q *Queries) getUserPassword(ctx context.Context, subject string) (string, error) {
func (q *Queries) getUserPassword(ctx context.Context, subject string) (password.HashString, error) {
row := q.db.QueryRowContext(ctx, getUserPassword, subject)
var password string
var password password.HashString
err := row.Scan(&password)
return password, err
}

View File

@ -4,6 +4,7 @@ import (
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/auth"
"github.com/1f349/tulip/database"
"github.com/1f349/tulip/database/types"
"github.com/julienschmidt/httprouter"
"net/http"
"net/url"
@ -30,14 +31,14 @@ func (u UserAuth) IsGuest() bool {
func (h *HttpServer) RequireAdminAuthentication(next UserHandler) httprouter.Handle {
return h.RequireAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
var role database.UserRole
if h.DbTx(rw, func(tx *database.Tx) (err error) {
role, err = tx.GetUserRole(auth.ID)
var role types.UserRole
if h.DbTx(rw, func(tx *database.Queries) (err error) {
role, err = tx.GetUserRole(req.Context(), auth.ID)
return
}) {
return
}
if role != database.RoleAdmin {
if role != types.RoleAdmin {
http.Error(rw, "403 Forbidden", http.StatusForbidden)
return
}

View File

@ -9,25 +9,13 @@ import (
// DbTx wraps a database transaction with http error messages and a simple action
// function. If the action function returns an error the transaction will be
// rolled back. If there is no error then the transaction is committed.
func (h *HttpServer) DbTx(rw http.ResponseWriter, action func(tx *database.Tx) error) bool {
tx, err := h.db.Begin()
if err != nil {
http.Error(rw, "Failed to begin database transaction", http.StatusInternalServerError)
return true
}
defer tx.Rollback()
err = action(tx)
func (h *HttpServer) DbTx(rw http.ResponseWriter, action func(db *database.Queries) error) bool {
err := action(h.db)
if err != nil {
http.Error(rw, "Database error", http.StatusInternalServerError)
log.Println("Database action error:", err)
return true
}
err = tx.Commit()
if err != nil {
http.Error(rw, "Database error", http.StatusInternalServerError)
log.Println("Database commit error:", err)
}
return false
}

View File

@ -11,12 +11,12 @@ import (
"time"
)
func (h *HttpServer) EditGet(rw http.ResponseWriter, _ *http.Request, _ httprouter.Params, auth UserAuth) {
var user *database.User
func (h *HttpServer) EditGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
var user database.User
if h.DbTx(rw, func(tx *database.Tx) error {
if h.DbTx(rw, func(tx *database.Queries) error {
var err error
user, err = tx.GetUser(auth.ID)
user, err = tx.GetUser(req.Context(), auth.ID)
if err != nil {
return fmt.Errorf("failed to read user data: %w", err)
}
@ -64,8 +64,19 @@ func (h *HttpServer) EditPost(rw http.ResponseWriter, req *http.Request, _ httpr
_, _ = fmt.Fprintln(rw, "</body>\n</html>")
return
}
if h.DbTx(rw, func(tx *database.Tx) error {
if err := tx.ModifyUser(auth.ID, &patch); err != nil {
m := database.ModifyUserParams{
Name: patch.Name,
Picture: patch.Picture,
Website: patch.Website,
Pronouns: patch.Pronouns,
Birthdate: patch.Birthdate,
Zoneinfo: patch.ZoneInfo,
Locale: patch.Locale,
UpdatedAt: time.Now(),
Subject: auth.ID,
}
if h.DbTx(rw, func(tx *database.Queries) error {
if _, err := tx.ModifyUser(req.Context(), m); err != nil {
return fmt.Errorf("failed to modify user info: %w", err)
}
return nil

View File

@ -3,6 +3,7 @@ package server
import (
"fmt"
"github.com/1f349/tulip/database"
"github.com/1f349/tulip/database/types"
"github.com/1f349/tulip/pages"
"github.com/google/uuid"
"github.com/julienschmidt/httprouter"
@ -29,18 +30,19 @@ func (h *HttpServer) Home(rw http.ResponseWriter, req *http.Request, _ httproute
return
}
var userWithName *database.User
var userWithName string
var userRole types.UserRole
var hasTwoFactor bool
if h.DbTx(rw, func(tx *database.Tx) (err error) {
userWithName, err = tx.GetUserDisplayName(auth.ID)
if h.DbTx(rw, func(tx *database.Queries) (err error) {
userWithName, err = tx.GetUserDisplayName(req.Context(), auth.ID)
if err != nil {
return fmt.Errorf("failed to get user display name: %w", err)
}
hasTwoFactor, err = tx.HasTwoFactor(auth.ID)
hasTwoFactor, err = tx.HasTwoFactor(req.Context(), auth.ID)
if err != nil {
return fmt.Errorf("failed to get user two factor state: %w", err)
}
userWithName.Role, err = tx.GetUserRole(auth.ID)
userRole, err = tx.GetUserRole(req.Context(), auth.ID)
if err != nil {
return fmt.Errorf("failed to get user role: %w", err)
}
@ -54,6 +56,6 @@ func (h *HttpServer) Home(rw http.ResponseWriter, req *http.Request, _ httproute
"User": userWithName,
"Nonce": lNonce,
"OtpEnabled": hasTwoFactor,
"IsAdmin": userWithName.Role,
"IsAdmin": userRole == types.RoleAdmin,
})
}

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"github.com/1f349/mjwt"
"github.com/1f349/tulip/database"
"github.com/go-oauth2/oauth2/v4"
@ -9,7 +10,7 @@ import (
"strings"
)
func addIdTokenSupport(srv *server.Server, db *database.DB, key mjwt.Signer) {
func addIdTokenSupport(srv *server.Server, db *database.Queries, key mjwt.Signer) {
srv.SetExtensionFieldsHandler(func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) {
scope := ti.GetScope()
if containsScope(scope, "openid") {
@ -30,18 +31,13 @@ type IdTokenClaims struct{}
func (a IdTokenClaims) Valid() error { return nil }
func (a IdTokenClaims) Type() string { return "access-token" }
func generateIDToken(ti oauth2.TokenInfo, us *database.DB, key mjwt.Signer) (token string, err error) {
tx, err := us.Begin()
func generateIDToken(ti oauth2.TokenInfo, us *database.Queries, key mjwt.Signer) (token string, err error) {
user, err := us.GetUser(context.Background(), ti.GetUserID())
if err != nil {
return "", err
}
user, err := tx.GetUser(ti.GetUserID())
if err != nil {
return "", err
}
tx.Rollback()
token, err = key.GenerateJwt(user.Sub, "", jwt.ClaimStrings{ti.GetClientID()}, ti.GetAccessExpiresIn(), IdTokenClaims{})
token, err = key.GenerateJwt(user.Subject, "", jwt.ClaimStrings{ti.GetClientID()}, ti.GetAccessExpiresIn(), IdTokenClaims{})
return
}

View File

@ -62,7 +62,7 @@ func (h *HttpServer) LoginPost(rw http.ResponseWriter, req *http.Request, _ http
var loginMismatch byte
var hasOtp bool
if h.DbTx(rw, func(tx *database.Tx) error {
if h.DbTx(rw, func(tx *database.Queries) error {
loginUser, hasOtpRaw, hasVerifiedEmail, err := tx.CheckLogin(un, pw)
if err != nil {
if errors.Is(err, sql.ErrNoRows) || errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
@ -176,7 +176,7 @@ func (h *HttpServer) LoginResetPasswordPost(rw http.ResponseWriter, req *http.Re
}
var emailExists bool
if h.DbTx(rw, func(tx *database.Tx) (err error) {
if h.DbTx(rw, func(tx *database.Queries) (err error) {
emailExists, err = tx.UserEmailExists(email)
return err
}) {

View File

@ -18,7 +18,7 @@ func (h *HttpServer) MailVerify(rw http.ResponseWriter, _ *http.Request, params
http.Error(rw, "Invalid email verification code", http.StatusBadRequest)
return
}
if h.DbTx(rw, func(tx *database.Tx) error {
if h.DbTx(rw, func(tx *database.Queries) error {
return tx.VerifyUserEmail(userSub)
}) {
return
@ -75,7 +75,7 @@ func (h *HttpServer) MailPasswordPost(rw http.ResponseWriter, req *http.Request,
h.mailLinkCache.Delete(k)
// reset password database call
if h.DbTx(rw, func(tx *database.Tx) error {
if h.DbTx(rw, func(tx *database.Queries) error {
return tx.UserResetPassword(userSub, pw)
}) {
return
@ -94,12 +94,12 @@ func (h *HttpServer) MailDelete(rw http.ResponseWriter, _ *http.Request, params
return
}
var userInfo *database.User
if h.DbTx(rw, func(tx *database.Tx) (err error) {
if h.DbTx(rw, func(tx *database.Queries) (err error) {
userInfo, err = tx.GetUser(userSub)
if err != nil {
return
}
return tx.UpdateUser(userSub, database.RoleToDelete, false)
return tx.UpdateUser(userSub, types.RoleToDelete, false)
}) {
return
}

View File

@ -22,14 +22,14 @@ func (h *HttpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _
}
}
var role database.UserRole
var appList []database.ClientInfoDbOutput
if h.DbTx(rw, func(tx *database.Tx) (err error) {
var role types.UserRole
var appList []database.ClientStore
if h.DbTx(rw, func(tx *database.Queries) (err error) {
role, err = tx.GetUserRole(auth.ID)
if err != nil {
return
}
appList, err = tx.GetAppList(auth.ID, role == database.RoleAdmin, offset)
appList, err = tx.GetAppList(auth.ID, role == types.RoleAdmin, offset)
return
}) {
return
@ -39,7 +39,7 @@ func (h *HttpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _
"ServiceName": h.conf.ServiceName,
"Apps": appList,
"Offset": offset,
"IsAdmin": role == database.RoleAdmin,
"IsAdmin": role == types.RoleAdmin,
"NewAppName": q.Get("NewAppName"),
"NewAppSecret": q.Get("NewAppSecret"),
}
@ -76,14 +76,14 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
active := req.Form.Has("active")
if sso {
var role database.UserRole
if h.DbTx(rw, func(tx *database.Tx) (err error) {
var role types.UserRole
if h.DbTx(rw, func(tx *database.Queries) (err error) {
role, err = tx.GetUserRole(auth.ID)
return
}) {
return
}
if role != database.RoleAdmin {
if role != types.RoleAdmin {
http.Error(rw, "400 Bad Request: Only admin users can create SSO client applications", http.StatusBadRequest)
return
}
@ -91,13 +91,13 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
switch action {
case "create":
if h.DbTx(rw, func(tx *database.Tx) error {
if h.DbTx(rw, func(tx *database.Queries) error {
return tx.InsertClientApp(name, domain, public, sso, active, auth.ID)
}) {
return
}
case "edit":
if h.DbTx(rw, func(tx *database.Tx) error {
if h.DbTx(rw, func(tx *database.Queries) error {
return tx.UpdateClientApp(req.Form.Get("subject"), auth.ID, name, domain, public, sso, active)
}) {
return
@ -105,7 +105,7 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
case "secret":
var info oauth2.ClientInfo
var secret string
if h.DbTx(rw, func(tx *database.Tx) error {
if h.DbTx(rw, func(tx *database.Queries) error {
sub := req.Form.Get("subject")
info, err = tx.GetClientInfo(sub)
if err != nil {

View File

@ -3,6 +3,7 @@ package server
import (
"errors"
"github.com/1f349/tulip/database"
"github.com/1f349/tulip/database/types"
"github.com/1f349/tulip/pages"
"github.com/emersion/go-message/mail"
"github.com/google/uuid"
@ -27,9 +28,9 @@ func (h *HttpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _
}
}
var role database.UserRole
var role types.UserRole
var userList []database.User
if h.DbTx(rw, func(tx *database.Tx) (err error) {
if h.DbTx(rw, func(tx *database.Queries) (err error) {
role, err = tx.GetUserRole(auth.ID)
if err != nil {
return
@ -39,7 +40,7 @@ func (h *HttpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _
}) {
return
}
if role != database.RoleAdmin {
if role != types.RoleAdmin {
http.Error(rw, "403 Forbidden", http.StatusForbidden)
return
}
@ -76,14 +77,14 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
return
}
var role database.UserRole
if h.DbTx(rw, func(tx *database.Tx) (err error) {
var role types.UserRole
if h.DbTx(rw, func(tx *database.Queries) (err error) {
role, err = tx.GetUserRole(auth.ID)
return
}) {
return
}
if role != database.RoleAdmin {
if role != types.RoleAdmin {
http.Error(rw, "400 Bad Request: Only admin users can manage users", http.StatusBadRequest)
return
}
@ -116,7 +117,7 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
addrDomain := address.Address[n+1:]
var userSub uuid.UUID
if h.DbTx(rw, func(tx *database.Tx) (err error) {
if h.DbTx(rw, func(tx *database.Queries) (err error) {
userSub, err = tx.InsertUser(name, username, "", email, addrDomain == h.conf.Namespace, newRole, active)
return err
}) {
@ -136,7 +137,7 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
return
}
case "edit":
if h.DbTx(rw, func(tx *database.Tx) error {
if h.DbTx(rw, func(tx *database.Queries) error {
sub := req.Form.Get("subject")
return tx.UpdateUser(sub, newRole, active)
}) {
@ -151,12 +152,12 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
http.Redirect(rw, req, redirectUrl.String(), http.StatusFound)
}
func parseRoleValue(role string) (database.UserRole, error) {
func parseRoleValue(role string) (types.UserRole, error) {
switch role {
case "member":
return database.RoleMember, nil
return types.RoleMember, nil
case "admin":
return database.RoleAdmin, nil
return types.RoleAdmin, nil
}
return 0, errors.New("invalid role value")
}

View File

@ -79,7 +79,7 @@ func (h *HttpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request
var user *database.User
var hasOtp bool
if h.DbTx(rw, func(tx *database.Tx) (err error) {
if h.DbTx(rw, func(tx *database.Queries) (err error) {
user, err = tx.GetUserDisplayName(auth.ID)
if err != nil {
return

View File

@ -47,7 +47,7 @@ func (h *HttpServer) fetchAndValidateOtp(rw http.ResponseWriter, sub, code strin
var hasOtp bool
var secret string
var digits int
if h.DbTx(rw, func(tx *database.Tx) (err error) {
if h.DbTx(rw, func(tx *database.Queries) (err error) {
hasOtp, err = tx.HasTwoFactor(sub)
if err != nil {
return
@ -86,7 +86,7 @@ func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ ht
return
}
if h.DbTx(rw, func(tx *database.Tx) error {
if h.DbTx(rw, func(tx *database.Queries) error {
return tx.SetTwoFactor(auth.ID, "", 0)
}) {
return
@ -118,7 +118,7 @@ func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ ht
if secret == "" {
// get user email
var email string
if h.DbTx(rw, func(tx *database.Tx) error {
if h.DbTx(rw, func(tx *database.Queries) error {
var err error
email, err = tx.GetUserEmail(auth.ID)
return err
@ -167,7 +167,7 @@ func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ ht
return
}
if h.DbTx(rw, func(tx *database.Tx) error {
if h.DbTx(rw, func(tx *database.Queries) error {
return tx.SetTwoFactor(auth.ID, secret, digits)
}) {
return

View File

@ -31,7 +31,7 @@ type HttpServer struct {
r *httprouter.Router
oauthSrv *server.Server
oauthMgr *manage.Manager
db *database.DB
db *database.Queries
conf Conf
signingKey mjwt.Signer
@ -50,7 +50,7 @@ type mailLinkKey struct {
data string
}
func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Server {
func NewHttpServer(conf Conf, db *database.Queries, signingKey mjwt.Signer) *http.Server {
r := httprouter.New()
// remove last slash from baseUrl
@ -191,10 +191,10 @@ func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Ser
return
}
var userData *database.User
var userData database.GetUserRow
if hs.DbTx(rw, func(tx *database.Tx) (err error) {
userData, err = tx.GetUser(userId)
if hs.DbTx(rw, func(tx *database.Queries) (err error) {
userData, err = tx.GetUser(req.Context(), userId)
return err
}) {
return

View File

@ -9,7 +9,13 @@ sql:
out: "database"
emit_json_tags: true
overrides:
- column: "routes.flags"
go_type: "github.com/1f349/violet/target.Flags"
- column: "redirects.flags"
go_type: "github.com/1f349/violet/target.Flags"
- column: "users.password"
go_type: "github.com/1f349/tulip/password.HashString"
- column: "users.role"
go_type: "github.com/1f349/tulip/database/types.UserRole"
- column: "users.pronouns"
go_type: "github.com/1f349/tulip/database/types.UserPronoun"
- column: "users.zoneinfo"
go_type: "github.com/1f349/tulip/database/types.UserZone"
- column: "users.locale"
go_type: "github.com/1f349/tulip/database/types.UserLocale"