mirror of
https://github.com/1f349/lavender.git
synced 2024-12-21 15:14:07 +00:00
Start adding sqlc and migration support
This commit is contained in:
parent
08096a4b98
commit
3a7eef803e
@ -7,20 +7,16 @@ import (
|
||||
)
|
||||
|
||||
type ClientStore struct {
|
||||
db *database.DB
|
||||
db *database.Queries
|
||||
}
|
||||
|
||||
var _ oauth2.ClientStore = &ClientStore{}
|
||||
|
||||
func New(db *database.DB) *ClientStore {
|
||||
func New(db *database.Queries) *ClientStore {
|
||||
return &ClientStore{db: db}
|
||||
}
|
||||
|
||||
func (c *ClientStore) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) {
|
||||
tx, err := c.db.BeginCtx(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
return tx.GetClientInfo(id)
|
||||
a, err := c.db.GetClientInfo(ctx, id)
|
||||
return &a, err
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"flag"
|
||||
"github.com/1f349/lavender/database"
|
||||
"github.com/1f349/lavender"
|
||||
"github.com/1f349/lavender/logger"
|
||||
"github.com/1f349/lavender/pages"
|
||||
"github.com/1f349/lavender/server"
|
||||
@ -75,7 +75,7 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
|
||||
}
|
||||
saveMjwtPubKey(signingKey, wd)
|
||||
|
||||
db, err := database.Open(filepath.Join(wd, "lavender.db.sqlite"))
|
||||
db, err := lavender.InitDB(filepath.Join(wd, "lavender.db.sqlite"))
|
||||
if err != nil {
|
||||
logger.Logger.Fatal("Failed to open database:", err)
|
||||
}
|
||||
|
22
database/clientstore.go
Normal file
22
database/clientstore.go
Normal file
@ -0,0 +1,22 @@
|
||||
package database
|
||||
|
||||
import "github.com/go-oauth2/oauth2/v4"
|
||||
|
||||
var _ oauth2.ClientInfo = &ClientStore{}
|
||||
|
||||
func (c *ClientStore) GetID() string { return c.Subject }
|
||||
func (c *ClientStore) GetSecret() string { return c.Secret }
|
||||
func (c *ClientStore) GetDomain() string { return c.Domain }
|
||||
func (c *ClientStore) IsPublic() bool { return c.Public }
|
||||
func (c *ClientStore) GetUserID() string { return c.Owner }
|
||||
|
||||
// GetName is an extra field for the oauth handler to display the application
|
||||
// name
|
||||
func (c *ClientStore) GetName() string { return c.Name }
|
||||
|
||||
// IsSSO is an extra field for the oauth handler to skip the user input stage
|
||||
// this is for trusted applications to get permissions without asking the user
|
||||
func (c *ClientStore) IsSSO() bool { return c.Sso }
|
||||
|
||||
// IsActive is an extra field for the app manager to get the active state
|
||||
func (c *ClientStore) IsActive() bool { return c.Active }
|
@ -1,43 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"github.com/go-oauth2/oauth2/v4"
|
||||
"time"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Subject string `json:"subject"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Roles string `json:"roles"`
|
||||
UserInfo string `json:"userinfo"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
type ClientInfoDbOutput struct {
|
||||
Subject, Name, Secret, Domain, Owner, Perms string
|
||||
Public, Sso, Active bool
|
||||
}
|
||||
|
||||
var _ oauth2.ClientInfo = &ClientInfoDbOutput{}
|
||||
|
||||
func (c *ClientInfoDbOutput) GetID() string { return c.Subject }
|
||||
func (c *ClientInfoDbOutput) GetSecret() string { return c.Secret }
|
||||
func (c *ClientInfoDbOutput) GetDomain() string { return c.Domain }
|
||||
func (c *ClientInfoDbOutput) IsPublic() bool { return c.Public }
|
||||
func (c *ClientInfoDbOutput) GetUserID() string { return c.Owner }
|
||||
|
||||
// GetName is an extra field for the oauth handler to display the application
|
||||
// name
|
||||
func (c *ClientInfoDbOutput) GetName() string { return c.Name }
|
||||
|
||||
// IsSSO is an extra field for the oauth handler to skip the user input stage
|
||||
// this is for trusted applications to get permissions without asking the user
|
||||
func (c *ClientInfoDbOutput) IsSSO() bool { return c.Sso }
|
||||
|
||||
// IsActive is an extra field for the app manager to get the active state
|
||||
func (c *ClientInfoDbOutput) IsActive() bool { return c.Active }
|
||||
|
||||
// UsePerms is an extra field for the userinfo handler to return user permissions matching the requested values
|
||||
func (c *ClientInfoDbOutput) UsePerms() string { return c.Perms }
|
@ -1,41 +1,31 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.25.0
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
//go:embed init.sql
|
||||
var initSql string
|
||||
type DBTX interface {
|
||||
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
|
||||
PrepareContext(context.Context, string) (*sql.Stmt, error)
|
||||
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
type DB struct{ db *sql.DB }
|
||||
func New(db DBTX) *Queries {
|
||||
return &Queries{db: db}
|
||||
}
|
||||
|
||||
func Open(p string) (*DB, error) {
|
||||
db, err := sql.Open("sqlite3", p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
type Queries struct {
|
||||
db DBTX
|
||||
}
|
||||
|
||||
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
|
||||
return &Queries{
|
||||
db: tx,
|
||||
}
|
||||
_, err = db.Exec(initSql)
|
||||
return &DB{db: db}, err
|
||||
}
|
||||
|
||||
func (d *DB) Begin() (*Tx, error) {
|
||||
begin, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Tx{begin}, err
|
||||
}
|
||||
|
||||
func (d *DB) BeginCtx(ctx context.Context) (*Tx, error) {
|
||||
begin, err := d.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Tx{begin}, err
|
||||
}
|
||||
|
||||
func (d *DB) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
||||
|
@ -1,5 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
@ -1,27 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS users
|
||||
(
|
||||
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
|
||||
email TEXT NOT NULL,
|
||||
email_verified INTEGER DEFAULT 0 NOT NULL,
|
||||
roles TEXT NOT NULL,
|
||||
userinfo TEXT,
|
||||
access_token TEXT,
|
||||
refresh_token TEXT,
|
||||
expiry DATETIME,
|
||||
updated_at DATETIME,
|
||||
active INTEGER DEFAULT 1
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS client_store
|
||||
(
|
||||
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
secret TEXT UNIQUE NOT NULL,
|
||||
domain TEXT NOT NULL,
|
||||
owner TEXT NOT NULL,
|
||||
perms TEXT NOT NULL,
|
||||
public INTEGER,
|
||||
sso INTEGER,
|
||||
active INTEGER DEFAULT 1,
|
||||
FOREIGN KEY (owner) REFERENCES users (subject)
|
||||
);
|
178
database/manage-oauth.sql.go
Normal file
178
database/manage-oauth.sql.go
Normal file
@ -0,0 +1,178 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.25.0
|
||||
// source: manage-oauth.sql
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const getAppList = `-- name: GetAppList :many
|
||||
SELECT subject, name, domain, owner, public, sso, active
|
||||
FROM client_store
|
||||
WHERE owner = ?
|
||||
OR ? = 1
|
||||
LIMIT 25 OFFSET ?
|
||||
`
|
||||
|
||||
type GetAppListParams struct {
|
||||
Owner string `json:"owner"`
|
||||
Column2 interface{} `json:"column_2"`
|
||||
Offset int64 `json:"offset"`
|
||||
}
|
||||
|
||||
type GetAppListRow struct {
|
||||
Subject string `json:"subject"`
|
||||
Name string `json:"name"`
|
||||
Domain string `json:"domain"`
|
||||
Owner string `json:"owner"`
|
||||
Public bool `json:"public"`
|
||||
Sso bool `json:"sso"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetAppList(ctx context.Context, arg GetAppListParams) ([]GetAppListRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getAppList, arg.Owner, arg.Column2, arg.Offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetAppListRow
|
||||
for rows.Next() {
|
||||
var i GetAppListRow
|
||||
if err := rows.Scan(
|
||||
&i.Subject,
|
||||
&i.Name,
|
||||
&i.Domain,
|
||||
&i.Owner,
|
||||
&i.Public,
|
||||
&i.Sso,
|
||||
&i.Active,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getClientInfo = `-- name: GetClientInfo :one
|
||||
SELECT subject, name, secret, domain, owner, perms, public, sso, active
|
||||
FROM client_store
|
||||
WHERE subject = ?
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
func (q *Queries) GetClientInfo(ctx context.Context, subject string) (ClientStore, error) {
|
||||
row := q.db.QueryRowContext(ctx, getClientInfo, subject)
|
||||
var i ClientStore
|
||||
err := row.Scan(
|
||||
&i.Subject,
|
||||
&i.Name,
|
||||
&i.Secret,
|
||||
&i.Domain,
|
||||
&i.Owner,
|
||||
&i.Perms,
|
||||
&i.Public,
|
||||
&i.Sso,
|
||||
&i.Active,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertClientApp = `-- name: InsertClientApp :exec
|
||||
INSERT INTO client_store (subject, name, secret, domain, owner, perms, public, sso, active)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
type InsertClientAppParams struct {
|
||||
Subject string `json:"subject"`
|
||||
Name string `json:"name"`
|
||||
Secret string `json:"secret"`
|
||||
Domain string `json:"domain"`
|
||||
Owner string `json:"owner"`
|
||||
Perms string `json:"perms"`
|
||||
Public bool `json:"public"`
|
||||
Sso bool `json:"sso"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertClientApp(ctx context.Context, arg InsertClientAppParams) error {
|
||||
_, err := q.db.ExecContext(ctx, insertClientApp,
|
||||
arg.Subject,
|
||||
arg.Name,
|
||||
arg.Secret,
|
||||
arg.Domain,
|
||||
arg.Owner,
|
||||
arg.Perms,
|
||||
arg.Public,
|
||||
arg.Sso,
|
||||
arg.Active,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const resetClientAppSecret = `-- name: ResetClientAppSecret :exec
|
||||
UPDATE client_store
|
||||
SET secret = ?
|
||||
WHERE subject = ?
|
||||
AND owner = ?
|
||||
`
|
||||
|
||||
type ResetClientAppSecretParams struct {
|
||||
Secret string `json:"secret"`
|
||||
Subject string `json:"subject"`
|
||||
Owner string `json:"owner"`
|
||||
}
|
||||
|
||||
func (q *Queries) ResetClientAppSecret(ctx context.Context, arg ResetClientAppSecretParams) error {
|
||||
_, err := q.db.ExecContext(ctx, resetClientAppSecret, arg.Secret, arg.Subject, arg.Owner)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateClientApp = `-- name: UpdateClientApp :exec
|
||||
UPDATE client_store
|
||||
SET name = ?,
|
||||
domain = ?,
|
||||
perms = CASE WHEN CAST(? AS BOOLEAN) = true THEN ? ELSE perms END,
|
||||
public = ?,
|
||||
sso = ?,
|
||||
active = ?
|
||||
WHERE subject = ?
|
||||
AND owner = ?
|
||||
`
|
||||
|
||||
type UpdateClientAppParams struct {
|
||||
Name string `json:"name"`
|
||||
Domain string `json:"domain"`
|
||||
Column3 bool `json:"column_3"`
|
||||
Perms string `json:"perms"`
|
||||
Public bool `json:"public"`
|
||||
Sso bool `json:"sso"`
|
||||
Active bool `json:"active"`
|
||||
Subject string `json:"subject"`
|
||||
Owner string `json:"owner"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateClientApp(ctx context.Context, arg UpdateClientAppParams) error {
|
||||
_, err := q.db.ExecContext(ctx, updateClientApp,
|
||||
arg.Name,
|
||||
arg.Domain,
|
||||
arg.Column3,
|
||||
arg.Perms,
|
||||
arg.Public,
|
||||
arg.Sso,
|
||||
arg.Active,
|
||||
arg.Subject,
|
||||
arg.Owner,
|
||||
)
|
||||
return err
|
||||
}
|
90
database/manage-users.sql.go
Normal file
90
database/manage-users.sql.go
Normal file
@ -0,0 +1,90 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.25.0
|
||||
// source: manage-users.sql
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
const getUserList = `-- name: GetUserList :many
|
||||
SELECT subject,
|
||||
email,
|
||||
email_verified,
|
||||
roles,
|
||||
updated_at,
|
||||
active
|
||||
FROM users
|
||||
LIMIT 25 OFFSET ?
|
||||
`
|
||||
|
||||
type GetUserListRow struct {
|
||||
Subject string `json:"subject"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Roles string `json:"roles"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getUserList, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetUserListRow
|
||||
for rows.Next() {
|
||||
var i GetUserListRow
|
||||
if err := rows.Scan(
|
||||
&i.Subject,
|
||||
&i.Email,
|
||||
&i.EmailVerified,
|
||||
&i.Roles,
|
||||
&i.UpdatedAt,
|
||||
&i.Active,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateUser = `-- name: UpdateUser :exec
|
||||
UPDATE users
|
||||
SET active = ?,
|
||||
roles=?
|
||||
WHERE subject = ?
|
||||
`
|
||||
|
||||
type UpdateUserParams struct {
|
||||
Active bool `json:"active"`
|
||||
Roles string `json:"roles"`
|
||||
Subject string `json:"subject"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) error {
|
||||
_, err := q.db.ExecContext(ctx, updateUser, arg.Active, arg.Roles, arg.Subject)
|
||||
return err
|
||||
}
|
||||
|
||||
const userEmailExists = `-- name: UserEmailExists :one
|
||||
SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND email_verified = 1) == 1 AS email_exists
|
||||
`
|
||||
|
||||
func (q *Queries) UserEmailExists(ctx context.Context, email string) (bool, error) {
|
||||
row := q.db.QueryRowContext(ctx, userEmailExists, email)
|
||||
var email_exists bool
|
||||
err := row.Scan(&email_exists)
|
||||
return email_exists, err
|
||||
}
|
2
database/migrations/20240517171813_init.down.sql
Normal file
2
database/migrations/20240517171813_init.down.sql
Normal file
@ -0,0 +1,2 @@
|
||||
DROP TABLE users;
|
||||
DROP TABLE client_store;
|
27
database/migrations/20240517171813_init.up.sql
Normal file
27
database/migrations/20240517171813_init.up.sql
Normal file
@ -0,0 +1,27 @@
|
||||
CREATE TABLE users
|
||||
(
|
||||
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
email_verified BOOLEAN DEFAULT 0 NOT NULL,
|
||||
roles TEXT NOT NULL,
|
||||
userinfo TEXT NOT NULL,
|
||||
access_token TEXT NOT NULL,
|
||||
refresh_token TEXT NOT NULL,
|
||||
expiry DATETIME NOT NULL,
|
||||
updated_at DATETIME NOT NULL,
|
||||
active BOOLEAN DEFAULT 1 NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE client_store
|
||||
(
|
||||
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
secret TEXT UNIQUE NOT NULL,
|
||||
domain TEXT NOT NULL,
|
||||
owner TEXT NOT NULL,
|
||||
perms TEXT NOT NULL,
|
||||
public BOOLEAN NOT NULL,
|
||||
sso BOOLEAN NOT NULL,
|
||||
active BOOLEAN DEFAULT 1 NOT NULL,
|
||||
FOREIGN KEY (owner) REFERENCES users (subject)
|
||||
);
|
34
database/models.go
Normal file
34
database/models.go
Normal file
@ -0,0 +1,34 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.25.0
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type ClientStore struct {
|
||||
Subject string `json:"subject"`
|
||||
Name string `json:"name"`
|
||||
Secret string `json:"secret"`
|
||||
Domain string `json:"domain"`
|
||||
Owner string `json:"owner"`
|
||||
Perms string `json:"perms"`
|
||||
Public bool `json:"public"`
|
||||
Sso bool `json:"sso"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Subject string `json:"subject"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Roles string `json:"roles"`
|
||||
Userinfo string `json:"userinfo"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Active bool `json:"active"`
|
||||
}
|
33
database/queries/manage-oauth.sql
Normal file
33
database/queries/manage-oauth.sql
Normal file
@ -0,0 +1,33 @@
|
||||
-- name: GetClientInfo :one
|
||||
SELECT *
|
||||
FROM client_store
|
||||
WHERE subject = ?
|
||||
LIMIT 1;
|
||||
|
||||
-- name: GetAppList :many
|
||||
SELECT subject, name, domain, owner, public, sso, active
|
||||
FROM client_store
|
||||
WHERE owner = ?
|
||||
OR ? = 1
|
||||
LIMIT 25 OFFSET ?;
|
||||
|
||||
-- name: InsertClientApp :exec
|
||||
INSERT INTO client_store (subject, name, secret, domain, owner, perms, public, sso, active)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
|
||||
-- name: UpdateClientApp :exec
|
||||
UPDATE client_store
|
||||
SET name = ?,
|
||||
domain = ?,
|
||||
perms = CASE WHEN CAST(? AS BOOLEAN) = true THEN ? ELSE perms END,
|
||||
public = ?,
|
||||
sso = ?,
|
||||
active = ?
|
||||
WHERE subject = ?
|
||||
AND owner = ?;
|
||||
|
||||
-- name: ResetClientAppSecret :exec
|
||||
UPDATE client_store
|
||||
SET secret = ?
|
||||
WHERE subject = ?
|
||||
AND owner = ?;
|
18
database/queries/manage-users.sql
Normal file
18
database/queries/manage-users.sql
Normal file
@ -0,0 +1,18 @@
|
||||
-- name: GetUserList :many
|
||||
SELECT subject,
|
||||
email,
|
||||
email_verified,
|
||||
roles,
|
||||
updated_at,
|
||||
active
|
||||
FROM users
|
||||
LIMIT 25 OFFSET ?;
|
||||
|
||||
-- name: UpdateUser :exec
|
||||
UPDATE users
|
||||
SET active = ?,
|
||||
roles=?
|
||||
WHERE subject = ?;
|
||||
|
||||
-- name: UserEmailExists :one
|
||||
SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND email_verified = 1) == 1 AS email_exists;
|
43
database/queries/users.sql
Normal file
43
database/queries/users.sql
Normal file
@ -0,0 +1,43 @@
|
||||
-- name: HasUser :one
|
||||
SELECT count(subject) > 0 AS hasUser
|
||||
FROM users;
|
||||
|
||||
-- name: AddUser :exec
|
||||
INSERT INTO users (subject, email, email_verified, roles, userinfo, updated_at, active)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?);
|
||||
|
||||
-- name: UpdateUserInfo :exec
|
||||
UPDATE users
|
||||
SET email = ?,
|
||||
email_verified = ?,
|
||||
userinfo = ?
|
||||
WHERE subject = ?;
|
||||
|
||||
-- name: GetUserRoles :one
|
||||
SELECT roles
|
||||
FROM users
|
||||
WHERE subject = ?;
|
||||
|
||||
-- name: GetUser :one
|
||||
SELECT *
|
||||
FROM users
|
||||
WHERE subject = ?
|
||||
LIMIT 1;
|
||||
|
||||
-- name: UpdateUserToken :exec
|
||||
UPDATE users
|
||||
SET access_token = ?,
|
||||
refresh_token = ?,
|
||||
expiry = ?
|
||||
WHERE subject = ?;
|
||||
|
||||
-- name: GetUserToken :one
|
||||
SELECT access_token, refresh_token, expiry
|
||||
FROM users
|
||||
WHERE subject = ?
|
||||
LIMIT 1;
|
||||
|
||||
-- name: GetUserEmail :one
|
||||
SELECT email
|
||||
FROM users
|
||||
WHERE subject = ?;
|
160
database/tx.go
160
database/tx.go
@ -1,160 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/1f349/lavender/password"
|
||||
"github.com/go-oauth2/oauth2/v4"
|
||||
"github.com/google/uuid"
|
||||
"time"
|
||||
)
|
||||
|
||||
func updatedAt() string {
|
||||
return time.Now().UTC().Format(time.DateTime)
|
||||
}
|
||||
|
||||
type Tx struct{ tx *sql.Tx }
|
||||
|
||||
func (t *Tx) Commit() error {
|
||||
return t.tx.Commit()
|
||||
}
|
||||
|
||||
func (t *Tx) Rollback() {
|
||||
_ = t.tx.Rollback()
|
||||
}
|
||||
|
||||
func (t *Tx) HasUser() error {
|
||||
var exists bool
|
||||
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users)`)
|
||||
err := row.Scan(&exists)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exists {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tx) InsertUser(subject, email string, verifyEmail bool, roles, userinfo string, active bool) error {
|
||||
_, err := t.tx.Exec(`INSERT INTO users (subject, email, email_verified, roles, userinfo, updated_at, active) VALUES (?, ?, ?, ?, ?, ?, ?)`, subject, email, verifyEmail, roles, userinfo, updatedAt(), active)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Tx) UpdateUserInfo(subject, email string, verified bool, userinfo string) error {
|
||||
_, err := t.tx.Exec(`UPDATE users SET email = ?, email_verified = ?, userinfo = ? WHERE subject = ?`, email, verified, userinfo, subject)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Tx) GetUserRoles(sub string) (string, error) {
|
||||
var r string
|
||||
row := t.tx.QueryRow(`SELECT roles FROM users WHERE subject = ? LIMIT 1`, sub)
|
||||
err := row.Scan(&r)
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (t *Tx) GetUser(sub string) (*User, error) {
|
||||
var u User
|
||||
row := t.tx.QueryRow(`SELECT email, email_verified, roles, userinfo, updated_at, active FROM users WHERE subject = ?`, sub)
|
||||
err := row.Scan(&u.Email, &u.EmailVerified, &u.Roles, &u.UserInfo, &u.UpdatedAt, &u.Active)
|
||||
u.Subject = sub
|
||||
return &u, err
|
||||
}
|
||||
|
||||
func (t *Tx) GetUserEmail(sub string) (string, error) {
|
||||
var email string
|
||||
row := t.tx.QueryRow(`SELECT email FROM users WHERE subject = ?`, sub)
|
||||
err := row.Scan(&email)
|
||||
return email, err
|
||||
}
|
||||
|
||||
func (t *Tx) GetClientInfo(sub string) (oauth2.ClientInfo, error) {
|
||||
var u ClientInfoDbOutput
|
||||
row := t.tx.QueryRow(`SELECT secret, name, domain, perms, public, sso, active FROM client_store WHERE subject = ? LIMIT 1`, sub)
|
||||
err := row.Scan(&u.Secret, &u.Name, &u.Domain, &u.Perms, &u.Public, &u.Sso, &u.Active)
|
||||
u.Owner = sub
|
||||
if !u.Active {
|
||||
return nil, fmt.Errorf("client is not active")
|
||||
}
|
||||
return &u, err
|
||||
}
|
||||
|
||||
func (t *Tx) GetAppList(owner string, admin bool, offset int) ([]ClientInfoDbOutput, error) {
|
||||
var u []ClientInfoDbOutput
|
||||
row, err := t.tx.Query(`SELECT subject, name, domain, owner, perms, public, sso, active FROM client_store WHERE owner = ? OR ? = 1 LIMIT 25 OFFSET ?`, owner, admin, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer row.Close()
|
||||
for row.Next() {
|
||||
var a ClientInfoDbOutput
|
||||
err := row.Scan(&a.Subject, &a.Name, &a.Domain, &a.Owner, &a.Perms, &a.Public, &a.Sso, &a.Active)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u = append(u, a)
|
||||
}
|
||||
return u, row.Err()
|
||||
}
|
||||
|
||||
func (t *Tx) InsertClientApp(name, domain, owner, perms string, public, sso, active bool) error {
|
||||
u := uuid.New()
|
||||
secret, err := password.GenerateApiSecret(70)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = t.tx.Exec(`INSERT INTO client_store (subject, name, secret, domain, owner, perms, public, sso, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, u.String(), name, secret, domain, owner, perms, public, sso, active)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Tx) UpdateClientApp(subject uuid.UUID, owner, name, domain, perms string, hasPerms, public, sso, active bool) error {
|
||||
_, err := t.tx.Exec(`UPDATE client_store SET name = ?, domain = ?, perms = CASE WHEN ? = true THEN ? ELSE perms END, public = ?, sso = ?, active = ? WHERE subject = ? AND owner = ?`, name, domain, hasPerms, perms, public, sso, active, subject.String(), owner)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Tx) ResetClientAppSecret(subject uuid.UUID, owner string) (string, error) {
|
||||
secret, err := password.GenerateApiSecret(70)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
_, err = t.tx.Exec(`UPDATE client_store SET secret = ? WHERE subject = ? AND owner = ?`, secret, subject.String(), owner)
|
||||
return secret, err
|
||||
}
|
||||
|
||||
func (t *Tx) GetUserList(offset int) ([]User, error) {
|
||||
var u []User
|
||||
row, err := t.tx.Query(`SELECT subject, email, email_verified, roles, updated_at, active FROM users LIMIT 25 OFFSET ?`, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for row.Next() {
|
||||
var a User
|
||||
err := row.Scan(&a.Subject, &a.Email, &a.EmailVerified, &a.Roles, &a.UpdatedAt, &a.Active)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u = append(u, a)
|
||||
}
|
||||
return u, row.Err()
|
||||
}
|
||||
|
||||
func (t *Tx) UpdateUser(subject, roles string, active bool) error {
|
||||
_, err := t.tx.Exec(`UPDATE users SET active = ?, roles = ? WHERE subject = ?`, active, roles, subject)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Tx) UpdateUserToken(subject, accessToken, refreshToken string, expiry time.Time) error {
|
||||
_, err := t.tx.Exec(`UPDATE users SET access_token = ?, refresh_token = ?, expiry = ? WHERE subject = ?`, accessToken, refreshToken, expiry, subject)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Tx) GetUserToken(subject string, accessToken, refreshToken *string, expiry *time.Time) error {
|
||||
row := t.tx.QueryRow(`SELECT access_token, refresh_token, expiry FROM users WHERE subject = ? LIMIT 1`, subject)
|
||||
return row.Scan(accessToken, refreshToken, expiry)
|
||||
}
|
||||
|
||||
func (t *Tx) UserEmailExists(email string) (exists bool, err error) {
|
||||
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users WHERE email = ? and email_verified = 1)`, email)
|
||||
err = row.Scan(&exists)
|
||||
return
|
||||
}
|
172
database/users.sql.go
Normal file
172
database/users.sql.go
Normal file
@ -0,0 +1,172 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.25.0
|
||||
// source: users.sql
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
const addUser = `-- name: AddUser :exec
|
||||
INSERT INTO users (subject, email, email_verified, roles, userinfo, updated_at, active)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
type AddUserParams struct {
|
||||
Subject string `json:"subject"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Roles string `json:"roles"`
|
||||
Userinfo string `json:"userinfo"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
func (q *Queries) AddUser(ctx context.Context, arg AddUserParams) error {
|
||||
_, err := q.db.ExecContext(ctx, addUser,
|
||||
arg.Subject,
|
||||
arg.Email,
|
||||
arg.EmailVerified,
|
||||
arg.Roles,
|
||||
arg.Userinfo,
|
||||
arg.UpdatedAt,
|
||||
arg.Active,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const getUser = `-- name: GetUser :one
|
||||
SELECT subject, email, email_verified, roles, userinfo, access_token, refresh_token, expiry, updated_at, active
|
||||
FROM users
|
||||
WHERE subject = ?
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
func (q *Queries) GetUser(ctx context.Context, subject string) (User, error) {
|
||||
row := q.db.QueryRowContext(ctx, getUser, subject)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.Subject,
|
||||
&i.Email,
|
||||
&i.EmailVerified,
|
||||
&i.Roles,
|
||||
&i.Userinfo,
|
||||
&i.AccessToken,
|
||||
&i.RefreshToken,
|
||||
&i.Expiry,
|
||||
&i.UpdatedAt,
|
||||
&i.Active,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserEmail = `-- name: GetUserEmail :one
|
||||
SELECT email
|
||||
FROM users
|
||||
WHERE subject = ?
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserEmail(ctx context.Context, subject string) (string, error) {
|
||||
row := q.db.QueryRowContext(ctx, getUserEmail, subject)
|
||||
var email string
|
||||
err := row.Scan(&email)
|
||||
return email, err
|
||||
}
|
||||
|
||||
const getUserRoles = `-- name: GetUserRoles :one
|
||||
SELECT roles
|
||||
FROM users
|
||||
WHERE subject = ?
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserRoles(ctx context.Context, subject string) (string, error) {
|
||||
row := q.db.QueryRowContext(ctx, getUserRoles, subject)
|
||||
var roles string
|
||||
err := row.Scan(&roles)
|
||||
return roles, err
|
||||
}
|
||||
|
||||
const getUserToken = `-- name: GetUserToken :one
|
||||
SELECT access_token, refresh_token, expiry
|
||||
FROM users
|
||||
WHERE subject = ?
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
type GetUserTokenRow struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetUserToken(ctx context.Context, subject string) (GetUserTokenRow, error) {
|
||||
row := q.db.QueryRowContext(ctx, getUserToken, subject)
|
||||
var i GetUserTokenRow
|
||||
err := row.Scan(&i.AccessToken, &i.RefreshToken, &i.Expiry)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const hasUser = `-- name: HasUser :one
|
||||
SELECT count(subject) > 0 AS hasUser
|
||||
FROM users
|
||||
`
|
||||
|
||||
func (q *Queries) HasUser(ctx context.Context) (bool, error) {
|
||||
row := q.db.QueryRowContext(ctx, hasUser)
|
||||
var hasuser bool
|
||||
err := row.Scan(&hasuser)
|
||||
return hasuser, err
|
||||
}
|
||||
|
||||
const updateUserInfo = `-- name: UpdateUserInfo :exec
|
||||
UPDATE users
|
||||
SET email = ?,
|
||||
email_verified = ?,
|
||||
userinfo = ?
|
||||
WHERE subject = ?
|
||||
`
|
||||
|
||||
type UpdateUserInfoParams struct {
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Userinfo string `json:"userinfo"`
|
||||
Subject string `json:"subject"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateUserInfo(ctx context.Context, arg UpdateUserInfoParams) error {
|
||||
_, err := q.db.ExecContext(ctx, updateUserInfo,
|
||||
arg.Email,
|
||||
arg.EmailVerified,
|
||||
arg.Userinfo,
|
||||
arg.Subject,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateUserToken = `-- name: UpdateUserToken :exec
|
||||
UPDATE users
|
||||
SET access_token = ?,
|
||||
refresh_token = ?,
|
||||
expiry = ?
|
||||
WHERE subject = ?
|
||||
`
|
||||
|
||||
type UpdateUserTokenParams struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
Subject string `json:"subject"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateUserToken(ctx context.Context, arg UpdateUserTokenParams) error {
|
||||
_, err := q.db.ExecContext(ctx, updateUserToken,
|
||||
arg.AccessToken,
|
||||
arg.RefreshToken,
|
||||
arg.Expiry,
|
||||
arg.Subject,
|
||||
)
|
||||
return err
|
||||
}
|
5
go.mod
5
go.mod
@ -10,6 +10,7 @@ require (
|
||||
github.com/charmbracelet/log v0.4.0
|
||||
github.com/go-oauth2/oauth2/v4 v4.5.2
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0
|
||||
github.com/golang-migrate/migrate/v4 v4.17.1
|
||||
github.com/google/subcommands v1.2.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/julienschmidt/httprouter v1.3.0
|
||||
@ -26,7 +27,8 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-logfmt/logfmt v0.6.0 // indirect
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/klauspost/compress v1.17.8 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
@ -45,6 +47,7 @@ require (
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/rtred v0.1.2 // indirect
|
||||
github.com/tidwall/tinyqueue v0.1.1 // indirect
|
||||
go.uber.org/atomic v1.11.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
|
||||
golang.org/x/net v0.25.0 // indirect
|
||||
golang.org/x/sys v0.20.0 // indirect
|
||||
|
12
go.sum
12
go.sum
@ -39,6 +39,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/golang-migrate/migrate/v4 v4.17.1 h1:4zQ6iqL6t6AiItphxJctQb3cFqWiSpMnX7wLTPnnYO4=
|
||||
github.com/golang-migrate/migrate/v4 v4.17.1/go.mod h1:m8hinFyWBn0SA4QKHuKh175Pm9wjmxj3S2Mia7dbXzM=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
|
||||
@ -50,7 +52,6 @@ github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw
|
||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
|
||||
@ -66,6 +67,11 @@ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORR
|
||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
|
||||
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
|
||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
|
||||
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk=
|
||||
github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA=
|
||||
@ -83,6 +89,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-colorable v0.1.7/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
|
||||
@ -179,6 +187,8 @@ github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FB
|
||||
github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 h1:BHyfKlQyqbsFN5p3IfnEUduWvb9is428/nNb5L3U01M=
|
||||
github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM=
|
||||
github.com/yudai/pp v2.0.1+incompatible/go.mod h1:PuxR/8QJ7cyCkFp/aUDS+JY727OFEZkTdatxwunjIkc=
|
||||
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
||||
|
38
initdb.go
Normal file
38
initdb.go
Normal file
@ -0,0 +1,38 @@
|
||||
package lavender
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"errors"
|
||||
"github.com/1f349/lavender/database"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
)
|
||||
|
||||
//go:embed database/migrations/*.sql
|
||||
var migrations embed.FS
|
||||
|
||||
func InitDB(p string) (*database.Queries, error) {
|
||||
migDrv, err := iofs.New(migrations, "database/migrations")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dbOpen, err := sql.Open("sqlite3", p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dbDrv, err := sqlite3.WithInstance(dbOpen, &sqlite3.Config{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mig, err := migrate.NewWithInstance("iofs", migDrv, "sqlite3", dbDrv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = mig.Up()
|
||||
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
return nil, err
|
||||
}
|
||||
return database.New(dbOpen), nil
|
||||
}
|
@ -21,8 +21,8 @@ func (u UserAuth) IsGuest() bool { return u.Subject == "" }
|
||||
func (h *HttpServer) RequireAdminAuthentication(next UserHandler) httprouter.Handle {
|
||||
return h.RequireAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
|
||||
var roles string
|
||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
roles, err = tx.GetUserRoles(auth.Subject)
|
||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
|
||||
return
|
||||
}) {
|
||||
return
|
||||
|
29
server/db.go
29
server/db.go
@ -9,35 +9,14 @@ import (
|
||||
// DbTx wraps a database transaction with http error messages and a simple action
|
||||
// function. If the action function returns an error the transaction will be
|
||||
// rolled back. If there is no error then the transaction is committed.
|
||||
func (h *HttpServer) DbTx(rw http.ResponseWriter, action func(tx *database.Tx) error) bool {
|
||||
tx, err := h.db.Begin()
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to begin database transaction", http.StatusInternalServerError)
|
||||
return true
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
err = action(tx)
|
||||
func (h *HttpServer) DbTx(rw http.ResponseWriter, action func(db *database.Queries) error) bool {
|
||||
err := action(h.db)
|
||||
if err != nil {
|
||||
http.Error(rw, "Database error", http.StatusInternalServerError)
|
||||
logger.Logger.Warn("Database action error", "er", err)
|
||||
logger.Logger.Helper()
|
||||
logger.Logger.Warn("Database action error", "err", err)
|
||||
return true
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
http.Error(rw, "Database error", http.StatusInternalServerError)
|
||||
logger.Logger.Warn("Database commit error", "err", err)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *HttpServer) DbTxRaw(action func(tx *database.Tx) error) bool {
|
||||
return h.DbTx(&fakeRW{}, action)
|
||||
}
|
||||
|
||||
type fakeRW struct{}
|
||||
|
||||
func (f *fakeRW) Header() http.Header { return http.Header{} }
|
||||
func (f *fakeRW) Write(b []byte) (int, error) { return len(b), nil }
|
||||
func (f *fakeRW) WriteHeader(statusCode int) {}
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func (h *HttpServer) Home(rw http.ResponseWriter, _ *http.Request, _ httprouter.Params, auth UserAuth) {
|
||||
func (h *HttpServer) Home(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
||||
rw.Header().Set("Content-Type", "text/html")
|
||||
lNonce := uuid.NewString()
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
@ -29,8 +29,8 @@ func (h *HttpServer) Home(rw http.ResponseWriter, _ *http.Request, _ httprouter.
|
||||
}
|
||||
|
||||
var isAdmin bool
|
||||
h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
roles, err := tx.GetUserRoles(auth.Subject)
|
||||
h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
roles, err := tx.GetUserRoles(req.Context(), auth.Subject)
|
||||
isAdmin = HasRole(roles, "lavender:admin")
|
||||
return err
|
||||
})
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/1f349/lavender/database"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/go-oauth2/oauth2/v4"
|
||||
@ -9,7 +10,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func addIdTokenSupport(srv *server.Server, db *database.DB, key mjwt.Signer) {
|
||||
func addIdTokenSupport(srv *server.Server, db *database.Queries, key mjwt.Signer) {
|
||||
srv.SetExtensionFieldsHandler(func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) {
|
||||
scope := ti.GetScope()
|
||||
if containsScope(scope, "openid") {
|
||||
@ -32,13 +33,8 @@ type IdTokenClaims struct {
|
||||
func (a IdTokenClaims) Valid() error { return nil }
|
||||
func (a IdTokenClaims) Type() string { return "id-token" }
|
||||
|
||||
func generateIDToken(ti oauth2.TokenInfo, us *database.DB, key mjwt.Signer) (token string, err error) {
|
||||
tx, err := us.Begin()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
user, err := tx.GetUser(ti.GetUserID())
|
||||
func generateIDToken(ti oauth2.TokenInfo, us *database.Queries, key mjwt.Signer) (token string, err error) {
|
||||
user, err := us.GetUser(context.Background(), ti.GetUserID())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -16,25 +16,20 @@ import (
|
||||
|
||||
type JWTAccessGenerate struct {
|
||||
signer mjwt.Signer
|
||||
db *database.DB
|
||||
db *database.Queries
|
||||
}
|
||||
|
||||
func NewJWTAccessGenerate(signer mjwt.Signer, db *database.DB) *JWTAccessGenerate {
|
||||
func NewJWTAccessGenerate(signer mjwt.Signer, db *database.Queries) *JWTAccessGenerate {
|
||||
return &JWTAccessGenerate{signer, db}
|
||||
}
|
||||
|
||||
var _ oauth2.AccessGenerate = &JWTAccessGenerate{}
|
||||
|
||||
func (j *JWTAccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (access, refresh string, err error) {
|
||||
beginCtx, err := j.db.BeginCtx(ctx)
|
||||
roles, err := j.db.GetUserRoles(ctx, data.UserID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
roles, err := beginCtx.GetUserRoles(data.UserID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
beginCtx.Rollback()
|
||||
|
||||
ps := claims.ParsePermStorage(roles)
|
||||
out := claims.NewPermStorage()
|
||||
|
@ -114,20 +114,33 @@ func (h *HttpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _
|
||||
return
|
||||
}
|
||||
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
jBytes, err := json.Marshal(sessionData.UserInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.GetUser(sessionData.Subject)
|
||||
_, err = tx.GetUser(req.Context(), sessionData.Subject)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost")
|
||||
uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified")
|
||||
return tx.InsertUser(sessionData.Subject, uEmail, uEmailVerified, "", string(jBytes), true)
|
||||
return tx.AddUser(req.Context(), database.AddUserParams{
|
||||
Subject: sessionData.Subject,
|
||||
Email: uEmail,
|
||||
EmailVerified: uEmailVerified,
|
||||
Roles: "",
|
||||
Userinfo: string(jBytes),
|
||||
UpdatedAt: time.Now(),
|
||||
Active: true,
|
||||
})
|
||||
}
|
||||
uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost")
|
||||
uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified")
|
||||
return tx.UpdateUserInfo(sessionData.Subject, uEmail, uEmailVerified, string(jBytes))
|
||||
return tx.UpdateUserInfo(req.Context(), database.UpdateUserInfoParams{
|
||||
Email: sessionData.Subject,
|
||||
EmailVerified: uEmailVerified,
|
||||
Userinfo: string(jBytes),
|
||||
Subject: uEmail,
|
||||
})
|
||||
}) {
|
||||
return
|
||||
}
|
||||
@ -135,8 +148,13 @@ func (h *HttpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _
|
||||
// only continues if the above tx succeeds
|
||||
auth = sessionData
|
||||
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
return tx.UpdateUserToken(auth.Subject, token.AccessToken, token.RefreshToken, token.Expiry)
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
return tx.UpdateUserToken(req.Context(), database.UpdateUserTokenParams{
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
Expiry: token.Expiry,
|
||||
Subject: auth.Subject,
|
||||
})
|
||||
}) {
|
||||
return
|
||||
}
|
||||
|
@ -3,7 +3,7 @@ package server
|
||||
import (
|
||||
"github.com/1f349/lavender/database"
|
||||
"github.com/1f349/lavender/pages"
|
||||
"github.com/go-oauth2/oauth2/v4"
|
||||
"github.com/1f349/lavender/password"
|
||||
"github.com/google/uuid"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"net/http"
|
||||
@ -24,13 +24,17 @@ func (h *HttpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _
|
||||
}
|
||||
|
||||
var roles string
|
||||
var appList []database.ClientInfoDbOutput
|
||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
roles, err = tx.GetUserRoles(auth.Subject)
|
||||
var appList []database.GetAppListRow
|
||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
appList, err = tx.GetAppList(auth.Subject, HasRole(roles, "lavender:admin"), offset)
|
||||
appList, err = tx.GetAppList(req.Context(), database.GetAppListParams{
|
||||
Owner: auth.Subject,
|
||||
Column2: HasRole(roles, "lavender:admin"),
|
||||
Offset: int64(offset),
|
||||
})
|
||||
return
|
||||
}) {
|
||||
return
|
||||
@ -63,10 +67,10 @@ func (h *HttpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _
|
||||
pages.RenderPageTemplate(rw, "manage-apps", m)
|
||||
}
|
||||
|
||||
func (h *HttpServer) ManageAppsCreateGet(rw http.ResponseWriter, _ *http.Request, _ httprouter.Params, auth UserAuth) {
|
||||
func (h *HttpServer) ManageAppsCreateGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
||||
var roles string
|
||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
roles, err = tx.GetUserRoles(auth.Subject)
|
||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
|
||||
return
|
||||
}) {
|
||||
return
|
||||
@ -100,8 +104,8 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
|
||||
|
||||
if sso || hasPerms {
|
||||
var roles string
|
||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
roles, err = tx.GetUserRoles(auth.Subject)
|
||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
|
||||
return
|
||||
}) {
|
||||
return
|
||||
@ -118,43 +122,64 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
|
||||
|
||||
switch action {
|
||||
case "create":
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
return tx.InsertClientApp(name, domain, auth.Subject, perms, public, sso, active)
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
secret, err := password.GenerateApiSecret(70)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.InsertClientApp(req.Context(), database.InsertClientAppParams{
|
||||
Subject: uuid.NewString(),
|
||||
Name: name,
|
||||
Secret: secret,
|
||||
Domain: domain,
|
||||
Owner: auth.Subject,
|
||||
Perms: perms,
|
||||
Public: public,
|
||||
Sso: sso,
|
||||
Active: active,
|
||||
})
|
||||
}) {
|
||||
return
|
||||
}
|
||||
case "edit":
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
sub, err := uuid.Parse(req.Form.Get("subject"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.UpdateClientApp(sub, auth.Subject, name, domain, perms, hasPerms, public, sso, active)
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
return tx.UpdateClientApp(req.Context(), database.UpdateClientAppParams{
|
||||
Name: name,
|
||||
Domain: domain,
|
||||
Column3: hasPerms,
|
||||
Public: public,
|
||||
Sso: sso,
|
||||
Active: active,
|
||||
Subject: req.FormValue("subject"),
|
||||
Owner: auth.Subject,
|
||||
})
|
||||
}) {
|
||||
return
|
||||
}
|
||||
case "secret":
|
||||
var info oauth2.ClientInfo
|
||||
var info database.ClientStore
|
||||
var secret string
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
sub, err := uuid.Parse(req.Form.Get("subject"))
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
sub := req.Form.Get("subject")
|
||||
info, err = tx.GetClientInfo(req.Context(), sub)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
info, err = tx.GetClientInfo(sub.String())
|
||||
secret, err := password.GenerateApiSecret(70)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
secret, err = tx.ResetClientAppSecret(sub, auth.Subject)
|
||||
err = tx.ResetClientAppSecret(req.Context(), database.ResetClientAppSecretParams{
|
||||
Secret: secret,
|
||||
Subject: sub,
|
||||
Owner: auth.Subject,
|
||||
})
|
||||
return err
|
||||
}) {
|
||||
return
|
||||
}
|
||||
|
||||
appName := "Unknown..."
|
||||
if getName, ok := info.(interface{ GetName() string }); ok {
|
||||
appName = getName.GetName()
|
||||
}
|
||||
appName := info.GetName()
|
||||
|
||||
h.ManageAppsGet(rw, &http.Request{
|
||||
URL: &url.URL{
|
||||
|
@ -22,13 +22,13 @@ func (h *HttpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _
|
||||
}
|
||||
|
||||
var roles string
|
||||
var userList []database.User
|
||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
roles, err = tx.GetUserRoles(auth.Subject)
|
||||
var userList []database.GetUserListRow
|
||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
userList, err = tx.GetUserList(offset)
|
||||
userList, err = tx.GetUserList(req.Context(), int64(offset))
|
||||
return
|
||||
}) {
|
||||
return
|
||||
@ -72,8 +72,8 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
|
||||
}
|
||||
|
||||
var roles string
|
||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
roles, err = tx.GetUserRoles(auth.Subject)
|
||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
|
||||
return
|
||||
}) {
|
||||
return
|
||||
@ -90,9 +90,13 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
|
||||
|
||||
switch action {
|
||||
case "edit":
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
sub := req.Form.Get("subject")
|
||||
return tx.UpdateUser(sub, newRoles, active)
|
||||
return tx.UpdateUser(req.Context(), database.UpdateUserParams{
|
||||
Active: active,
|
||||
Roles: newRoles,
|
||||
Subject: sub,
|
||||
})
|
||||
}) {
|
||||
return
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ type HttpServer struct {
|
||||
r *httprouter.Router
|
||||
oauthSrv *server.Server
|
||||
oauthMgr *manage.Manager
|
||||
db *database.DB
|
||||
db *database.Queries
|
||||
conf Conf
|
||||
signingKey mjwt.Signer
|
||||
manager *issuer.Manager
|
||||
@ -42,7 +42,7 @@ type flowStateData struct {
|
||||
redirect string
|
||||
}
|
||||
|
||||
func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Server {
|
||||
func NewHttpServer(conf Conf, db *database.Queries, signingKey mjwt.Signer) *http.Server {
|
||||
r := httprouter.New()
|
||||
contentCache := time.Now()
|
||||
|
||||
@ -187,16 +187,16 @@ func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Ser
|
||||
return
|
||||
}
|
||||
|
||||
var user *database.User
|
||||
if hs.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
user, err = tx.GetUser(userId)
|
||||
return err
|
||||
var user database.User
|
||||
if hs.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
user, err = tx.GetUser(req.Context(), userId)
|
||||
return
|
||||
}) {
|
||||
return
|
||||
}
|
||||
|
||||
var userInfo UserInfoFields
|
||||
err = json.Unmarshal([]byte(user.UserInfo), &userInfo)
|
||||
err = json.Unmarshal([]byte(user.Userinfo), &userInfo)
|
||||
if err != nil {
|
||||
http.Error(rw, "500 Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user