From 3a7eef803e6b1a755505fc40dcd82ab158a0d938 Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Fri, 17 May 2024 21:40:31 +0100 Subject: [PATCH] Start adding sqlc and migration support --- client-store/client-store.go | 12 +- cmd/lavender/serve.go | 4 +- database/clientstore.go | 22 +++ database/db-types.go | 43 ----- database/db.go | 50 ++--- database/db_test.go | 5 - database/init.sql | 27 --- database/manage-oauth.sql.go | 178 ++++++++++++++++++ database/manage-users.sql.go | 90 +++++++++ .../migrations/20240517171813_init.down.sql | 2 + .../migrations/20240517171813_init.up.sql | 27 +++ database/models.go | 34 ++++ database/queries/manage-oauth.sql | 33 ++++ database/queries/manage-users.sql | 18 ++ database/queries/users.sql | 43 +++++ database/tx.go | 160 ---------------- database/users.sql.go | 172 +++++++++++++++++ go.mod | 5 +- go.sum | 12 +- initdb.go | 38 ++++ server/auth.go | 4 +- server/db.go | 29 +-- server/home.go | 6 +- server/id_token.go | 12 +- server/jwt.go | 11 +- server/login.go | 30 ++- server/manage-apps.go | 79 +++++--- server/manage-users.go | 20 +- server/server.go | 14 +- sqlc.yaml | 10 + 30 files changed, 819 insertions(+), 371 deletions(-) create mode 100644 database/clientstore.go delete mode 100644 database/db-types.go delete mode 100644 database/db_test.go delete mode 100644 database/init.sql create mode 100644 database/manage-oauth.sql.go create mode 100644 database/manage-users.sql.go create mode 100644 database/migrations/20240517171813_init.down.sql create mode 100644 database/migrations/20240517171813_init.up.sql create mode 100644 database/models.go create mode 100644 database/queries/manage-oauth.sql create mode 100644 database/queries/manage-users.sql create mode 100644 database/queries/users.sql delete mode 100644 database/tx.go create mode 100644 database/users.sql.go create mode 100644 initdb.go create mode 100644 sqlc.yaml diff --git a/client-store/client-store.go b/client-store/client-store.go index 97da29a..ad9b477 100644 --- a/client-store/client-store.go +++ b/client-store/client-store.go @@ -7,20 +7,16 @@ import ( ) type ClientStore struct { - db *database.DB + db *database.Queries } var _ oauth2.ClientStore = &ClientStore{} -func New(db *database.DB) *ClientStore { +func New(db *database.Queries) *ClientStore { return &ClientStore{db: db} } func (c *ClientStore) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) { - tx, err := c.db.BeginCtx(ctx) - if err != nil { - return nil, err - } - defer tx.Rollback() - return tx.GetClientInfo(id) + a, err := c.db.GetClientInfo(ctx, id) + return &a, err } diff --git a/cmd/lavender/serve.go b/cmd/lavender/serve.go index d565f7e..c637fcf 100644 --- a/cmd/lavender/serve.go +++ b/cmd/lavender/serve.go @@ -9,7 +9,7 @@ import ( "encoding/pem" "errors" "flag" - "github.com/1f349/lavender/database" + "github.com/1f349/lavender" "github.com/1f349/lavender/logger" "github.com/1f349/lavender/pages" "github.com/1f349/lavender/server" @@ -75,7 +75,7 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) } saveMjwtPubKey(signingKey, wd) - db, err := database.Open(filepath.Join(wd, "lavender.db.sqlite")) + db, err := lavender.InitDB(filepath.Join(wd, "lavender.db.sqlite")) if err != nil { logger.Logger.Fatal("Failed to open database:", err) } diff --git a/database/clientstore.go b/database/clientstore.go new file mode 100644 index 0000000..1e1c1fc --- /dev/null +++ b/database/clientstore.go @@ -0,0 +1,22 @@ +package database + +import "github.com/go-oauth2/oauth2/v4" + +var _ oauth2.ClientInfo = &ClientStore{} + +func (c *ClientStore) GetID() string { return c.Subject } +func (c *ClientStore) GetSecret() string { return c.Secret } +func (c *ClientStore) GetDomain() string { return c.Domain } +func (c *ClientStore) IsPublic() bool { return c.Public } +func (c *ClientStore) GetUserID() string { return c.Owner } + +// GetName is an extra field for the oauth handler to display the application +// name +func (c *ClientStore) GetName() string { return c.Name } + +// IsSSO is an extra field for the oauth handler to skip the user input stage +// this is for trusted applications to get permissions without asking the user +func (c *ClientStore) IsSSO() bool { return c.Sso } + +// IsActive is an extra field for the app manager to get the active state +func (c *ClientStore) IsActive() bool { return c.Active } diff --git a/database/db-types.go b/database/db-types.go deleted file mode 100644 index 9aa3145..0000000 --- a/database/db-types.go +++ /dev/null @@ -1,43 +0,0 @@ -package database - -import ( - "github.com/go-oauth2/oauth2/v4" - "time" -) - -type User struct { - Subject string `json:"subject"` - Email string `json:"email"` - EmailVerified bool `json:"email_verified"` - Roles string `json:"roles"` - UserInfo string `json:"userinfo"` - UpdatedAt time.Time `json:"updated_at"` - Active bool `json:"active"` -} - -type ClientInfoDbOutput struct { - Subject, Name, Secret, Domain, Owner, Perms string - Public, Sso, Active bool -} - -var _ oauth2.ClientInfo = &ClientInfoDbOutput{} - -func (c *ClientInfoDbOutput) GetID() string { return c.Subject } -func (c *ClientInfoDbOutput) GetSecret() string { return c.Secret } -func (c *ClientInfoDbOutput) GetDomain() string { return c.Domain } -func (c *ClientInfoDbOutput) IsPublic() bool { return c.Public } -func (c *ClientInfoDbOutput) GetUserID() string { return c.Owner } - -// GetName is an extra field for the oauth handler to display the application -// name -func (c *ClientInfoDbOutput) GetName() string { return c.Name } - -// IsSSO is an extra field for the oauth handler to skip the user input stage -// this is for trusted applications to get permissions without asking the user -func (c *ClientInfoDbOutput) IsSSO() bool { return c.Sso } - -// IsActive is an extra field for the app manager to get the active state -func (c *ClientInfoDbOutput) IsActive() bool { return c.Active } - -// UsePerms is an extra field for the userinfo handler to return user permissions matching the requested values -func (c *ClientInfoDbOutput) UsePerms() string { return c.Perms } diff --git a/database/db.go b/database/db.go index 0e38862..61f5bf4 100644 --- a/database/db.go +++ b/database/db.go @@ -1,41 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + package database import ( "context" "database/sql" - _ "embed" ) -//go:embed init.sql -var initSql string +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} -type DB struct{ db *sql.DB } +func New(db DBTX) *Queries { + return &Queries{db: db} +} -func Open(p string) (*DB, error) { - db, err := sql.Open("sqlite3", p) - if err != nil { - return nil, err +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, } - _, err = db.Exec(initSql) - return &DB{db: db}, err -} - -func (d *DB) Begin() (*Tx, error) { - begin, err := d.db.Begin() - if err != nil { - return nil, err - } - return &Tx{begin}, err -} - -func (d *DB) BeginCtx(ctx context.Context) (*Tx, error) { - begin, err := d.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - return &Tx{begin}, err -} - -func (d *DB) Close() error { - return d.db.Close() } diff --git a/database/db_test.go b/database/db_test.go deleted file mode 100644 index 5d73aab..0000000 --- a/database/db_test.go +++ /dev/null @@ -1,5 +0,0 @@ -package database - -import ( - _ "github.com/mattn/go-sqlite3" -) diff --git a/database/init.sql b/database/init.sql deleted file mode 100644 index 0009f1d..0000000 --- a/database/init.sql +++ /dev/null @@ -1,27 +0,0 @@ -CREATE TABLE IF NOT EXISTS users -( - subject TEXT PRIMARY KEY UNIQUE NOT NULL, - email TEXT NOT NULL, - email_verified INTEGER DEFAULT 0 NOT NULL, - roles TEXT NOT NULL, - userinfo TEXT, - access_token TEXT, - refresh_token TEXT, - expiry DATETIME, - updated_at DATETIME, - active INTEGER DEFAULT 1 -); - -CREATE TABLE IF NOT EXISTS client_store -( - subject TEXT PRIMARY KEY UNIQUE NOT NULL, - name TEXT NOT NULL, - secret TEXT UNIQUE NOT NULL, - domain TEXT NOT NULL, - owner TEXT NOT NULL, - perms TEXT NOT NULL, - public INTEGER, - sso INTEGER, - active INTEGER DEFAULT 1, - FOREIGN KEY (owner) REFERENCES users (subject) -); diff --git a/database/manage-oauth.sql.go b/database/manage-oauth.sql.go new file mode 100644 index 0000000..5b25999 --- /dev/null +++ b/database/manage-oauth.sql.go @@ -0,0 +1,178 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: manage-oauth.sql + +package database + +import ( + "context" +) + +const getAppList = `-- name: GetAppList :many +SELECT subject, name, domain, owner, public, sso, active +FROM client_store +WHERE owner = ? + OR ? = 1 +LIMIT 25 OFFSET ? +` + +type GetAppListParams struct { + Owner string `json:"owner"` + Column2 interface{} `json:"column_2"` + Offset int64 `json:"offset"` +} + +type GetAppListRow struct { + Subject string `json:"subject"` + Name string `json:"name"` + Domain string `json:"domain"` + Owner string `json:"owner"` + Public bool `json:"public"` + Sso bool `json:"sso"` + Active bool `json:"active"` +} + +func (q *Queries) GetAppList(ctx context.Context, arg GetAppListParams) ([]GetAppListRow, error) { + rows, err := q.db.QueryContext(ctx, getAppList, arg.Owner, arg.Column2, arg.Offset) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetAppListRow + for rows.Next() { + var i GetAppListRow + if err := rows.Scan( + &i.Subject, + &i.Name, + &i.Domain, + &i.Owner, + &i.Public, + &i.Sso, + &i.Active, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getClientInfo = `-- name: GetClientInfo :one +SELECT subject, name, secret, domain, owner, perms, public, sso, active +FROM client_store +WHERE subject = ? +LIMIT 1 +` + +func (q *Queries) GetClientInfo(ctx context.Context, subject string) (ClientStore, error) { + row := q.db.QueryRowContext(ctx, getClientInfo, subject) + var i ClientStore + err := row.Scan( + &i.Subject, + &i.Name, + &i.Secret, + &i.Domain, + &i.Owner, + &i.Perms, + &i.Public, + &i.Sso, + &i.Active, + ) + return i, err +} + +const insertClientApp = `-- name: InsertClientApp :exec +INSERT INTO client_store (subject, name, secret, domain, owner, perms, public, sso, active) +VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) +` + +type InsertClientAppParams struct { + Subject string `json:"subject"` + Name string `json:"name"` + Secret string `json:"secret"` + Domain string `json:"domain"` + Owner string `json:"owner"` + Perms string `json:"perms"` + Public bool `json:"public"` + Sso bool `json:"sso"` + Active bool `json:"active"` +} + +func (q *Queries) InsertClientApp(ctx context.Context, arg InsertClientAppParams) error { + _, err := q.db.ExecContext(ctx, insertClientApp, + arg.Subject, + arg.Name, + arg.Secret, + arg.Domain, + arg.Owner, + arg.Perms, + arg.Public, + arg.Sso, + arg.Active, + ) + return err +} + +const resetClientAppSecret = `-- name: ResetClientAppSecret :exec +UPDATE client_store +SET secret = ? +WHERE subject = ? + AND owner = ? +` + +type ResetClientAppSecretParams struct { + Secret string `json:"secret"` + Subject string `json:"subject"` + Owner string `json:"owner"` +} + +func (q *Queries) ResetClientAppSecret(ctx context.Context, arg ResetClientAppSecretParams) error { + _, err := q.db.ExecContext(ctx, resetClientAppSecret, arg.Secret, arg.Subject, arg.Owner) + return err +} + +const updateClientApp = `-- name: UpdateClientApp :exec +UPDATE client_store +SET name = ?, + domain = ?, + perms = CASE WHEN CAST(? AS BOOLEAN) = true THEN ? ELSE perms END, + public = ?, + sso = ?, + active = ? +WHERE subject = ? + AND owner = ? +` + +type UpdateClientAppParams struct { + Name string `json:"name"` + Domain string `json:"domain"` + Column3 bool `json:"column_3"` + Perms string `json:"perms"` + Public bool `json:"public"` + Sso bool `json:"sso"` + Active bool `json:"active"` + Subject string `json:"subject"` + Owner string `json:"owner"` +} + +func (q *Queries) UpdateClientApp(ctx context.Context, arg UpdateClientAppParams) error { + _, err := q.db.ExecContext(ctx, updateClientApp, + arg.Name, + arg.Domain, + arg.Column3, + arg.Perms, + arg.Public, + arg.Sso, + arg.Active, + arg.Subject, + arg.Owner, + ) + return err +} diff --git a/database/manage-users.sql.go b/database/manage-users.sql.go new file mode 100644 index 0000000..94513d0 --- /dev/null +++ b/database/manage-users.sql.go @@ -0,0 +1,90 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: manage-users.sql + +package database + +import ( + "context" + "time" +) + +const getUserList = `-- name: GetUserList :many +SELECT subject, + email, + email_verified, + roles, + updated_at, + active +FROM users +LIMIT 25 OFFSET ? +` + +type GetUserListRow struct { + Subject string `json:"subject"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Roles string `json:"roles"` + UpdatedAt time.Time `json:"updated_at"` + Active bool `json:"active"` +} + +func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListRow, error) { + rows, err := q.db.QueryContext(ctx, getUserList, offset) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetUserListRow + for rows.Next() { + var i GetUserListRow + if err := rows.Scan( + &i.Subject, + &i.Email, + &i.EmailVerified, + &i.Roles, + &i.UpdatedAt, + &i.Active, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateUser = `-- name: UpdateUser :exec +UPDATE users +SET active = ?, + roles=? +WHERE subject = ? +` + +type UpdateUserParams struct { + Active bool `json:"active"` + Roles string `json:"roles"` + Subject string `json:"subject"` +} + +func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) error { + _, err := q.db.ExecContext(ctx, updateUser, arg.Active, arg.Roles, arg.Subject) + return err +} + +const userEmailExists = `-- name: UserEmailExists :one +SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND email_verified = 1) == 1 AS email_exists +` + +func (q *Queries) UserEmailExists(ctx context.Context, email string) (bool, error) { + row := q.db.QueryRowContext(ctx, userEmailExists, email) + var email_exists bool + err := row.Scan(&email_exists) + return email_exists, err +} diff --git a/database/migrations/20240517171813_init.down.sql b/database/migrations/20240517171813_init.down.sql new file mode 100644 index 0000000..bdfa645 --- /dev/null +++ b/database/migrations/20240517171813_init.down.sql @@ -0,0 +1,2 @@ +DROP TABLE users; +DROP TABLE client_store; diff --git a/database/migrations/20240517171813_init.up.sql b/database/migrations/20240517171813_init.up.sql new file mode 100644 index 0000000..f3acbef --- /dev/null +++ b/database/migrations/20240517171813_init.up.sql @@ -0,0 +1,27 @@ +CREATE TABLE users +( + subject TEXT PRIMARY KEY UNIQUE NOT NULL, + email TEXT UNIQUE NOT NULL, + email_verified BOOLEAN DEFAULT 0 NOT NULL, + roles TEXT NOT NULL, + userinfo TEXT NOT NULL, + access_token TEXT NOT NULL, + refresh_token TEXT NOT NULL, + expiry DATETIME NOT NULL, + updated_at DATETIME NOT NULL, + active BOOLEAN DEFAULT 1 NOT NULL +); + +CREATE TABLE client_store +( + subject TEXT PRIMARY KEY UNIQUE NOT NULL, + name TEXT NOT NULL, + secret TEXT UNIQUE NOT NULL, + domain TEXT NOT NULL, + owner TEXT NOT NULL, + perms TEXT NOT NULL, + public BOOLEAN NOT NULL, + sso BOOLEAN NOT NULL, + active BOOLEAN DEFAULT 1 NOT NULL, + FOREIGN KEY (owner) REFERENCES users (subject) +); diff --git a/database/models.go b/database/models.go new file mode 100644 index 0000000..cffc13c --- /dev/null +++ b/database/models.go @@ -0,0 +1,34 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package database + +import ( + "time" +) + +type ClientStore struct { + Subject string `json:"subject"` + Name string `json:"name"` + Secret string `json:"secret"` + Domain string `json:"domain"` + Owner string `json:"owner"` + Perms string `json:"perms"` + Public bool `json:"public"` + Sso bool `json:"sso"` + Active bool `json:"active"` +} + +type User struct { + Subject string `json:"subject"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Roles string `json:"roles"` + Userinfo string `json:"userinfo"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + Expiry time.Time `json:"expiry"` + UpdatedAt time.Time `json:"updated_at"` + Active bool `json:"active"` +} diff --git a/database/queries/manage-oauth.sql b/database/queries/manage-oauth.sql new file mode 100644 index 0000000..5a70d75 --- /dev/null +++ b/database/queries/manage-oauth.sql @@ -0,0 +1,33 @@ +-- name: GetClientInfo :one +SELECT * +FROM client_store +WHERE subject = ? +LIMIT 1; + +-- name: GetAppList :many +SELECT subject, name, domain, owner, public, sso, active +FROM client_store +WHERE owner = ? + OR ? = 1 +LIMIT 25 OFFSET ?; + +-- name: InsertClientApp :exec +INSERT INTO client_store (subject, name, secret, domain, owner, perms, public, sso, active) +VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); + +-- name: UpdateClientApp :exec +UPDATE client_store +SET name = ?, + domain = ?, + perms = CASE WHEN CAST(? AS BOOLEAN) = true THEN ? ELSE perms END, + public = ?, + sso = ?, + active = ? +WHERE subject = ? + AND owner = ?; + +-- name: ResetClientAppSecret :exec +UPDATE client_store +SET secret = ? +WHERE subject = ? + AND owner = ?; diff --git a/database/queries/manage-users.sql b/database/queries/manage-users.sql new file mode 100644 index 0000000..344ca4c --- /dev/null +++ b/database/queries/manage-users.sql @@ -0,0 +1,18 @@ +-- name: GetUserList :many +SELECT subject, + email, + email_verified, + roles, + updated_at, + active +FROM users +LIMIT 25 OFFSET ?; + +-- name: UpdateUser :exec +UPDATE users +SET active = ?, + roles=? +WHERE subject = ?; + +-- name: UserEmailExists :one +SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND email_verified = 1) == 1 AS email_exists; diff --git a/database/queries/users.sql b/database/queries/users.sql new file mode 100644 index 0000000..e565f83 --- /dev/null +++ b/database/queries/users.sql @@ -0,0 +1,43 @@ +-- name: HasUser :one +SELECT count(subject) > 0 AS hasUser +FROM users; + +-- name: AddUser :exec +INSERT INTO users (subject, email, email_verified, roles, userinfo, updated_at, active) +VALUES (?, ?, ?, ?, ?, ?, ?); + +-- name: UpdateUserInfo :exec +UPDATE users +SET email = ?, + email_verified = ?, + userinfo = ? +WHERE subject = ?; + +-- name: GetUserRoles :one +SELECT roles +FROM users +WHERE subject = ?; + +-- name: GetUser :one +SELECT * +FROM users +WHERE subject = ? +LIMIT 1; + +-- name: UpdateUserToken :exec +UPDATE users +SET access_token = ?, + refresh_token = ?, + expiry = ? +WHERE subject = ?; + +-- name: GetUserToken :one +SELECT access_token, refresh_token, expiry +FROM users +WHERE subject = ? +LIMIT 1; + +-- name: GetUserEmail :one +SELECT email +FROM users +WHERE subject = ?; diff --git a/database/tx.go b/database/tx.go deleted file mode 100644 index 7190b35..0000000 --- a/database/tx.go +++ /dev/null @@ -1,160 +0,0 @@ -package database - -import ( - "database/sql" - "fmt" - "github.com/1f349/lavender/password" - "github.com/go-oauth2/oauth2/v4" - "github.com/google/uuid" - "time" -) - -func updatedAt() string { - return time.Now().UTC().Format(time.DateTime) -} - -type Tx struct{ tx *sql.Tx } - -func (t *Tx) Commit() error { - return t.tx.Commit() -} - -func (t *Tx) Rollback() { - _ = t.tx.Rollback() -} - -func (t *Tx) HasUser() error { - var exists bool - row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users)`) - err := row.Scan(&exists) - if err != nil { - return err - } - if !exists { - return sql.ErrNoRows - } - return nil -} - -func (t *Tx) InsertUser(subject, email string, verifyEmail bool, roles, userinfo string, active bool) error { - _, err := t.tx.Exec(`INSERT INTO users (subject, email, email_verified, roles, userinfo, updated_at, active) VALUES (?, ?, ?, ?, ?, ?, ?)`, subject, email, verifyEmail, roles, userinfo, updatedAt(), active) - return err -} - -func (t *Tx) UpdateUserInfo(subject, email string, verified bool, userinfo string) error { - _, err := t.tx.Exec(`UPDATE users SET email = ?, email_verified = ?, userinfo = ? WHERE subject = ?`, email, verified, userinfo, subject) - return err -} - -func (t *Tx) GetUserRoles(sub string) (string, error) { - var r string - row := t.tx.QueryRow(`SELECT roles FROM users WHERE subject = ? LIMIT 1`, sub) - err := row.Scan(&r) - return r, err -} - -func (t *Tx) GetUser(sub string) (*User, error) { - var u User - row := t.tx.QueryRow(`SELECT email, email_verified, roles, userinfo, updated_at, active FROM users WHERE subject = ?`, sub) - err := row.Scan(&u.Email, &u.EmailVerified, &u.Roles, &u.UserInfo, &u.UpdatedAt, &u.Active) - u.Subject = sub - return &u, err -} - -func (t *Tx) GetUserEmail(sub string) (string, error) { - var email string - row := t.tx.QueryRow(`SELECT email FROM users WHERE subject = ?`, sub) - err := row.Scan(&email) - return email, err -} - -func (t *Tx) GetClientInfo(sub string) (oauth2.ClientInfo, error) { - var u ClientInfoDbOutput - row := t.tx.QueryRow(`SELECT secret, name, domain, perms, public, sso, active FROM client_store WHERE subject = ? LIMIT 1`, sub) - err := row.Scan(&u.Secret, &u.Name, &u.Domain, &u.Perms, &u.Public, &u.Sso, &u.Active) - u.Owner = sub - if !u.Active { - return nil, fmt.Errorf("client is not active") - } - return &u, err -} - -func (t *Tx) GetAppList(owner string, admin bool, offset int) ([]ClientInfoDbOutput, error) { - var u []ClientInfoDbOutput - row, err := t.tx.Query(`SELECT subject, name, domain, owner, perms, public, sso, active FROM client_store WHERE owner = ? OR ? = 1 LIMIT 25 OFFSET ?`, owner, admin, offset) - if err != nil { - return nil, err - } - defer row.Close() - for row.Next() { - var a ClientInfoDbOutput - err := row.Scan(&a.Subject, &a.Name, &a.Domain, &a.Owner, &a.Perms, &a.Public, &a.Sso, &a.Active) - if err != nil { - return nil, err - } - u = append(u, a) - } - return u, row.Err() -} - -func (t *Tx) InsertClientApp(name, domain, owner, perms string, public, sso, active bool) error { - u := uuid.New() - secret, err := password.GenerateApiSecret(70) - if err != nil { - return err - } - _, err = t.tx.Exec(`INSERT INTO client_store (subject, name, secret, domain, owner, perms, public, sso, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, u.String(), name, secret, domain, owner, perms, public, sso, active) - return err -} - -func (t *Tx) UpdateClientApp(subject uuid.UUID, owner, name, domain, perms string, hasPerms, public, sso, active bool) error { - _, err := t.tx.Exec(`UPDATE client_store SET name = ?, domain = ?, perms = CASE WHEN ? = true THEN ? ELSE perms END, public = ?, sso = ?, active = ? WHERE subject = ? AND owner = ?`, name, domain, hasPerms, perms, public, sso, active, subject.String(), owner) - return err -} - -func (t *Tx) ResetClientAppSecret(subject uuid.UUID, owner string) (string, error) { - secret, err := password.GenerateApiSecret(70) - if err != nil { - return "", err - } - _, err = t.tx.Exec(`UPDATE client_store SET secret = ? WHERE subject = ? AND owner = ?`, secret, subject.String(), owner) - return secret, err -} - -func (t *Tx) GetUserList(offset int) ([]User, error) { - var u []User - row, err := t.tx.Query(`SELECT subject, email, email_verified, roles, updated_at, active FROM users LIMIT 25 OFFSET ?`, offset) - if err != nil { - return nil, err - } - for row.Next() { - var a User - err := row.Scan(&a.Subject, &a.Email, &a.EmailVerified, &a.Roles, &a.UpdatedAt, &a.Active) - if err != nil { - return nil, err - } - u = append(u, a) - } - return u, row.Err() -} - -func (t *Tx) UpdateUser(subject, roles string, active bool) error { - _, err := t.tx.Exec(`UPDATE users SET active = ?, roles = ? WHERE subject = ?`, active, roles, subject) - return err -} - -func (t *Tx) UpdateUserToken(subject, accessToken, refreshToken string, expiry time.Time) error { - _, err := t.tx.Exec(`UPDATE users SET access_token = ?, refresh_token = ?, expiry = ? WHERE subject = ?`, accessToken, refreshToken, expiry, subject) - return err -} - -func (t *Tx) GetUserToken(subject string, accessToken, refreshToken *string, expiry *time.Time) error { - row := t.tx.QueryRow(`SELECT access_token, refresh_token, expiry FROM users WHERE subject = ? LIMIT 1`, subject) - return row.Scan(accessToken, refreshToken, expiry) -} - -func (t *Tx) UserEmailExists(email string) (exists bool, err error) { - row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users WHERE email = ? and email_verified = 1)`, email) - err = row.Scan(&exists) - return -} diff --git a/database/users.sql.go b/database/users.sql.go new file mode 100644 index 0000000..10f0e10 --- /dev/null +++ b/database/users.sql.go @@ -0,0 +1,172 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: users.sql + +package database + +import ( + "context" + "time" +) + +const addUser = `-- name: AddUser :exec +INSERT INTO users (subject, email, email_verified, roles, userinfo, updated_at, active) +VALUES (?, ?, ?, ?, ?, ?, ?) +` + +type AddUserParams struct { + Subject string `json:"subject"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Roles string `json:"roles"` + Userinfo string `json:"userinfo"` + UpdatedAt time.Time `json:"updated_at"` + Active bool `json:"active"` +} + +func (q *Queries) AddUser(ctx context.Context, arg AddUserParams) error { + _, err := q.db.ExecContext(ctx, addUser, + arg.Subject, + arg.Email, + arg.EmailVerified, + arg.Roles, + arg.Userinfo, + arg.UpdatedAt, + arg.Active, + ) + return err +} + +const getUser = `-- name: GetUser :one +SELECT subject, email, email_verified, roles, userinfo, access_token, refresh_token, expiry, updated_at, active +FROM users +WHERE subject = ? +LIMIT 1 +` + +func (q *Queries) GetUser(ctx context.Context, subject string) (User, error) { + row := q.db.QueryRowContext(ctx, getUser, subject) + var i User + err := row.Scan( + &i.Subject, + &i.Email, + &i.EmailVerified, + &i.Roles, + &i.Userinfo, + &i.AccessToken, + &i.RefreshToken, + &i.Expiry, + &i.UpdatedAt, + &i.Active, + ) + return i, err +} + +const getUserEmail = `-- name: GetUserEmail :one +SELECT email +FROM users +WHERE subject = ? +` + +func (q *Queries) GetUserEmail(ctx context.Context, subject string) (string, error) { + row := q.db.QueryRowContext(ctx, getUserEmail, subject) + var email string + err := row.Scan(&email) + return email, err +} + +const getUserRoles = `-- name: GetUserRoles :one +SELECT roles +FROM users +WHERE subject = ? +` + +func (q *Queries) GetUserRoles(ctx context.Context, subject string) (string, error) { + row := q.db.QueryRowContext(ctx, getUserRoles, subject) + var roles string + err := row.Scan(&roles) + return roles, err +} + +const getUserToken = `-- name: GetUserToken :one +SELECT access_token, refresh_token, expiry +FROM users +WHERE subject = ? +LIMIT 1 +` + +type GetUserTokenRow struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + Expiry time.Time `json:"expiry"` +} + +func (q *Queries) GetUserToken(ctx context.Context, subject string) (GetUserTokenRow, error) { + row := q.db.QueryRowContext(ctx, getUserToken, subject) + var i GetUserTokenRow + err := row.Scan(&i.AccessToken, &i.RefreshToken, &i.Expiry) + return i, err +} + +const hasUser = `-- name: HasUser :one +SELECT count(subject) > 0 AS hasUser +FROM users +` + +func (q *Queries) HasUser(ctx context.Context) (bool, error) { + row := q.db.QueryRowContext(ctx, hasUser) + var hasuser bool + err := row.Scan(&hasuser) + return hasuser, err +} + +const updateUserInfo = `-- name: UpdateUserInfo :exec +UPDATE users +SET email = ?, + email_verified = ?, + userinfo = ? +WHERE subject = ? +` + +type UpdateUserInfoParams struct { + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Userinfo string `json:"userinfo"` + Subject string `json:"subject"` +} + +func (q *Queries) UpdateUserInfo(ctx context.Context, arg UpdateUserInfoParams) error { + _, err := q.db.ExecContext(ctx, updateUserInfo, + arg.Email, + arg.EmailVerified, + arg.Userinfo, + arg.Subject, + ) + return err +} + +const updateUserToken = `-- name: UpdateUserToken :exec +UPDATE users +SET access_token = ?, + refresh_token = ?, + expiry = ? +WHERE subject = ? +` + +type UpdateUserTokenParams struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + Expiry time.Time `json:"expiry"` + Subject string `json:"subject"` +} + +func (q *Queries) UpdateUserToken(ctx context.Context, arg UpdateUserTokenParams) error { + _, err := q.db.ExecContext(ctx, updateUserToken, + arg.AccessToken, + arg.RefreshToken, + arg.Expiry, + arg.Subject, + ) + return err +} diff --git a/go.mod b/go.mod index e4867a3..9e07ea3 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/charmbracelet/log v0.4.0 github.com/go-oauth2/oauth2/v4 v4.5.2 github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/golang-migrate/migrate/v4 v4.17.1 github.com/google/subcommands v1.2.0 github.com/google/uuid v1.6.0 github.com/julienschmidt/httprouter v1.3.0 @@ -26,7 +27,8 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect - github.com/google/go-querystring v1.1.0 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/klauspost/compress v1.17.8 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -45,6 +47,7 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/rtred v0.1.2 // indirect github.com/tidwall/tinyqueue v0.1.1 // indirect + go.uber.org/atomic v1.11.0 // indirect golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect golang.org/x/net v0.25.0 // indirect golang.org/x/sys v0.20.0 // indirect diff --git a/go.sum b/go.sum index 0b38578..61bfee3 100644 --- a/go.sum +++ b/go.sum @@ -39,6 +39,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-migrate/migrate/v4 v4.17.1 h1:4zQ6iqL6t6AiItphxJctQb3cFqWiSpMnX7wLTPnnYO4= +github.com/golang-migrate/migrate/v4 v4.17.1/go.mod h1:m8hinFyWBn0SA4QKHuKh175Pm9wjmxj3S2Mia7dbXzM= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= @@ -50,7 +52,6 @@ github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= @@ -66,6 +67,11 @@ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORR github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk= github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA= @@ -83,6 +89,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-colorable v0.1.7/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= @@ -179,6 +187,8 @@ github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FB github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 h1:BHyfKlQyqbsFN5p3IfnEUduWvb9is428/nNb5L3U01M= github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM= github.com/yudai/pp v2.0.1+incompatible/go.mod h1:PuxR/8QJ7cyCkFp/aUDS+JY727OFEZkTdatxwunjIkc= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= diff --git a/initdb.go b/initdb.go new file mode 100644 index 0000000..f98070b --- /dev/null +++ b/initdb.go @@ -0,0 +1,38 @@ +package lavender + +import ( + "database/sql" + "embed" + "errors" + "github.com/1f349/lavender/database" + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/sqlite3" + "github.com/golang-migrate/migrate/v4/source/iofs" +) + +//go:embed database/migrations/*.sql +var migrations embed.FS + +func InitDB(p string) (*database.Queries, error) { + migDrv, err := iofs.New(migrations, "database/migrations") + if err != nil { + return nil, err + } + dbOpen, err := sql.Open("sqlite3", p) + if err != nil { + return nil, err + } + dbDrv, err := sqlite3.WithInstance(dbOpen, &sqlite3.Config{}) + if err != nil { + return nil, err + } + mig, err := migrate.NewWithInstance("iofs", migDrv, "sqlite3", dbDrv) + if err != nil { + return nil, err + } + err = mig.Up() + if err != nil && !errors.Is(err, migrate.ErrNoChange) { + return nil, err + } + return database.New(dbOpen), nil +} diff --git a/server/auth.go b/server/auth.go index 1178d1c..8d40964 100644 --- a/server/auth.go +++ b/server/auth.go @@ -21,8 +21,8 @@ func (u UserAuth) IsGuest() bool { return u.Subject == "" } func (h *HttpServer) RequireAdminAuthentication(next UserHandler) httprouter.Handle { return h.RequireAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { var roles string - if h.DbTx(rw, func(tx *database.Tx) (err error) { - roles, err = tx.GetUserRoles(auth.Subject) + if h.DbTx(rw, func(tx *database.Queries) (err error) { + roles, err = tx.GetUserRoles(req.Context(), auth.Subject) return }) { return diff --git a/server/db.go b/server/db.go index 8301d60..3ede57e 100644 --- a/server/db.go +++ b/server/db.go @@ -9,35 +9,14 @@ import ( // DbTx wraps a database transaction with http error messages and a simple action // function. If the action function returns an error the transaction will be // rolled back. If there is no error then the transaction is committed. -func (h *HttpServer) DbTx(rw http.ResponseWriter, action func(tx *database.Tx) error) bool { - tx, err := h.db.Begin() - if err != nil { - http.Error(rw, "Failed to begin database transaction", http.StatusInternalServerError) - return true - } - defer tx.Rollback() - - err = action(tx) +func (h *HttpServer) DbTx(rw http.ResponseWriter, action func(db *database.Queries) error) bool { + err := action(h.db) if err != nil { http.Error(rw, "Database error", http.StatusInternalServerError) - logger.Logger.Warn("Database action error", "er", err) + logger.Logger.Helper() + logger.Logger.Warn("Database action error", "err", err) return true } - err = tx.Commit() - if err != nil { - http.Error(rw, "Database error", http.StatusInternalServerError) - logger.Logger.Warn("Database commit error", "err", err) - } return false } - -func (h *HttpServer) DbTxRaw(action func(tx *database.Tx) error) bool { - return h.DbTx(&fakeRW{}, action) -} - -type fakeRW struct{} - -func (f *fakeRW) Header() http.Header { return http.Header{} } -func (f *fakeRW) Write(b []byte) (int, error) { return len(b), nil } -func (f *fakeRW) WriteHeader(statusCode int) {} diff --git a/server/home.go b/server/home.go index 4e28c61..0142967 100644 --- a/server/home.go +++ b/server/home.go @@ -9,7 +9,7 @@ import ( "time" ) -func (h *HttpServer) Home(rw http.ResponseWriter, _ *http.Request, _ httprouter.Params, auth UserAuth) { +func (h *HttpServer) Home(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { rw.Header().Set("Content-Type", "text/html") lNonce := uuid.NewString() http.SetCookie(rw, &http.Cookie{ @@ -29,8 +29,8 @@ func (h *HttpServer) Home(rw http.ResponseWriter, _ *http.Request, _ httprouter. } var isAdmin bool - h.DbTx(rw, func(tx *database.Tx) (err error) { - roles, err := tx.GetUserRoles(auth.Subject) + h.DbTx(rw, func(tx *database.Queries) (err error) { + roles, err := tx.GetUserRoles(req.Context(), auth.Subject) isAdmin = HasRole(roles, "lavender:admin") return err }) diff --git a/server/id_token.go b/server/id_token.go index 8e502dd..1250043 100644 --- a/server/id_token.go +++ b/server/id_token.go @@ -1,6 +1,7 @@ package server import ( + "context" "github.com/1f349/lavender/database" "github.com/1f349/mjwt" "github.com/go-oauth2/oauth2/v4" @@ -9,7 +10,7 @@ import ( "strings" ) -func addIdTokenSupport(srv *server.Server, db *database.DB, key mjwt.Signer) { +func addIdTokenSupport(srv *server.Server, db *database.Queries, key mjwt.Signer) { srv.SetExtensionFieldsHandler(func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) { scope := ti.GetScope() if containsScope(scope, "openid") { @@ -32,13 +33,8 @@ type IdTokenClaims struct { func (a IdTokenClaims) Valid() error { return nil } func (a IdTokenClaims) Type() string { return "id-token" } -func generateIDToken(ti oauth2.TokenInfo, us *database.DB, key mjwt.Signer) (token string, err error) { - tx, err := us.Begin() - if err != nil { - return "", err - } - defer tx.Rollback() - user, err := tx.GetUser(ti.GetUserID()) +func generateIDToken(ti oauth2.TokenInfo, us *database.Queries, key mjwt.Signer) (token string, err error) { + user, err := us.GetUser(context.Background(), ti.GetUserID()) if err != nil { return "", err } diff --git a/server/jwt.go b/server/jwt.go index dbfe47a..253bd6a 100644 --- a/server/jwt.go +++ b/server/jwt.go @@ -16,25 +16,20 @@ import ( type JWTAccessGenerate struct { signer mjwt.Signer - db *database.DB + db *database.Queries } -func NewJWTAccessGenerate(signer mjwt.Signer, db *database.DB) *JWTAccessGenerate { +func NewJWTAccessGenerate(signer mjwt.Signer, db *database.Queries) *JWTAccessGenerate { return &JWTAccessGenerate{signer, db} } var _ oauth2.AccessGenerate = &JWTAccessGenerate{} func (j *JWTAccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (access, refresh string, err error) { - beginCtx, err := j.db.BeginCtx(ctx) + roles, err := j.db.GetUserRoles(ctx, data.UserID) if err != nil { return "", "", err } - roles, err := beginCtx.GetUserRoles(data.UserID) - if err != nil { - return "", "", err - } - beginCtx.Rollback() ps := claims.ParsePermStorage(roles) out := claims.NewPermStorage() diff --git a/server/login.go b/server/login.go index 97a07de..1dc42f7 100644 --- a/server/login.go +++ b/server/login.go @@ -114,20 +114,33 @@ func (h *HttpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _ return } - if h.DbTx(rw, func(tx *database.Tx) error { + if h.DbTx(rw, func(tx *database.Queries) error { jBytes, err := json.Marshal(sessionData.UserInfo) if err != nil { return err } - _, err = tx.GetUser(sessionData.Subject) + _, err = tx.GetUser(req.Context(), sessionData.Subject) if errors.Is(err, sql.ErrNoRows) { uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost") uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified") - return tx.InsertUser(sessionData.Subject, uEmail, uEmailVerified, "", string(jBytes), true) + return tx.AddUser(req.Context(), database.AddUserParams{ + Subject: sessionData.Subject, + Email: uEmail, + EmailVerified: uEmailVerified, + Roles: "", + Userinfo: string(jBytes), + UpdatedAt: time.Now(), + Active: true, + }) } uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost") uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified") - return tx.UpdateUserInfo(sessionData.Subject, uEmail, uEmailVerified, string(jBytes)) + return tx.UpdateUserInfo(req.Context(), database.UpdateUserInfoParams{ + Email: sessionData.Subject, + EmailVerified: uEmailVerified, + Userinfo: string(jBytes), + Subject: uEmail, + }) }) { return } @@ -135,8 +148,13 @@ func (h *HttpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _ // only continues if the above tx succeeds auth = sessionData - if h.DbTx(rw, func(tx *database.Tx) error { - return tx.UpdateUserToken(auth.Subject, token.AccessToken, token.RefreshToken, token.Expiry) + if h.DbTx(rw, func(tx *database.Queries) error { + return tx.UpdateUserToken(req.Context(), database.UpdateUserTokenParams{ + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + Expiry: token.Expiry, + Subject: auth.Subject, + }) }) { return } diff --git a/server/manage-apps.go b/server/manage-apps.go index eab8807..e96472b 100644 --- a/server/manage-apps.go +++ b/server/manage-apps.go @@ -3,7 +3,7 @@ package server import ( "github.com/1f349/lavender/database" "github.com/1f349/lavender/pages" - "github.com/go-oauth2/oauth2/v4" + "github.com/1f349/lavender/password" "github.com/google/uuid" "github.com/julienschmidt/httprouter" "net/http" @@ -24,13 +24,17 @@ func (h *HttpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _ } var roles string - var appList []database.ClientInfoDbOutput - if h.DbTx(rw, func(tx *database.Tx) (err error) { - roles, err = tx.GetUserRoles(auth.Subject) + var appList []database.GetAppListRow + if h.DbTx(rw, func(tx *database.Queries) (err error) { + roles, err = tx.GetUserRoles(req.Context(), auth.Subject) if err != nil { return } - appList, err = tx.GetAppList(auth.Subject, HasRole(roles, "lavender:admin"), offset) + appList, err = tx.GetAppList(req.Context(), database.GetAppListParams{ + Owner: auth.Subject, + Column2: HasRole(roles, "lavender:admin"), + Offset: int64(offset), + }) return }) { return @@ -63,10 +67,10 @@ func (h *HttpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _ pages.RenderPageTemplate(rw, "manage-apps", m) } -func (h *HttpServer) ManageAppsCreateGet(rw http.ResponseWriter, _ *http.Request, _ httprouter.Params, auth UserAuth) { +func (h *HttpServer) ManageAppsCreateGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { var roles string - if h.DbTx(rw, func(tx *database.Tx) (err error) { - roles, err = tx.GetUserRoles(auth.Subject) + if h.DbTx(rw, func(tx *database.Queries) (err error) { + roles, err = tx.GetUserRoles(req.Context(), auth.Subject) return }) { return @@ -100,8 +104,8 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _ if sso || hasPerms { var roles string - if h.DbTx(rw, func(tx *database.Tx) (err error) { - roles, err = tx.GetUserRoles(auth.Subject) + if h.DbTx(rw, func(tx *database.Queries) (err error) { + roles, err = tx.GetUserRoles(req.Context(), auth.Subject) return }) { return @@ -118,43 +122,64 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _ switch action { case "create": - if h.DbTx(rw, func(tx *database.Tx) error { - return tx.InsertClientApp(name, domain, auth.Subject, perms, public, sso, active) + if h.DbTx(rw, func(tx *database.Queries) error { + secret, err := password.GenerateApiSecret(70) + if err != nil { + return err + } + return tx.InsertClientApp(req.Context(), database.InsertClientAppParams{ + Subject: uuid.NewString(), + Name: name, + Secret: secret, + Domain: domain, + Owner: auth.Subject, + Perms: perms, + Public: public, + Sso: sso, + Active: active, + }) }) { return } case "edit": - if h.DbTx(rw, func(tx *database.Tx) error { - sub, err := uuid.Parse(req.Form.Get("subject")) - if err != nil { - return err - } - return tx.UpdateClientApp(sub, auth.Subject, name, domain, perms, hasPerms, public, sso, active) + if h.DbTx(rw, func(tx *database.Queries) error { + return tx.UpdateClientApp(req.Context(), database.UpdateClientAppParams{ + Name: name, + Domain: domain, + Column3: hasPerms, + Public: public, + Sso: sso, + Active: active, + Subject: req.FormValue("subject"), + Owner: auth.Subject, + }) }) { return } case "secret": - var info oauth2.ClientInfo + var info database.ClientStore var secret string - if h.DbTx(rw, func(tx *database.Tx) error { - sub, err := uuid.Parse(req.Form.Get("subject")) + if h.DbTx(rw, func(tx *database.Queries) error { + sub := req.Form.Get("subject") + info, err = tx.GetClientInfo(req.Context(), sub) if err != nil { return err } - info, err = tx.GetClientInfo(sub.String()) + secret, err := password.GenerateApiSecret(70) if err != nil { return err } - secret, err = tx.ResetClientAppSecret(sub, auth.Subject) + err = tx.ResetClientAppSecret(req.Context(), database.ResetClientAppSecretParams{ + Secret: secret, + Subject: sub, + Owner: auth.Subject, + }) return err }) { return } - appName := "Unknown..." - if getName, ok := info.(interface{ GetName() string }); ok { - appName = getName.GetName() - } + appName := info.GetName() h.ManageAppsGet(rw, &http.Request{ URL: &url.URL{ diff --git a/server/manage-users.go b/server/manage-users.go index d24a815..fc776ae 100644 --- a/server/manage-users.go +++ b/server/manage-users.go @@ -22,13 +22,13 @@ func (h *HttpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _ } var roles string - var userList []database.User - if h.DbTx(rw, func(tx *database.Tx) (err error) { - roles, err = tx.GetUserRoles(auth.Subject) + var userList []database.GetUserListRow + if h.DbTx(rw, func(tx *database.Queries) (err error) { + roles, err = tx.GetUserRoles(req.Context(), auth.Subject) if err != nil { return } - userList, err = tx.GetUserList(offset) + userList, err = tx.GetUserList(req.Context(), int64(offset)) return }) { return @@ -72,8 +72,8 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request, } var roles string - if h.DbTx(rw, func(tx *database.Tx) (err error) { - roles, err = tx.GetUserRoles(auth.Subject) + if h.DbTx(rw, func(tx *database.Queries) (err error) { + roles, err = tx.GetUserRoles(req.Context(), auth.Subject) return }) { return @@ -90,9 +90,13 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request, switch action { case "edit": - if h.DbTx(rw, func(tx *database.Tx) error { + if h.DbTx(rw, func(tx *database.Queries) error { sub := req.Form.Get("subject") - return tx.UpdateUser(sub, newRoles, active) + return tx.UpdateUser(req.Context(), database.UpdateUserParams{ + Active: active, + Roles: newRoles, + Subject: sub, + }) }) { return } diff --git a/server/server.go b/server/server.go index ab6f4ff..d52b3be 100644 --- a/server/server.go +++ b/server/server.go @@ -30,7 +30,7 @@ type HttpServer struct { r *httprouter.Router oauthSrv *server.Server oauthMgr *manage.Manager - db *database.DB + db *database.Queries conf Conf signingKey mjwt.Signer manager *issuer.Manager @@ -42,7 +42,7 @@ type flowStateData struct { redirect string } -func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Server { +func NewHttpServer(conf Conf, db *database.Queries, signingKey mjwt.Signer) *http.Server { r := httprouter.New() contentCache := time.Now() @@ -187,16 +187,16 @@ func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Ser return } - var user *database.User - if hs.DbTx(rw, func(tx *database.Tx) (err error) { - user, err = tx.GetUser(userId) - return err + var user database.User + if hs.DbTx(rw, func(tx *database.Queries) (err error) { + user, err = tx.GetUser(req.Context(), userId) + return }) { return } var userInfo UserInfoFields - err = json.Unmarshal([]byte(user.UserInfo), &userInfo) + err = json.Unmarshal([]byte(user.Userinfo), &userInfo) if err != nil { http.Error(rw, "500 Internal Server Error", http.StatusInternalServerError) return diff --git a/sqlc.yaml b/sqlc.yaml new file mode 100644 index 0000000..7e08599 --- /dev/null +++ b/sqlc.yaml @@ -0,0 +1,10 @@ +version: "2" +sql: + - engine: sqlite + queries: database/queries + schema: database/migrations + gen: + go: + package: "database" + out: "database" + emit_json_tags: true