Fix a bunch more compile breaking issues

This commit is contained in:
Melon 2024-10-05 21:08:02 +01:00
parent 7064afd55e
commit d25f9ae2ca
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
43 changed files with 1574 additions and 452 deletions

View File

@ -1,5 +1,7 @@
package auth package auth
import "github.com/hardfinhq/go-date"
type UserInfoFields map[string]any type UserInfoFields map[string]any
func (u UserInfoFields) GetString(key string) (string, bool) { func (u UserInfoFields) GetString(key string) (string, bool) {
@ -20,7 +22,24 @@ func (u UserInfoFields) GetStringOrEmpty(key string) string {
return s return s
} }
func (u UserInfoFields) GetStringFromKeysOrEmpty(keys ...string) string {
for _, key := range keys {
s, _ := u[key].(string)
if s == "" {
continue
}
return s
}
return ""
}
func (u UserInfoFields) GetBoolean(key string) (bool, bool) { func (u UserInfoFields) GetBoolean(key string) (bool, bool) {
b, ok := u[key].(bool) b, ok := u[key].(bool)
return b, ok return b, ok
} }
func (u UserInfoFields) GetNullDate(key string) date.NullDate {
s, _ := u[key].(string)
fromStr, err := date.FromString(s)
return date.NullDate{Date: fromStr, Valid: err == nil}
}

View File

@ -3,10 +3,13 @@ package main
import ( import (
"context" "context"
"flag" "flag"
"fmt"
"github.com/1f349/lavender" "github.com/1f349/lavender"
"github.com/1f349/lavender/conf" "github.com/1f349/lavender/conf"
"github.com/1f349/lavender/database"
"github.com/1f349/lavender/logger" "github.com/1f349/lavender/logger"
"github.com/1f349/lavender/pages" "github.com/1f349/lavender/pages"
"github.com/1f349/lavender/role"
"github.com/1f349/lavender/server" "github.com/1f349/lavender/server"
"github.com/1f349/mjwt" "github.com/1f349/mjwt"
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
@ -114,6 +117,10 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
logger.Logger.Fatal("Failed to open database", "err", err) logger.Logger.Fatal("Failed to open database", "err", err)
} }
if err := checkDbHasUser(db); err != nil {
logger.Logger.Fatal("Failed to add initial user", "err", err)
}
if err := pages.LoadPages(wd); err != nil { if err := pages.LoadPages(wd); err != nil {
logger.Logger.Fatal("Failed to load page templates:", err) logger.Logger.Fatal("Failed to load page templates:", err)
} }
@ -168,3 +175,45 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
return subcommands.ExitSuccess return subcommands.ExitSuccess
} }
func checkDbHasUser(db *database.Queries) error {
value, err := db.HasUser(context.Background())
if err != nil {
return err
}
if !value {
logger.Logger.Warn("No users are available, setting up initial admin user")
ctx := context.Background()
err = db.UseTx(ctx, func(tx *database.Queries) error {
adminUuid, err := db.AddLocalUser(context.Background(), database.AddLocalUserParams{
Password: "admin",
Email: "admin@localhost",
EmailVerified: false,
Name: "Admin",
Username: "admin",
ChangePassword: true,
})
if err != nil {
return fmt.Errorf("failed to add user: %w", err)
}
roleId, err := db.AddRole(context.Background(), role.LavenderAdmin)
if err != nil {
return fmt.Errorf("failed to add role: %w", err)
}
err = db.AddUserRole(context.Background(), database.AddUserRoleParams{
RoleID: roleId,
Subject: adminUuid,
})
if err != nil {
return fmt.Errorf("failed to add user role: %w", err)
}
return nil
})
if err != nil {
return err
}
}
return nil
}

View File

@ -12,6 +12,7 @@ type Conf struct {
Issuer string `yaml:"issuer"` Issuer string `yaml:"issuer"`
Kid string `yaml:"kid"` Kid string `yaml:"kid"`
Namespace string `yaml:"namespace"` Namespace string `yaml:"namespace"`
OtpIssuer string `yaml:"otpIssuer"`
Mail mail.Mail `yaml:"mail"` Mail mail.Mail `yaml:"mail"`
SsoServices []issuer.SsoConfig `yaml:"ssoServices"` SsoServices map[string]issuer.SsoConfig `yaml:"ssoServices"`
} }

View File

@ -20,13 +20,13 @@ SELECT subject,
active active
FROM client_store FROM client_store
WHERE owner_subject = ? WHERE owner_subject = ?
OR ? = 1 OR CAST(? AS BOOLEAN) = 1
LIMIT 25 OFFSET ? LIMIT 25 OFFSET ?
` `
type GetAppListParams struct { type GetAppListParams struct {
OwnerSubject string `json:"owner_subject"` OwnerSubject string `json:"owner_subject"`
Column2 interface{} `json:"column_2"` Column2 bool `json:"column_2"`
Offset int64 `json:"offset"` Offset int64 `json:"offset"`
} }

View File

@ -7,10 +7,30 @@ package database
import ( import (
"context" "context"
"database/sql"
"strings" "strings"
"time" "time"
"github.com/1f349/lavender/database/types"
) )
const addUserRole = `-- name: AddUserRole :exec
INSERT INTO users_roles(role_id, user_id)
SELECT ?, users.id
FROM users
WHERE subject = ?
`
type AddUserRoleParams struct {
RoleID int64 `json:"role_id"`
Subject string `json:"subject"`
}
func (q *Queries) AddUserRole(ctx context.Context, arg AddUserRoleParams) error {
_, err := q.db.ExecContext(ctx, addUserRole, arg.RoleID, arg.Subject)
return err
}
const changeUserActive = `-- name: ChangeUserActive :exec const changeUserActive = `-- name: ChangeUserActive :exec
UPDATE users UPDATE users
SET active = cast(? as boolean) SET active = cast(? as boolean)
@ -34,11 +54,9 @@ SELECT users.subject,
website, website,
email, email,
email_verified, email_verified,
users.updated_at as user_updated_at, updated_at,
p.updated_at as profile_updated_at,
active active
FROM users FROM users
INNER JOIN main.profiles p on users.subject = p.subject
LIMIT 50 OFFSET ? LIMIT 50 OFFSET ?
` `
@ -49,11 +67,11 @@ type GetUserListRow struct {
Website string `json:"website"` Website string `json:"website"`
Email string `json:"email"` Email string `json:"email"`
EmailVerified bool `json:"email_verified"` EmailVerified bool `json:"email_verified"`
UserUpdatedAt time.Time `json:"user_updated_at"` UpdatedAt time.Time `json:"updated_at"`
ProfileUpdatedAt time.Time `json:"profile_updated_at"`
Active bool `json:"active"` Active bool `json:"active"`
} }
// INNER JOIN main.profiles p on users.subject = p.subject
func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListRow, error) { func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListRow, error) {
rows, err := q.db.QueryContext(ctx, getUserList, offset) rows, err := q.db.QueryContext(ctx, getUserList, offset)
if err != nil { if err != nil {
@ -70,8 +88,7 @@ func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListR
&i.Website, &i.Website,
&i.Email, &i.Email,
&i.EmailVerified, &i.EmailVerified,
&i.UserUpdatedAt, &i.UpdatedAt,
&i.ProfileUpdatedAt,
&i.Active, &i.Active,
); err != nil { ); err != nil {
return nil, err return nil, err
@ -87,6 +104,25 @@ func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListR
return items, nil return items, nil
} }
const getUserToken = `-- name: GetUserToken :one
SELECT access_token, refresh_token, token_expiry
FROM users
WHERE subject = ?
`
type GetUserTokenRow struct {
AccessToken sql.NullString `json:"access_token"`
RefreshToken sql.NullString `json:"refresh_token"`
TokenExpiry sql.NullTime `json:"token_expiry"`
}
func (q *Queries) GetUserToken(ctx context.Context, subject string) (GetUserTokenRow, error) {
row := q.db.QueryRowContext(ctx, getUserToken, subject)
var i GetUserTokenRow
err := row.Scan(&i.AccessToken, &i.RefreshToken, &i.TokenExpiry)
return i, err
}
const getUsersRoles = `-- name: GetUsersRoles :many const getUsersRoles = `-- name: GetUsersRoles :many
SELECT r.role, u.id SELECT r.role, u.id
FROM users_roles FROM users_roles
@ -133,6 +169,105 @@ func (q *Queries) GetUsersRoles(ctx context.Context, userIds []int64) ([]GetUser
return items, nil return items, nil
} }
const modifyUserAuth = `-- name: ModifyUserAuth :exec
UPDATE users
SET auth_type = ?,
auth_namespace=?,
auth_user = ?
WHERE subject = ?
`
type ModifyUserAuthParams struct {
AuthType types.AuthType `json:"auth_type"`
AuthNamespace string `json:"auth_namespace"`
AuthUser string `json:"auth_user"`
Subject string `json:"subject"`
}
func (q *Queries) ModifyUserAuth(ctx context.Context, arg ModifyUserAuthParams) error {
_, err := q.db.ExecContext(ctx, modifyUserAuth,
arg.AuthType,
arg.AuthNamespace,
arg.AuthUser,
arg.Subject,
)
return err
}
const modifyUserEmail = `-- name: ModifyUserEmail :exec
UPDATE users
SET email = ?,
email_verified=?
WHERE subject = ?
`
type ModifyUserEmailParams struct {
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Subject string `json:"subject"`
}
func (q *Queries) ModifyUserEmail(ctx context.Context, arg ModifyUserEmailParams) error {
_, err := q.db.ExecContext(ctx, modifyUserEmail, arg.Email, arg.EmailVerified, arg.Subject)
return err
}
const modifyUserRemoteLogin = `-- name: ModifyUserRemoteLogin :exec
UPDATE users
SET login = ?,
profile_url = ?
WHERE subject = ?
`
type ModifyUserRemoteLoginParams struct {
Login string `json:"login"`
ProfileUrl string `json:"profile_url"`
Subject string `json:"subject"`
}
func (q *Queries) ModifyUserRemoteLogin(ctx context.Context, arg ModifyUserRemoteLoginParams) error {
_, err := q.db.ExecContext(ctx, modifyUserRemoteLogin, arg.Login, arg.ProfileUrl, arg.Subject)
return err
}
const removeUserRoles = `-- name: RemoveUserRoles :exec
DELETE
FROM users_roles
WHERE user_id IN (SELECT id
FROM users
WHERE subject = ?)
`
func (q *Queries) RemoveUserRoles(ctx context.Context, subject string) error {
_, err := q.db.ExecContext(ctx, removeUserRoles, subject)
return err
}
const updateUserToken = `-- name: UpdateUserToken :exec
UPDATE users
SET access_token = ?,
refresh_token=?,
token_expiry = ?
WHERE subject = ?
`
type UpdateUserTokenParams struct {
AccessToken sql.NullString `json:"access_token"`
RefreshToken sql.NullString `json:"refresh_token"`
TokenExpiry sql.NullTime `json:"token_expiry"`
Subject string `json:"subject"`
}
func (q *Queries) UpdateUserToken(ctx context.Context, arg UpdateUserTokenParams) error {
_, err := q.db.ExecContext(ctx, updateUserToken,
arg.AccessToken,
arg.RefreshToken,
arg.TokenExpiry,
arg.Subject,
)
return err
}
const userEmailExists = `-- name: UserEmailExists :one const userEmailExists = `-- name: UserEmailExists :one
SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND email_verified = 1) == 1 AS email_exists SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND email_verified = 1) == 1 AS email_exists
` `
@ -143,3 +278,14 @@ func (q *Queries) UserEmailExists(ctx context.Context, email string) (bool, erro
err := row.Scan(&email_exists) err := row.Scan(&email_exists)
return email_exists, err return email_exists, err
} }
const verifyUserEmail = `-- name: VerifyUserEmail :exec
UPDATE users
SET email_verified=1
WHERE subject = ?
`
func (q *Queries) VerifyUserEmail(ctx context.Context, subject string) error {
_, err := q.db.ExecContext(ctx, verifyUserEmail, subject)
return err
}

View File

@ -21,9 +21,21 @@ CREATE TABLE users
zone TEXT NOT NULL DEFAULT 'UTC', zone TEXT NOT NULL DEFAULT 'UTC',
locale TEXT NOT NULL DEFAULT 'en-US', locale TEXT NOT NULL DEFAULT 'en-US',
login TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
auth_type INTEGER NOT NULL, auth_type INTEGER NOT NULL,
auth_namespace TEXT NOT NULL, auth_namespace TEXT NOT NULL,
auth_user TEXT NOT NULL auth_user TEXT NOT NULL,
access_token TEXT NULL DEFAULT NULL,
refresh_token TEXT NULL DEFAULT NULL,
token_expiry DATETIME NULL DEFAULT NULL,
otp_secret TEXT NOT NULL DEFAULT '',
otp_digits INTEGER NOT NULL DEFAULT 0,
to_delete BOOLEAN NOT NULL DEFAULT 0
); );
CREATE INDEX users_subject ON users (subject); CREATE INDEX users_subject ON users (subject);
@ -39,21 +51,12 @@ CREATE TABLE users_roles
role_id INTEGER NOT NULL, role_id INTEGER NOT NULL,
user_id INTEGER NOT NULL, user_id INTEGER NOT NULL,
FOREIGN KEY (role_id) REFERENCES roles (id), FOREIGN KEY (role_id) REFERENCES roles (id) ON DELETE RESTRICT,
FOREIGN KEY (user_id) REFERENCES users (id), FOREIGN KEY (user_id) REFERENCES users (id),
CONSTRAINT user_role UNIQUE (role_id, user_id) CONSTRAINT user_role UNIQUE (role_id, user_id)
); );
CREATE TABLE otp
(
subject INTEGER NOT NULL UNIQUE PRIMARY KEY,
secret TEXT NOT NULL,
digits INTEGER NOT NULL,
FOREIGN KEY (subject) REFERENCES users (subject)
);
CREATE TABLE client_store CREATE TABLE client_store
( (
subject TEXT NOT NULL UNIQUE PRIMARY KEY, subject TEXT NOT NULL UNIQUE PRIMARY KEY,

View File

@ -5,9 +5,12 @@
package database package database
import ( import (
"database/sql"
"time" "time"
"github.com/1f349/lavender/database/types"
"github.com/1f349/lavender/password" "github.com/1f349/lavender/password"
"github.com/hardfinhq/go-date"
) )
type ClientStore struct { type ClientStore struct {
@ -22,24 +25,6 @@ type ClientStore struct {
Active bool `json:"active"` Active bool `json:"active"`
} }
type Otp struct {
Subject int64 `json:"subject"`
Secret string `json:"secret"`
Digits int64 `json:"digits"`
}
type Profile struct {
Subject string `json:"subject"`
Name string `json:"name"`
Picture string `json:"picture"`
Website string `json:"website"`
Pronouns string `json:"pronouns"`
Birthdate interface{} `json:"birthdate"`
Zone string `json:"zone"`
Locale string `json:"locale"`
UpdatedAt time.Time `json:"updated_at"`
}
type Role struct { type Role struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Role string `json:"role"` Role string `json:"role"`
@ -49,11 +34,30 @@ type User struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Subject string `json:"subject"` Subject string `json:"subject"`
Password password.HashString `json:"password"` Password password.HashString `json:"password"`
ChangePassword bool `json:"change_password"`
Email string `json:"email"` Email string `json:"email"`
EmailVerified bool `json:"email_verified"` EmailVerified bool `json:"email_verified"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
Registered time.Time `json:"registered"` Registered time.Time `json:"registered"`
Active bool `json:"active"` Active bool `json:"active"`
Name string `json:"name"`
Picture string `json:"picture"`
Website string `json:"website"`
Pronouns types.UserPronoun `json:"pronouns"`
Birthdate date.NullDate `json:"birthdate"`
Zone string `json:"zone"`
Locale types.UserLocale `json:"locale"`
Login string `json:"login"`
ProfileUrl string `json:"profile_url"`
AuthType types.AuthType `json:"auth_type"`
AuthNamespace string `json:"auth_namespace"`
AuthUser string `json:"auth_user"`
AccessToken sql.NullString `json:"access_token"`
RefreshToken sql.NullString `json:"refresh_token"`
TokenExpiry sql.NullTime `json:"token_expiry"`
OtpSecret string `json:"otp_secret"`
OtpDigits int64 `json:"otp_digits"`
ToDelete bool `json:"to_delete"`
} }
type UsersRole struct { type UsersRole struct {

View File

@ -10,31 +10,32 @@ import (
) )
const deleteOtp = `-- name: DeleteOtp :exec const deleteOtp = `-- name: DeleteOtp :exec
DELETE UPDATE users
FROM otp SET otp_secret='',
WHERE otp.subject = ? otp_digits=0
WHERE subject = ?
` `
func (q *Queries) DeleteOtp(ctx context.Context, subject int64) error { func (q *Queries) DeleteOtp(ctx context.Context, subject string) error {
_, err := q.db.ExecContext(ctx, deleteOtp, subject) _, err := q.db.ExecContext(ctx, deleteOtp, subject)
return err return err
} }
const getOtp = `-- name: GetOtp :one const getOtp = `-- name: GetOtp :one
SELECT secret, digits SELECT otp_secret, otp_digits
FROM otp FROM users
WHERE subject = ? WHERE subject = ?
` `
type GetOtpRow struct { type GetOtpRow struct {
Secret string `json:"secret"` OtpSecret string `json:"otp_secret"`
Digits int64 `json:"digits"` OtpDigits int64 `json:"otp_digits"`
} }
func (q *Queries) GetOtp(ctx context.Context, subject int64) (GetOtpRow, error) { func (q *Queries) GetOtp(ctx context.Context, subject string) (GetOtpRow, error) {
row := q.db.QueryRowContext(ctx, getOtp, subject) row := q.db.QueryRowContext(ctx, getOtp, subject)
var i GetOtpRow var i GetOtpRow
err := row.Scan(&i.Secret, &i.Digits) err := row.Scan(&i.OtpSecret, &i.OtpDigits)
return i, err return i, err
} }
@ -52,10 +53,13 @@ func (q *Queries) GetUserEmail(ctx context.Context, subject string) (string, err
} }
const hasOtp = `-- name: HasOtp :one const hasOtp = `-- name: HasOtp :one
SELECT EXISTS(SELECT 1 FROM otp WHERE subject = ?) == 1 as hasOtp SELECT CAST(1 AS BOOLEAN) AS hasOtp
FROM users
WHERE subject = ?
AND otp_secret != ''
` `
func (q *Queries) HasOtp(ctx context.Context, subject int64) (bool, error) { func (q *Queries) HasOtp(ctx context.Context, subject string) (bool, error) {
row := q.db.QueryRowContext(ctx, hasOtp, subject) row := q.db.QueryRowContext(ctx, hasOtp, subject)
var hasotp bool var hasotp bool
err := row.Scan(&hasotp) err := row.Scan(&hasotp)
@ -63,19 +67,19 @@ func (q *Queries) HasOtp(ctx context.Context, subject int64) (bool, error) {
} }
const setOtp = `-- name: SetOtp :exec const setOtp = `-- name: SetOtp :exec
INSERT OR UPDATE users
REPLACE SET otp_secret = ?,
INTO otp (subject, secret, digits) otp_digits=?
VALUES (?, ?, ?) WHERE subject = ?
` `
type SetOtpParams struct { type SetOtpParams struct {
Subject int64 `json:"subject"` OtpSecret string `json:"otp_secret"`
Secret string `json:"secret"` OtpDigits int64 `json:"otp_digits"`
Digits int64 `json:"digits"` Subject string `json:"subject"`
} }
func (q *Queries) SetOtp(ctx context.Context, arg SetOtpParams) error { func (q *Queries) SetOtp(ctx context.Context, arg SetOtpParams) error {
_, err := q.db.ExecContext(ctx, setOtp, arg.Subject, arg.Secret, arg.Digits) _, err := q.db.ExecContext(ctx, setOtp, arg.OtpSecret, arg.OtpDigits, arg.Subject)
return err return err
} }

View File

@ -2,22 +2,22 @@ package database
import ( import (
"context" "context"
"github.com/1f349/lavender/database/types"
"github.com/1f349/lavender/password" "github.com/1f349/lavender/password"
"github.com/google/uuid" "github.com/google/uuid"
"time" "time"
) )
type AddUserParams struct { type AddLocalUserParams struct {
Name string `json:"name"`
Subject string `json:"subject"`
Password string `json:"password"` Password string `json:"password"`
Email string `json:"email"` Email string `json:"email"`
EmailVerified bool `json:"email_verified"` EmailVerified bool `json:"email_verified"`
UpdatedAt time.Time `json:"updated_at"` Name string `json:"name"`
Active bool `json:"active"` Username string `json:"username"`
ChangePassword bool `json:"change_password"`
} }
func (q *Queries) AddUser(ctx context.Context, arg AddUserParams) (string, error) { func (q *Queries) AddLocalUser(ctx context.Context, arg AddLocalUserParams) (string, error) {
pwHash, err := password.HashPassword(arg.Password) pwHash, err := password.HashPassword(arg.Password)
if err != nil { if err != nil {
return "", err return "", err
@ -31,6 +31,40 @@ func (q *Queries) AddUser(ctx context.Context, arg AddUserParams) (string, error
UpdatedAt: n, UpdatedAt: n,
Registered: n, Registered: n,
Active: true, Active: true,
Name: arg.Name,
Login: arg.Username,
ChangePassword: arg.ChangePassword,
AuthType: types.AuthTypeLocal,
AuthNamespace: "",
AuthUser: arg.Username,
}
return a.Subject, q.addUser(ctx, a)
}
type AddOAuthUserParams struct {
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Name string `json:"name"`
Username string `json:"username"`
AuthNamespace string `json:"auth_namespace"`
AuthUser string `json:"auth_user"`
}
func (q *Queries) AddOAuthUser(ctx context.Context, arg AddOAuthUserParams) (string, error) {
n := time.Now()
a := addUserParams{
Subject: uuid.NewString(),
Email: arg.Email,
EmailVerified: arg.EmailVerified,
UpdatedAt: n,
Registered: n,
Active: true,
Name: arg.Name,
Login: arg.Username,
ChangePassword: false,
AuthType: types.AuthTypeOauth2,
AuthNamespace: arg.AuthNamespace,
AuthUser: arg.AuthUser,
} }
return a.Subject, q.addUser(ctx, a) return a.Subject, q.addUser(ctx, a)
} }

View File

@ -8,17 +8,38 @@ package database
import ( import (
"context" "context"
"time" "time"
"github.com/1f349/lavender/database/types"
"github.com/hardfinhq/go-date"
) )
const getProfile = `-- name: GetProfile :one const getProfile = `-- name: GetProfile :one
SELECT profiles.subject, profiles.name, profiles.picture, profiles.website, profiles.pronouns, profiles.birthdate, profiles.zone, profiles.locale, profiles.updated_at SELECT subject,
FROM profiles name,
picture,
website,
pronouns,
birthdate,
zone,
locale
FROM users
WHERE subject = ? WHERE subject = ?
` `
func (q *Queries) GetProfile(ctx context.Context, subject string) (Profile, error) { type GetProfileRow struct {
Subject string `json:"subject"`
Name string `json:"name"`
Picture string `json:"picture"`
Website string `json:"website"`
Pronouns types.UserPronoun `json:"pronouns"`
Birthdate date.NullDate `json:"birthdate"`
Zone string `json:"zone"`
Locale types.UserLocale `json:"locale"`
}
func (q *Queries) GetProfile(ctx context.Context, subject string) (GetProfileRow, error) {
row := q.db.QueryRowContext(ctx, getProfile, subject) row := q.db.QueryRowContext(ctx, getProfile, subject)
var i Profile var i GetProfileRow
err := row.Scan( err := row.Scan(
&i.Subject, &i.Subject,
&i.Name, &i.Name,
@ -28,13 +49,12 @@ func (q *Queries) GetProfile(ctx context.Context, subject string) (Profile, erro
&i.Birthdate, &i.Birthdate,
&i.Zone, &i.Zone,
&i.Locale, &i.Locale,
&i.UpdatedAt,
) )
return i, err return i, err
} }
const modifyProfile = `-- name: ModifyProfile :exec const modifyProfile = `-- name: ModifyProfile :exec
UPDATE profiles UPDATE users
SET name = ?, SET name = ?,
picture = ?, picture = ?,
website = ?, website = ?,
@ -50,10 +70,10 @@ type ModifyProfileParams struct {
Name string `json:"name"` Name string `json:"name"`
Picture string `json:"picture"` Picture string `json:"picture"`
Website string `json:"website"` Website string `json:"website"`
Pronouns string `json:"pronouns"` Pronouns types.UserPronoun `json:"pronouns"`
Birthdate interface{} `json:"birthdate"` Birthdate date.NullDate `json:"birthdate"`
Zone string `json:"zone"` Zone string `json:"zone"`
Locale string `json:"locale"` Locale types.UserLocale `json:"locale"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
Subject string `json:"subject"` Subject string `json:"subject"`
} }

View File

@ -15,7 +15,7 @@ SELECT subject,
active active
FROM client_store FROM client_store
WHERE owner_subject = ? WHERE owner_subject = ?
OR ? = 1 OR CAST(? AS BOOLEAN) = 1
LIMIT 25 OFFSET ?; LIMIT 25 OFFSET ?;
-- name: InsertClientApp :exec -- name: InsertClientApp :exec

View File

@ -5,11 +5,10 @@ SELECT users.subject,
website, website,
email, email,
email_verified, email_verified,
users.updated_at as user_updated_at, updated_at,
p.updated_at as profile_updated_at,
active active
FROM users FROM users
INNER JOIN main.profiles p on users.subject = p.subject --INNER JOIN main.profiles p on users.subject = p.subject
LIMIT 50 OFFSET ?; LIMIT 50 OFFSET ?;
-- name: GetUsersRoles :many -- name: GetUsersRoles :many
@ -24,5 +23,54 @@ UPDATE users
SET active = cast(? as boolean) SET active = cast(? as boolean)
WHERE subject = ?; WHERE subject = ?;
-- name: VerifyUserEmail :exec
UPDATE users
SET email_verified=1
WHERE subject = ?;
-- name: UserEmailExists :one -- name: UserEmailExists :one
SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND email_verified = 1) == 1 AS email_exists; SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND email_verified = 1) == 1 AS email_exists;
-- name: ModifyUserEmail :exec
UPDATE users
SET email = ?,
email_verified=?
WHERE subject = ?;
-- name: ModifyUserAuth :exec
UPDATE users
SET auth_type = ?,
auth_namespace=?,
auth_user = ?
WHERE subject = ?;
-- name: ModifyUserRemoteLogin :exec
UPDATE users
SET login = ?,
profile_url = ?
WHERE subject = ?;
-- name: UpdateUserToken :exec
UPDATE users
SET access_token = ?,
refresh_token=?,
token_expiry = ?
WHERE subject = ?;
-- name: GetUserToken :one
SELECT access_token, refresh_token, token_expiry
FROM users
WHERE subject = ?;
-- name: RemoveUserRoles :exec
DELETE
FROM users_roles
WHERE user_id IN (SELECT id
FROM users
WHERE subject = ?);
-- name: AddUserRole :exec
INSERT INTO users_roles(role_id, user_id)
SELECT ?, users.id
FROM users
WHERE subject = ?;

View File

@ -1,21 +1,25 @@
-- name: SetOtp :exec -- name: SetOtp :exec
INSERT OR UPDATE users
REPLACE SET otp_secret = ?,
INTO otp (subject, secret, digits) otp_digits=?
VALUES (?, ?, ?); WHERE subject = ?;
-- name: DeleteOtp :exec -- name: DeleteOtp :exec
DELETE UPDATE users
FROM otp SET otp_secret='',
WHERE otp.subject = ?; otp_digits=0
WHERE subject = ?;
-- name: GetOtp :one -- name: GetOtp :one
SELECT secret, digits SELECT otp_secret, otp_digits
FROM otp FROM users
WHERE subject = ?; WHERE subject = ?;
-- name: HasOtp :one -- name: HasOtp :one
SELECT EXISTS(SELECT 1 FROM otp WHERE subject = ?) == 1 as hasOtp; SELECT CAST(1 AS BOOLEAN) AS hasOtp
FROM users
WHERE subject = ?
AND otp_secret != '';
-- name: GetUserEmail :one -- name: GetUserEmail :one
SELECT email SELECT email

View File

@ -1,10 +1,17 @@
-- name: GetProfile :one -- name: GetProfile :one
SELECT profiles.* SELECT subject,
FROM profiles name,
picture,
website,
pronouns,
birthdate,
zone,
locale
FROM users
WHERE subject = ?; WHERE subject = ?;
-- name: ModifyProfile :exec -- name: ModifyProfile :exec
UPDATE profiles UPDATE users
SET name = ?, SET name = ?,
picture = ?, picture = ?,
website = ?, website = ?,

View File

@ -0,0 +1,8 @@
-- name: AddRole :execlastid
INSERT OR IGNORE INTO roles(role)
VALUES (?);
-- name: RemoveRole :exec
DELETE
FROM roles
WHERE role = ?;

View File

@ -3,15 +3,11 @@ SELECT count(subject) > 0 AS hasUser
FROM users; FROM users;
-- name: addUser :exec -- name: addUser :exec
INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active) INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active, name, login, change_password, auth_type, auth_namespace, auth_user)
VALUES (?, ?, ?, ?, ?, ?, ?); VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
-- name: addOAuthUser :exec
INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active)
VALUES (?, ?, ?, ?, ?, ?, ?);
-- name: checkLogin :one -- name: checkLogin :one
SELECT subject, password, EXISTS(SELECT 1 FROM otp WHERE otp.subject = users.subject) == 1 AS has_otp, email, email_verified SELECT subject, password, CAST(otp_secret != '' AS BOOLEAN) AS has_otp, email, email_verified
FROM users FROM users
WHERE users.subject = ? WHERE users.subject = ?
LIMIT 1; LIMIT 1;
@ -48,3 +44,9 @@ SET password = ?,
updated_at=? updated_at=?
WHERE subject = ? WHERE subject = ?
AND password = ?; AND password = ?;
-- name: FlagUserAsDeleted :exec
UPDATE users
SET active= false,
to_delete = true
WHERE subject = ?;

34
database/roles.sql.go Normal file
View File

@ -0,0 +1,34 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.25.0
// source: roles.sql
package database
import (
"context"
)
const addRole = `-- name: AddRole :execlastid
INSERT OR IGNORE INTO roles(role)
VALUES (?)
`
func (q *Queries) AddRole(ctx context.Context, role string) (int64, error) {
result, err := q.db.ExecContext(ctx, addRole, role)
if err != nil {
return 0, err
}
return result.LastInsertId()
}
const removeRole = `-- name: RemoveRole :exec
DELETE
FROM roles
WHERE role = ?
`
func (q *Queries) RemoveRole(ctx context.Context, role string) error {
_, err := q.db.ExecContext(ctx, removeRole, role)
return err
}

26
database/tx.go Normal file
View File

@ -0,0 +1,26 @@
package database
import (
"context"
"database/sql"
"errors"
)
var errCannotOpenTransactionWithoutSqlDB = errors.New("cannot open transaction without sql.DB")
func (q *Queries) UseTx(ctx context.Context, cb func(tx *Queries) error) error {
sqlDB, ok := q.db.(*sql.DB)
if !ok {
panic(errCannotOpenTransactionWithoutSqlDB)
}
tx, err := sqlDB.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
err = cb(q.WithTx(tx))
if err != nil {
return err
}
return tx.Commit()
}

View File

@ -3,7 +3,7 @@ package types
type AuthType byte type AuthType byte
const ( const (
AuthTypeBase AuthType = iota AuthTypeLocal AuthType = iota
AuthTypeOauth2 AuthTypeOauth2
) )

View File

@ -9,11 +9,24 @@ import (
"context" "context"
"time" "time"
"github.com/1f349/lavender/database/types"
"github.com/1f349/lavender/password" "github.com/1f349/lavender/password"
) )
const flagUserAsDeleted = `-- name: FlagUserAsDeleted :exec
UPDATE users
SET active= false,
to_delete = true
WHERE subject = ?
`
func (q *Queries) FlagUserAsDeleted(ctx context.Context, subject string) error {
_, err := q.db.ExecContext(ctx, flagUserAsDeleted, subject)
return err
}
const getUser = `-- name: GetUser :one const getUser = `-- name: GetUser :one
SELECT id, subject, password, email, email_verified, updated_at, registered, active SELECT id, subject, password, change_password, email, email_verified, updated_at, registered, active, name, picture, website, pronouns, birthdate, zone, locale, login, profile_url, auth_type, auth_namespace, auth_user, access_token, refresh_token, token_expiry, otp_secret, otp_digits, to_delete
FROM users FROM users
WHERE subject = ? WHERE subject = ?
LIMIT 1 LIMIT 1
@ -26,11 +39,30 @@ func (q *Queries) GetUser(ctx context.Context, subject string) (User, error) {
&i.ID, &i.ID,
&i.Subject, &i.Subject,
&i.Password, &i.Password,
&i.ChangePassword,
&i.Email, &i.Email,
&i.EmailVerified, &i.EmailVerified,
&i.UpdatedAt, &i.UpdatedAt,
&i.Registered, &i.Registered,
&i.Active, &i.Active,
&i.Name,
&i.Picture,
&i.Website,
&i.Pronouns,
&i.Birthdate,
&i.Zone,
&i.Locale,
&i.Login,
&i.ProfileUrl,
&i.AuthType,
&i.AuthNamespace,
&i.AuthUser,
&i.AccessToken,
&i.RefreshToken,
&i.TokenExpiry,
&i.OtpSecret,
&i.OtpDigits,
&i.ToDelete,
) )
return i, err return i, err
} }
@ -98,8 +130,8 @@ func (q *Queries) UserHasRole(ctx context.Context, arg UserHasRoleParams) error
} }
const addUser = `-- name: addUser :exec const addUser = `-- name: addUser :exec
INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active) INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active, name, login, change_password, auth_type, auth_namespace, auth_user)
VALUES (?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
type addUserParams struct { type addUserParams struct {
@ -110,6 +142,12 @@ type addUserParams struct {
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
Registered time.Time `json:"registered"` Registered time.Time `json:"registered"`
Active bool `json:"active"` Active bool `json:"active"`
Name string `json:"name"`
Login string `json:"login"`
ChangePassword bool `json:"change_password"`
AuthType types.AuthType `json:"auth_type"`
AuthNamespace string `json:"auth_namespace"`
AuthUser string `json:"auth_user"`
} }
func (q *Queries) addUser(ctx context.Context, arg addUserParams) error { func (q *Queries) addUser(ctx context.Context, arg addUserParams) error {
@ -121,6 +159,12 @@ func (q *Queries) addUser(ctx context.Context, arg addUserParams) error {
arg.UpdatedAt, arg.UpdatedAt,
arg.Registered, arg.Registered,
arg.Active, arg.Active,
arg.Name,
arg.Login,
arg.ChangePassword,
arg.AuthType,
arg.AuthNamespace,
arg.AuthUser,
) )
return err return err
} }
@ -151,7 +195,7 @@ func (q *Queries) changeUserPassword(ctx context.Context, arg changeUserPassword
} }
const checkLogin = `-- name: checkLogin :one const checkLogin = `-- name: checkLogin :one
SELECT subject, password, EXISTS(SELECT 1 FROM otp WHERE otp.subject = users.subject) == 1 AS has_otp, email, email_verified SELECT subject, password, CAST(otp_secret != '' AS BOOLEAN) AS has_otp, email, email_verified
FROM users FROM users
WHERE users.subject = ? WHERE users.subject = ?
LIMIT 1 LIMIT 1

10
go.mod
View File

@ -6,12 +6,10 @@ require (
github.com/1f349/cache v0.0.3 github.com/1f349/cache v0.0.3
github.com/1f349/mjwt v0.4.1 github.com/1f349/mjwt v0.4.1
github.com/1f349/overlapfs v0.0.1 github.com/1f349/overlapfs v0.0.1
github.com/1f349/tulip v0.0.0-20240725211619-6b19e2d4ca63 github.com/1f349/simplemail v0.0.5
github.com/charmbracelet/log v0.4.0 github.com/charmbracelet/log v0.4.0
github.com/cloudflare/tableflip v1.2.3 github.com/cloudflare/tableflip v1.2.3
github.com/emersion/go-message v0.18.1 github.com/emersion/go-message v0.18.1
github.com/emersion/go-sasl v0.0.0-20231106173351-e73c9f7bad43
github.com/emersion/go-smtp v0.21.3
github.com/go-oauth2/oauth2/v4 v4.5.2 github.com/go-oauth2/oauth2/v4 v4.5.2
github.com/golang-jwt/jwt/v4 v4.5.0 github.com/golang-jwt/jwt/v4 v4.5.0
github.com/golang-migrate/migrate/v4 v4.17.1 github.com/golang-migrate/migrate/v4 v4.17.1
@ -21,10 +19,13 @@ require (
github.com/julienschmidt/httprouter v1.3.0 github.com/julienschmidt/httprouter v1.3.0
github.com/mattn/go-sqlite3 v1.14.22 github.com/mattn/go-sqlite3 v1.14.22
github.com/mrmelon54/pronouns v1.0.3 github.com/mrmelon54/pronouns v1.0.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/spf13/afero v1.11.0 github.com/spf13/afero v1.11.0
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/xlzd/gotp v0.1.0
golang.org/x/crypto v0.26.0 golang.org/x/crypto v0.26.0
golang.org/x/oauth2 v0.22.0 golang.org/x/oauth2 v0.22.0
golang.org/x/sync v0.8.0
golang.org/x/text v0.17.0 golang.org/x/text v0.17.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )
@ -36,6 +37,8 @@ require (
github.com/charmbracelet/lipgloss v0.12.1 // indirect github.com/charmbracelet/lipgloss v0.12.1 // indirect
github.com/charmbracelet/x/ansi v0.2.1 // indirect github.com/charmbracelet/x/ansi v0.2.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/emersion/go-sasl v0.0.0-20231106173351-e73c9f7bad43 // indirect
github.com/emersion/go-smtp v0.21.3 // indirect
github.com/go-jose/go-jose/v4 v4.0.4 // indirect github.com/go-jose/go-jose/v4 v4.0.4 // indirect
github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/go-logfmt/logfmt v0.6.0 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
@ -63,6 +66,5 @@ require (
go.uber.org/atomic v1.11.0 // indirect go.uber.org/atomic v1.11.0 // indirect
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect
golang.org/x/net v0.28.0 // indirect golang.org/x/net v0.28.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.24.0 // indirect golang.org/x/sys v0.24.0 // indirect
) )

8
go.sum
View File

@ -7,8 +7,8 @@ github.com/1f349/overlapfs v0.0.1 h1:LAxBolrXFAgU0yqZtXg/C/aaPq3eoQSPpBc49BHuTp0
github.com/1f349/overlapfs v0.0.1/go.mod h1:I6aItQycr7nrzplmfNXp/QF9tTmKRSgY3fXmu/7Ky2o= github.com/1f349/overlapfs v0.0.1/go.mod h1:I6aItQycr7nrzplmfNXp/QF9tTmKRSgY3fXmu/7Ky2o=
github.com/1f349/rsa-helper v0.0.2 h1:N/fLQqg5wrjIzG6G4zdwa5Xcv9/jIPutCls9YekZr9U= github.com/1f349/rsa-helper v0.0.2 h1:N/fLQqg5wrjIzG6G4zdwa5Xcv9/jIPutCls9YekZr9U=
github.com/1f349/rsa-helper v0.0.2/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg= github.com/1f349/rsa-helper v0.0.2/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg=
github.com/1f349/tulip v0.0.0-20240725211619-6b19e2d4ca63 h1:jPg+0bgKD5kY7yQtRZqeba+BGKFE51evGvwewZwa7Xc= github.com/1f349/simplemail v0.0.5 h1:cr+8pdWhFE/+XVSO7ZTjntySbmIbTqmDy2SR9cHAPLE=
github.com/1f349/tulip v0.0.0-20240725211619-6b19e2d4ca63/go.mod h1:1zFQhcbgiyPSWHVMp0cXJjmd6FhasP5bf5tWS4ZK61A= github.com/1f349/simplemail v0.0.5/go.mod h1:ppAIqkvVkI6L99EefbR5NgOjpePNK/RKgeoehj5A+kU=
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
@ -146,6 +146,8 @@ github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99
github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw= github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw=
github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0= github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0=
github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
@ -199,6 +201,8 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHo
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74=
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
github.com/xlzd/gotp v0.1.0 h1:37blvlKCh38s+fkem+fFh7sMnceltoIEBYTVXyoa5Po=
github.com/xlzd/gotp v0.1.0/go.mod h1:ndLJ3JKzi3xLmUProq4LLxCuECL93dG9WASNLpHz8qg=
github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 h1:6fRhSjgLCkTD3JnJxvaJ4Sj+TYblw757bqYgZaOq5ZY= github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 h1:6fRhSjgLCkTD3JnJxvaJ4Sj+TYblw757bqYgZaOq5ZY=
github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI= github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI=
github.com/yudai/gojsondiff v1.0.0 h1:27cbfqXLVEJ1o8I6v3y9lg8Ydm53EKqHXAOMxEGlCOA= github.com/yudai/gojsondiff v1.0.0 h1:27cbfqXLVEJ1o8I6v3y9lg8Ydm53EKqHXAOMxEGlCOA=

View File

@ -25,6 +25,7 @@ func NewManager(services map[string]SsoConfig) (*Manager, error) {
} }
// save by namespace // save by namespace
conf.Namespace = namespace
l.m[namespace] = conf l.m[namespace] = conf
} }
return l, nil return l, nil

View File

@ -1,26 +0,0 @@
package mail
import (
"encoding/json"
"github.com/emersion/go-message/mail"
)
type FromAddress struct {
*mail.Address
}
var _ json.Unmarshaler = &FromAddress{}
func (f *FromAddress) UnmarshalJSON(b []byte) error {
var a string
err := json.Unmarshal(b, &a)
if err != nil {
return err
}
address, err := mail.ParseAddress(a)
if err != nil {
return err
}
f.Address = address
return nil
}

View File

@ -1,96 +1,48 @@
package mail package mail
import ( import (
"bytes" "embed"
"errors"
"fmt"
"github.com/1f349/overlapfs"
"github.com/1f349/simplemail"
"github.com/emersion/go-message/mail" "github.com/emersion/go-message/mail"
"github.com/emersion/go-sasl" "io/fs"
"github.com/emersion/go-smtp" "os"
"io" "path/filepath"
"net"
"time"
) )
//go:embed templates/*.go.html templates/*.go.txt
var embeddedTemplates embed.FS
type Mail struct { type Mail struct {
Name string `json:"name"` mail *simplemail.SimpleMail
Tls bool `json:"tls"` name string
Server string `json:"server"`
From FromAddress `json:"from"`
Username string `json:"username"`
Password string `json:"password"`
} }
func (m *Mail) loginInfo() sasl.Client { func New(sender *simplemail.Mail, wd, name string) (*Mail, error) {
return sasl.NewPlainClient("", m.Username, m.Password) var o fs.FS = embeddedTemplates
o, _ = fs.Sub(o, "templates")
if wd != "" {
mailDir := filepath.Join(wd, "mail-templates")
err := os.Mkdir(mailDir, os.ModePerm)
if err == nil || errors.Is(err, os.ErrExist) {
wdFs := os.DirFS(mailDir)
o = overlapfs.OverlapFS{A: embeddedTemplates, B: wdFs}
}
} }
func (m *Mail) mailCall(to []string, r io.Reader) error { simpleMail, err := simplemail.New(sender, o)
host, _, err := net.SplitHostPort(m.Server) return &Mail{
if err != nil { mail: simpleMail,
return err name: name,
} }, err
if m.Tls {
return smtp.SendMailTLS(m.Server, m.loginInfo(), m.From.String(), to, r)
}
if host == "localhost" || host == "127.0.0.1" {
// internals of smtp.SendMail without STARTTLS for localhost testing
dial, err := smtp.Dial(m.Server)
if err != nil {
return err
}
err = dial.Auth(m.loginInfo())
if err != nil {
return err
}
return dial.SendMail(m.From.String(), to, r)
}
return smtp.SendMail(m.Server, m.loginInfo(), m.From.String(), to, r)
} }
func (m *Mail) SendMail(subject string, to []*mail.Address, htmlBody, textBody io.Reader) error { func (m *Mail) SendEmailTemplate(templateName, subject, nameOfUser string, to *mail.Address, data map[string]any) error {
// generate the email in this template return m.mail.Send(templateName, fmt.Sprintf("%s - %s", subject, m.name), to, map[string]any{
buf := new(bytes.Buffer) "ServiceName": m.name,
"Name": nameOfUser,
// setup mail headers "Data": data,
var h mail.Header })
h.SetDate(time.Now())
h.SetSubject(subject)
h.SetAddressList("From", []*mail.Address{m.From.Address})
h.SetAddressList("To", to)
h.Set("Content-Type", "multipart/alternative")
// setup html and text alternative headers
var hHtml, hTxt mail.InlineHeader
hHtml.Set("Content-Type", "text/html; charset=utf-8")
hTxt.Set("Content-Type", "text/plain; charset=utf-8")
createWriter, err := mail.CreateWriter(buf, h)
if err != nil {
return err
}
inline, err := createWriter.CreateInline()
if err != nil {
return err
}
partHtml, err := inline.CreatePart(hHtml)
if err != nil {
return err
}
if _, err := io.Copy(partHtml, htmlBody); err != nil {
return err
}
partTxt, err := inline.CreatePart(hTxt)
if err != nil {
return err
}
if _, err := io.Copy(partTxt, textBody); err != nil {
return err
}
// convert all to addresses to strings
toStr := make([]string, len(to))
for i := range toStr {
toStr[i] = to[i].String()
}
return m.mailCall(toStr, buf)
} }

View File

@ -1,18 +0,0 @@
package mail
import (
"bytes"
"fmt"
"github.com/1f349/lavender/mail/templates"
"github.com/emersion/go-message/mail"
)
func (m *Mail) SendEmailTemplate(templateName, subject, nameOfUser string, to *mail.Address, data map[string]any) error {
var bufHtml, bufTxt bytes.Buffer
templates.RenderMailTemplate(&bufHtml, &bufTxt, templateName, map[string]any{
"ServiceName": m.Name,
"Name": nameOfUser,
"Data": data,
})
return m.SendMail(fmt.Sprintf("%s - %s", subject, m.Name), []*mail.Address{to}, &bufHtml, &bufTxt)
}

View File

@ -1,55 +0,0 @@
package templates
import (
"embed"
"errors"
"github.com/1f349/overlapfs"
"github.com/1f349/tulip/logger"
htmlTemplate "html/template"
"io"
"io/fs"
"os"
"path/filepath"
"sync"
textTemplate "text/template"
)
var (
//go:embed *.go.html *.go.txt
embeddedTemplates embed.FS
mailHtmlTemplates *htmlTemplate.Template
mailTextTemplates *textTemplate.Template
loadOnce sync.Once
)
func LoadMailTemplates(wd string) (err error) {
loadOnce.Do(func() {
var o fs.FS = embeddedTemplates
if wd != "" {
mailDir := filepath.Join(wd, "mail-templates")
err = os.Mkdir(mailDir, os.ModePerm)
if err != nil && !errors.Is(err, os.ErrExist) {
return
}
wdFs := os.DirFS(mailDir)
o = overlapfs.OverlapFS{A: embeddedTemplates, B: wdFs}
}
mailHtmlTemplates, err = htmlTemplate.New("mail").ParseFS(o, "*.go.html")
if err != nil {
return
}
mailTextTemplates, err = textTemplate.New("mail").ParseFS(o, "*.go.txt")
})
return
}
func RenderMailTemplate(wrHtml, wrTxt io.Writer, name string, data any) {
err := mailHtmlTemplates.ExecuteTemplate(wrHtml, name+".go.html", data)
if err != nil {
logger.Logger.Warn("Failed to render mail html", "name", name, "err", err)
}
err = mailTextTemplates.ExecuteTemplate(wrTxt, name+".go.txt", data)
if err != nil {
logger.Logger.Warn("Failed to render mail text", "name", name, "err", err)
}
}

View File

@ -59,7 +59,7 @@ func (h *httpServer) RequireAdminAuthentication(next UserHandler) httprouter.Han
} }
func (h *httpServer) RequireAuthentication(next UserHandler) httprouter.Handle { func (h *httpServer) RequireAuthentication(next UserHandler) httprouter.Handle {
return h.OptionalAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { return h.OptionalAuthentication(false, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
if auth.IsGuest() { if auth.IsGuest() {
redirectUrl := PrepareRedirectUrl("/login", req.URL) redirectUrl := PrepareRedirectUrl("/login", req.URL)
http.Redirect(rw, req, redirectUrl.String(), http.StatusFound) http.Redirect(rw, req, redirectUrl.String(), http.StatusFound)
@ -69,16 +69,20 @@ func (h *httpServer) RequireAuthentication(next UserHandler) httprouter.Handle {
}) })
} }
func (h *httpServer) OptionalAuthentication(next UserHandler) httprouter.Handle { func (h *httpServer) OptionalAuthentication(flowPart bool, next UserHandler) httprouter.Handle {
return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
authUser, err := h.internalAuthenticationHandler(rw, req) authData, err := h.internalAuthenticationHandler(rw, req)
if err != nil { if err != nil {
if !errors.Is(err, ErrAuthHttpError) { if !errors.Is(err, ErrAuthHttpError) {
http.Error(rw, err.Error(), http.StatusInternalServerError) http.Error(rw, err.Error(), http.StatusInternalServerError)
} }
return return
} }
next(rw, req, params, authUser) if n := authData.NextFlowUrl(req.URL); n != nil && !flowPart {
http.Redirect(rw, req, n.String(), http.StatusFound)
return
}
next(rw, req, params, authData)
} }
} }

73
server/auth_test.go Normal file
View File

@ -0,0 +1,73 @@
package server
import (
"context"
"github.com/1f349/mjwt"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
func TestUserAuth_NextFlowUrl(t *testing.T) {
u := UserAuth{NeedOtp: true}
assert.Equal(t, url.URL{Path: "/login/otp"}, *u.NextFlowUrl(&url.URL{}))
assert.Equal(t, url.URL{Path: "/login/otp", RawQuery: url.Values{"redirect": {"/hello"}}.Encode()}, *u.NextFlowUrl(&url.URL{Path: "/hello"}))
assert.Equal(t, url.URL{Path: "/login/otp", RawQuery: url.Values{"redirect": {"/hello?a=A"}}.Encode()}, *u.NextFlowUrl(&url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}}.Encode()}))
u.NeedOtp = false
assert.Nil(t, u.NextFlowUrl(&url.URL{}))
}
func TestUserAuth_IsGuest(t *testing.T) {
var u UserAuth
assert.True(t, u.IsGuest())
u.Subject = uuid.NewString()
assert.False(t, u.IsGuest())
}
type fakeSessionStore struct {
m map[string]any
saveFunc func(map[string]any) error
}
func (f *fakeSessionStore) Context() context.Context { return context.Background() }
func (f *fakeSessionStore) SessionID() string { return "fakeSessionStore" }
func (f *fakeSessionStore) Set(key string, value interface{}) { f.m[key] = value }
func (f *fakeSessionStore) Get(key string) (a interface{}, ok bool) {
if a, ok = f.m[key]; false {
}
return
}
func TestRequireAuthentication(t *testing.T) {
}
func TestOptionalAuthentication(t *testing.T) {
jwtIssuer, err := mjwt.NewIssuer("TestIssuer", uuid.NewString(), jwt.SigningMethodRS512)
h := &httpServer{signingKey: jwtIssuer}
rec := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodGet, "https://example.com/hello", nil)
assert.NoError(t, err)
auth, err := h.internalAuthenticationHandler(rec, req)
assert.NoError(t, err)
assert.True(t, auth.IsGuest())
auth.Subject = "567"
}
func TestPrepareRedirectUrl(t *testing.T) {
assert.Equal(t, url.URL{Path: "/hello"}, *PrepareRedirectUrl("/hello", &url.URL{}))
assert.Equal(t, url.URL{Path: "/world"}, *PrepareRedirectUrl("/world", &url.URL{}))
assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"redirect": {"/hello"}}.Encode()}, *PrepareRedirectUrl("/a", &url.URL{Path: "/hello"}))
assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"redirect": {"/hello?a=A"}}.Encode()}, *PrepareRedirectUrl("/a", &url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}}.Encode()}))
assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"redirect": {"/hello?a=A&b=B"}}.Encode()}, *PrepareRedirectUrl("/a", &url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}, "b": {"B"}}.Encode()}))
assert.Equal(t, url.URL{Path: "/hello", RawQuery: "z=y"}, *PrepareRedirectUrl("/hello?z=y", &url.URL{}))
assert.Equal(t, url.URL{Path: "/world", RawQuery: "z=y"}, *PrepareRedirectUrl("/world?z=y", &url.URL{}))
assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"z": {"y"}, "redirect": {"/hello"}}.Encode()}, *PrepareRedirectUrl("/a?z=y", &url.URL{Path: "/hello"}))
assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"z": {"y"}, "redirect": {"/hello?a=A"}}.Encode()}, *PrepareRedirectUrl("/a?z=y", &url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}}.Encode()}))
assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"z": {"y"}, "redirect": {"/hello?a=A&b=B"}}.Encode()}, *PrepareRedirectUrl("/a?z=y", &url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}, "b": {"B"}}.Encode()}))
}

87
server/edit.go Normal file
View File

@ -0,0 +1,87 @@
package server
import (
"fmt"
"github.com/1f349/lavender/database"
"github.com/1f349/lavender/lists"
"github.com/1f349/lavender/pages"
"github.com/google/uuid"
"github.com/julienschmidt/httprouter"
"net/http"
"time"
)
func (h *httpServer) EditGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
var user database.User
if h.DbTx(rw, func(tx *database.Queries) error {
var err error
user, err = tx.GetUser(req.Context(), auth.Subject)
if err != nil {
return fmt.Errorf("failed to read user data: %w", err)
}
return nil
}) {
return
}
lNonce := uuid.NewString()
http.SetCookie(rw, &http.Cookie{
Name: "tulip-nonce",
Value: lNonce,
Path: "/",
Expires: time.Now().Add(10 * time.Minute),
Secure: true,
SameSite: http.SameSiteLaxMode,
})
pages.RenderPageTemplate(rw, "edit", map[string]any{
"ServiceName": h.conf.ServiceName,
"User": user,
"Nonce": lNonce,
"FieldPronoun": user.Pronouns.String(),
"ListZoneInfo": lists.ListZoneInfo(),
"ListLocale": lists.ListLocale(),
})
}
func (h *httpServer) EditPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
if req.ParseForm() != nil {
rw.WriteHeader(http.StatusBadRequest)
_, _ = rw.Write([]byte("400 Bad Request\n"))
return
}
var patch database.ProfilePatch
errs := patch.ParseFromForm(req.Form)
if len(errs) > 0 {
rw.WriteHeader(http.StatusBadRequest)
_, _ = fmt.Fprintln(rw, "<!DOCTYPE html>\n<html>\n<body>")
_, _ = fmt.Fprintln(rw, "<p>400 Bad Request: Failed to parse form data, press the back button in your browser, check your inputs and try again.</p>")
_, _ = fmt.Fprintln(rw, "<ul>")
for _, i := range errs {
_, _ = fmt.Fprintf(rw, " <li>%s</li>\n", i)
}
_, _ = fmt.Fprintln(rw, "</ul>")
_, _ = fmt.Fprintln(rw, "</body>\n</html>")
return
}
m := database.ModifyProfileParams{
Name: patch.Name,
Picture: patch.Picture,
Website: patch.Website,
Pronouns: patch.Pronouns,
Birthdate: patch.Birthdate,
Zone: patch.Zone.String(),
Locale: patch.Locale,
UpdatedAt: time.Now(),
Subject: auth.Subject,
}
if h.DbTx(rw, func(tx *database.Queries) error {
if err := tx.ModifyProfile(req.Context(), m); err != nil {
return fmt.Errorf("failed to modify user info: %w", err)
}
return nil
}) {
return
}
http.Redirect(rw, req, "/edit", http.StatusFound)
}

View File

@ -42,4 +42,51 @@ func (h *httpServer) Home(rw http.ResponseWriter, req *http.Request, _ httproute
"Nonce": lNonce, "Nonce": lNonce,
"IsAdmin": isAdmin, "IsAdmin": isAdmin,
}) })
// rw.Header().Set("Content-Type", "text/html")
// lNonce := uuid.NewString()
// http.SetCookie(rw, &http.Cookie{
// Name: "tulip-nonce",
// Value: lNonce,
// Path: "/",
// Expires: time.Now().Add(10 * time.Minute),
// Secure: true,
// SameSite: http.SameSiteLaxMode,
// })
//
// if auth.IsGuest() {
// pages.RenderPageTemplate(rw, "index-guest", map[string]any{
// "ServiceName": h.conf.ServiceName,
// })
// return
// }
//
// var userWithName string
// var userRole types.UserRole
// var hasTwoFactor bool
// if h.DbTx(rw, func(tx *database.Queries) (err error) {
// userWithName, err = tx.GetUserDisplayName(req.Context(), auth.Subject)
// if err != nil {
// return fmt.Errorf("failed to get user display name: %w", err)
// }
// hasTwoFactor, err = tx.HasOtp(req.Context(), auth.Subject)
// if err != nil {
// return fmt.Errorf("failed to get user two factor state: %w", err)
// }
// userRole, err = tx.GetUserRole(req.Context(), auth.Subject)
// if err != nil {
// return fmt.Errorf("failed to get user role: %w", err)
// }
// return
// }) {
// return
// }
// pages.RenderPageTemplate(rw, "index", map[string]any{
// "ServiceName": h.conf.ServiceName,
// "Auth": auth,
// "User": database.User{Subject: auth.Subject, Name: userWithName, Role: userRole},
// "Nonce": lNonce,
// "OtpEnabled": hasTwoFactor,
// "IsAdmin": userRole == types.RoleAdmin,
// })
} }

View File

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
auth2 "github.com/1f349/lavender/auth" auth2 "github.com/1f349/lavender/auth"
"github.com/1f349/lavender/database" "github.com/1f349/lavender/database"
"github.com/1f349/lavender/database/types"
"github.com/1f349/lavender/issuer" "github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/pages" "github.com/1f349/lavender/pages"
"github.com/1f349/mjwt" "github.com/1f349/mjwt"
@ -15,13 +16,31 @@ import (
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/mrmelon54/pronouns"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/text/language"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time" "time"
) )
// getUserLoginName finds the `login_name` query parameter within the `/authorize` redirect url
func getUserLoginName(req *http.Request) string {
q := req.URL.Query()
if !q.Has("redirect") {
return ""
}
originUrl, err := url.ParseRequestURI(q.Get("redirect"))
if err != nil {
return ""
}
if originUrl.Path != "/authorize" {
return ""
}
return originUrl.Query().Get("login_name")
}
func (h *httpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { func (h *httpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
if !auth.IsGuest() { if !auth.IsGuest() {
h.SafeRedirect(rw, req) h.SafeRedirect(rw, req)
@ -131,41 +150,70 @@ func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellK
} }
err = h.DbTxError(func(tx *database.Queries) error { err = h.DbTxError(func(tx *database.Queries) error {
jBytes, err := json.Marshal(sessionData.UserInfo) name := sessionData.UserInfo.GetStringOrDefault("name", "Unknown User")
_, err = tx.GetUser(req.Context(), sessionData.Subject)
uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost")
uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified")
if errors.Is(err, sql.ErrNoRows) {
_, err := tx.AddOAuthUser(req.Context(), database.AddOAuthUserParams{
Email: uEmail,
EmailVerified: uEmailVerified,
Name: name,
Username: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"),
AuthNamespace: sso.Namespace,
AuthUser: sessionData.UserInfo.GetStringOrEmpty("sub"),
})
return err
}
err = tx.ModifyUserEmail(req.Context(), database.ModifyUserEmailParams{
Email: uEmail,
EmailVerified: uEmailVerified,
Subject: sessionData.Subject,
})
if err != nil { if err != nil {
return err return err
} }
_, err = tx.GetUser(req.Context(), sessionData.Subject)
if errors.Is(err, sql.ErrNoRows) { err = tx.ModifyUserAuth(req.Context(), database.ModifyUserAuthParams{
uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost") AuthType: types.AuthTypeOauth2,
uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified") AuthNamespace: sso.Namespace,
id, err := tx.AddUser(req.Context(), database.AddUserParams{ AuthUser: sessionData.UserInfo.GetStringOrEmpty("sub"),
Name: "",
Subject: sessionData.Subject, Subject: sessionData.Subject,
Password: "",
Email: uEmail,
EmailVerified: uEmailVerified,
UpdatedAt: time.Now(),
Active: true,
}) })
if err != nil {
return err return err
return tx.AddUser(req.Context(), database.AddUserParams{
Subject: sessionData.Subject,
Email: uEmail,
EmailVerified: uEmailVerified,
Roles: "",
Userinfo: string(jBytes),
UpdatedAt: time.Now(),
Active: true,
})
} }
uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost")
uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified") err = tx.ModifyUserRemoteLogin(req.Context(), database.ModifyUserRemoteLoginParams{
return tx.UpdateUserInfo(req.Context(), database.UpdateUserInfoParams{ Login: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"),
Email: sessionData.Subject, ProfileUrl: sessionData.UserInfo.GetStringOrEmpty("profile"),
EmailVerified: uEmailVerified, Subject: sessionData.Subject,
Userinfo: string(jBytes), })
Subject: uEmail, if err != nil {
return err
}
pronoun, err := pronouns.FindPronoun(sessionData.UserInfo.GetStringOrEmpty("pronouns"))
if err != nil {
pronoun = pronouns.TheyThem
}
locale, err := language.Parse(sessionData.UserInfo.GetStringOrEmpty("locale"))
if err != nil {
locale = language.AmericanEnglish
}
return tx.ModifyProfile(req.Context(), database.ModifyProfileParams{
Name: name,
Picture: sessionData.UserInfo.GetStringOrEmpty("profile"),
Website: sessionData.UserInfo.GetStringOrEmpty("website"),
Pronouns: types.UserPronoun{Pronoun: pronoun},
Birthdate: sessionData.UserInfo.GetNullDate("birthdate"),
Zone: sessionData.UserInfo.GetStringOrDefault("zoneinfo", "UTC"),
Locale: types.UserLocale{Tag: locale},
UpdatedAt: time.Now(),
Subject: sessionData.Subject,
}) })
}) })
if err != nil { if err != nil {
@ -177,7 +225,7 @@ func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellK
return tx.UpdateUserToken(req.Context(), database.UpdateUserTokenParams{ return tx.UpdateUserToken(req.Context(), database.UpdateUserTokenParams{
AccessToken: sql.NullString{String: token.AccessToken, Valid: true}, AccessToken: sql.NullString{String: token.AccessToken, Valid: true},
RefreshToken: sql.NullString{String: token.RefreshToken, Valid: true}, RefreshToken: sql.NullString{String: token.RefreshToken, Valid: true},
Expiry: sql.NullTime{Time: token.Expiry, Valid: true}, TokenExpiry: sql.NullTime{Time: token.Expiry, Valid: true},
Subject: sessionData.Subject, Subject: sessionData.Subject,
}) })
}); err != nil { }); err != nil {
@ -208,6 +256,11 @@ func (l lavenderLoginRefresh) Valid() error { return l.RefreshTokenClaims.Valid(
func (l lavenderLoginRefresh) Type() string { return "lavender-login-refresh" } func (l lavenderLoginRefresh) Type() string { return "lavender-login-refresh" }
func (h *httpServer) setLoginDataCookie2(rw http.ResponseWriter, authData UserAuth) bool {
// TODO(melon): should probably merge there methods
return h.setLoginDataCookie(rw, authData, "")
}
func (h *httpServer) setLoginDataCookie(rw http.ResponseWriter, authData UserAuth, loginName string) bool { func (h *httpServer) setLoginDataCookie(rw http.ResponseWriter, authData UserAuth, loginName string) bool {
ps := auth.NewPermStorage() ps := auth.NewPermStorage()
accId := uuid.NewString() accId := uuid.NewString()
@ -286,13 +339,13 @@ func (h *httpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Re
if err != nil { if err != nil {
return err return err
} }
if !token.AccessToken.Valid || !token.RefreshToken.Valid || !token.Expiry.Valid { if !token.AccessToken.Valid || !token.RefreshToken.Valid || !token.TokenExpiry.Valid {
return fmt.Errorf("invalid oauth token") return fmt.Errorf("invalid oauth token")
} }
oauthToken = &oauth2.Token{ oauthToken = &oauth2.Token{
AccessToken: token.AccessToken.String, AccessToken: token.AccessToken.String,
RefreshToken: token.RefreshToken.String, RefreshToken: token.RefreshToken.String,
Expiry: token.Expiry.Time, Expiry: token.TokenExpiry.Time,
} }
return nil return nil
}) })

25
server/logout.go Normal file
View File

@ -0,0 +1,25 @@
package server
import (
"github.com/julienschmidt/httprouter"
"net/http"
)
func (h *httpServer) logoutPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, _ UserAuth) {
http.SetCookie(rw, &http.Cookie{
Name: "lavender-login-access",
Path: "/",
MaxAge: -1,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
http.SetCookie(rw, &http.Cookie{
Name: "lavender-login-refresh",
Path: "/",
MaxAge: -1,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
http.Redirect(rw, req, "/", http.StatusFound)
}

123
server/mail.go Normal file
View File

@ -0,0 +1,123 @@
package server
import (
"github.com/1f349/lavender/database"
"github.com/1f349/lavender/pages"
"github.com/emersion/go-message/mail"
"github.com/julienschmidt/httprouter"
"net/http"
)
func (h *httpServer) MailVerify(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
code := params.ByName("code")
k := mailLinkKey{mailLinkVerifyEmail, code}
userSub, ok := h.mailLinkCache.Get(k)
if !ok {
http.Error(rw, "Invalid email verification code", http.StatusBadRequest)
return
}
if h.DbTx(rw, func(tx *database.Queries) error {
return tx.VerifyUserEmail(req.Context(), userSub)
}) {
return
}
h.mailLinkCache.Delete(k)
http.Error(rw, "Email address has been verified, you may close this tab and return to the login page.", http.StatusOK)
}
func (h *httpServer) MailPassword(rw http.ResponseWriter, _ *http.Request, params httprouter.Params) {
code := params.ByName("code")
k := mailLinkKey{mailLinkResetPassword, code}
_, ok := h.mailLinkCache.Get(k)
if !ok {
http.Error(rw, "Invalid password reset code", http.StatusBadRequest)
return
}
pages.RenderPageTemplate(rw, "reset-password", map[string]any{
"ServiceName": h.conf.ServiceName,
"Code": code,
})
}
func (h *httpServer) MailPasswordPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
pw := req.PostFormValue("new_password")
rpw := req.PostFormValue("confirm_password")
code := req.PostFormValue("code")
// reverse passwords are possible
if len(pw) == 0 {
http.Error(rw, "Cannot set an empty password", http.StatusBadRequest)
return
}
// bcrypt only allows up to 72 bytes anyway
if len(pw) > 64 {
http.Error(rw, "Security by extremely long password is a weird flex", http.StatusBadRequest)
return
}
if rpw != pw {
http.Error(rw, "Passwords do not match", http.StatusBadRequest)
return
}
k := mailLinkKey{mailLinkResetPassword, code}
userSub, ok := h.mailLinkCache.Get(k)
if !ok {
http.Error(rw, "Invalid password reset code", http.StatusBadRequest)
return
}
h.mailLinkCache.Delete(k)
// reset password database call
if h.DbTx(rw, func(tx *database.Queries) error {
return tx.ChangePassword(req.Context(), userSub, pw)
}) {
return
}
http.Error(rw, "Reset password successfully, you can login now.", http.StatusOK)
}
func (h *httpServer) MailDelete(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
code := params.ByName("code")
k := mailLinkKey{mailLinkDelete, code}
userSub, ok := h.mailLinkCache.Get(k)
if !ok {
http.Error(rw, "Invalid email delete code", http.StatusBadRequest)
return
}
var userInfo database.User
if h.DbTx(rw, func(tx *database.Queries) (err error) {
userInfo, err = tx.GetUser(req.Context(), userSub)
if err != nil {
return
}
return tx.FlagUserAsDeleted(req.Context(), userSub)
}) {
return
}
h.mailLinkCache.Delete(k)
// parse email for headers
address, err := mail.ParseAddress(userInfo.Email)
if err != nil {
http.Error(rw, "500 Internal Server Error: Failed to parse user email address", http.StatusInternalServerError)
return
}
err = h.conf.Mail.SendEmailTemplate("mail-account-delete", "Account Deletion", userInfo.Name, address, nil)
if err != nil {
http.Error(rw, "Failed to send confirmation email.", http.StatusInternalServerError)
return
}
http.Error(rw, "You will receive an email shortly to verify this action, you may close this tab.", http.StatusOK)
}

View File

@ -12,11 +12,17 @@ import (
"strconv" "strconv"
) )
func SetupManageApps(r *httprouter.Router, hs *httpServer) {
r.GET("/manage/apps", hs.RequireAuthentication(hs.ManageAppsGet))
r.GET("/manage/apps/create", hs.RequireAuthentication(hs.ManageAppsCreateGet))
r.POST("/manage/apps", hs.RequireAuthentication(hs.ManageAppsPost))
}
func (h *httpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { func (h *httpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
q := req.URL.Query() q := req.URL.Query()
offset, _ := strconv.Atoi(q.Get("offset")) offset, _ := strconv.Atoi(q.Get("offset"))
var roles string var roles []string
var appList []database.GetAppListRow var appList []database.GetAppListRow
if h.DbTx(rw, func(tx *database.Queries) (err error) { if h.DbTx(rw, func(tx *database.Queries) (err error) {
roles, err = tx.GetUserRoles(req.Context(), auth.Subject) roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
@ -24,7 +30,7 @@ func (h *httpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _
return return
} }
appList, err = tx.GetAppList(req.Context(), database.GetAppListParams{ appList, err = tx.GetAppList(req.Context(), database.GetAppListParams{
Owner: auth.Subject, OwnerSubject: auth.Subject,
Column2: HasRole(roles, role.LavenderAdmin), Column2: HasRole(roles, role.LavenderAdmin),
Offset: int64(offset), Offset: int64(offset),
}) })
@ -61,7 +67,7 @@ func (h *httpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _
} }
func (h *httpServer) ManageAppsCreateGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { func (h *httpServer) ManageAppsCreateGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
var roles string var roles []string
if h.DbTx(rw, func(tx *database.Queries) (err error) { if h.DbTx(rw, func(tx *database.Queries) (err error) {
roles, err = tx.GetUserRoles(req.Context(), auth.Subject) roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
return return
@ -96,7 +102,7 @@ func (h *httpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
active := req.Form.Has("active") active := req.Form.Has("active")
if sso || hasPerms { if sso || hasPerms {
var roles string var roles []string
if h.DbTx(rw, func(tx *database.Queries) (err error) { if h.DbTx(rw, func(tx *database.Queries) (err error) {
roles, err = tx.GetUserRoles(req.Context(), auth.Subject) roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
return return
@ -125,7 +131,7 @@ func (h *httpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
Name: name, Name: name,
Secret: secret, Secret: secret,
Domain: domain, Domain: domain,
Owner: auth.Subject, OwnerSubject: auth.Subject,
Perms: perms, Perms: perms,
Public: public, Public: public,
Sso: sso, Sso: sso,
@ -145,7 +151,7 @@ func (h *httpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
Sso: sso, Sso: sso,
Active: active, Active: active,
Subject: req.FormValue("subject"), Subject: req.FormValue("subject"),
Owner: auth.Subject, OwnerSubject: auth.Subject,
}) })
}) { }) {
return return
@ -166,7 +172,7 @@ func (h *httpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
err = tx.ResetClientAppSecret(req.Context(), database.ResetClientAppSecretParams{ err = tx.ResetClientAppSecret(req.Context(), database.ResetClientAppSecretParams{
Secret: secret, Secret: secret,
Subject: sub, Subject: sub,
Owner: auth.Subject, OwnerSubject: auth.Subject,
}) })
return err return err
}) { }) {

View File

@ -5,16 +5,22 @@ import (
"github.com/1f349/lavender/pages" "github.com/1f349/lavender/pages"
"github.com/1f349/lavender/role" "github.com/1f349/lavender/role"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"golang.org/x/sync/errgroup"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
) )
func SetupManageUsers(r *httprouter.Router, hs *httpServer) {
r.GET("/manage/users", hs.RequireAdminAuthentication(hs.ManageUsersGet))
r.POST("/manage/users", hs.RequireAdminAuthentication(hs.ManageUsersPost))
}
func (h *httpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { func (h *httpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
q := req.URL.Query() q := req.URL.Query()
offset, _ := strconv.Atoi(q.Get("offset")) offset, _ := strconv.Atoi(q.Get("offset"))
var roles string var roles []string
var userList []database.GetUserListRow var userList []database.GetUserListRow
if h.DbTx(rw, func(tx *database.Queries) (err error) { if h.DbTx(rw, func(tx *database.Queries) (err error) {
roles, err = tx.GetUserRoles(req.Context(), auth.Subject) roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
@ -64,7 +70,7 @@ func (h *httpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
return return
} }
var roles string var roles []string
if h.DbTx(rw, func(tx *database.Queries) (err error) { if h.DbTx(rw, func(tx *database.Queries) (err error) {
roles, err = tx.GetUserRoles(req.Context(), auth.Subject) roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
return return
@ -78,18 +84,38 @@ func (h *httpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
offset := req.Form.Get("offset") offset := req.Form.Get("offset")
action := req.Form.Get("action") action := req.Form.Get("action")
newRoles := req.Form.Get("roles") newRoles := req.Form["roles"]
active := req.Form.Has("active") active := req.Form.Has("active")
switch action { switch action {
case "edit": case "edit":
if h.DbTx(rw, func(tx *database.Queries) error { if h.DbTx(rw, func(tx *database.Queries) error {
sub := req.Form.Get("subject") sub := req.Form.Get("subject")
return tx.UpdateUser(req.Context(), database.UpdateUserParams{ return tx.UseTx(req.Context(), func(tx *database.Queries) (err error) {
Active: active, err = tx.ChangeUserActive(req.Context(), database.ChangeUserActiveParams{Column1: active, Subject: sub})
Roles: newRoles, if err != nil {
return err
}
err = tx.RemoveUserRoles(req.Context(), sub)
if err != nil {
return err
}
errGrp := new(errgroup.Group)
errGrp.SetLimit(3)
for _, roleName := range newRoles {
errGrp.Go(func() error {
roleId, err := strconv.ParseInt(roleName, 10, 64)
if err != nil {
return err
}
return tx.AddUserRole(req.Context(), database.AddUserRoleParams{
RoleID: roleId,
Subject: sub, Subject: sub,
}) })
})
}
return errGrp.Wait()
})
}) { }) {
return return
} }

View File

@ -1,15 +1,143 @@
package server package server
import ( import (
"encoding/json"
clientStore "github.com/1f349/lavender/client-store"
"github.com/1f349/lavender/database"
"github.com/1f349/lavender/logger" "github.com/1f349/lavender/logger"
"github.com/1f349/lavender/pages" "github.com/1f349/lavender/pages"
"github.com/1f349/lavender/scope" "github.com/1f349/lavender/scope"
"github.com/1f349/lavender/utils"
"github.com/1f349/mjwt"
"github.com/go-oauth2/oauth2/v4/generates"
"github.com/go-oauth2/oauth2/v4/manage"
"github.com/go-oauth2/oauth2/v4/server"
"github.com/go-oauth2/oauth2/v4/store"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time"
) )
func SetupOAuth2(r *httprouter.Router, hs *httpServer, key *mjwt.Issuer, db *database.Queries) {
oauthManager := manage.NewManager()
oauthManager.MapAuthorizeGenerate(generates.NewAuthorizeGenerate())
oauthManager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
oauthManager.MustTokenStorage(store.NewMemoryTokenStore())
oauthManager.MapAccessGenerate(NewMJWTAccessGenerate(key, db))
oauthManager.MapClientStorage(clientStore.New(db))
oauthSrv := server.NewDefaultServer(oauthManager)
oauthSrv.SetClientInfoHandler(func(req *http.Request) (clientID, clientSecret string, err error) {
cId, cSecret, err := server.ClientBasicHandler(req)
if cId == "" && cSecret == "" {
cId, cSecret, err = server.ClientFormHandler(req)
}
if err != nil {
return "", "", err
}
return cId, cSecret, nil
})
oauthSrv.SetUserAuthorizationHandler(hs.oauthUserAuthorization)
oauthSrv.SetAuthorizeScopeHandler(func(rw http.ResponseWriter, req *http.Request) (string, error) {
var form url.Values
if req.Method == http.MethodPost {
form = req.PostForm
} else {
form = req.URL.Query()
}
a := form.Get("scope")
if !scope.ScopesExist(a) {
return "", errInvalidScope
}
return a, nil
})
addIdTokenSupport(oauthSrv, db, key)
r.GET("/authorize", hs.RequireAuthentication(hs.authorizeEndpoint))
r.POST("/authorize", hs.RequireAuthentication(hs.authorizeEndpoint))
r.POST("/token", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
if err := oauthSrv.HandleTokenRequest(rw, req); err != nil {
http.Error(rw, "Failed to handle token request", http.StatusInternalServerError)
}
})
}
func (h *httpServer) userInfoRequest(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Access-Control-Allow-Credentials", "true")
rw.Header().Set("Access-Control-Allow-Headers", "Authorization,Content-Type")
rw.Header().Set("Access-Control-Allow-Origin", strings.TrimSuffix(req.Referer(), "/"))
rw.Header().Set("Access-Control-Allow-Methods", "GET")
if req.Method == http.MethodOptions {
return
}
token, err := h.oauthSrv.ValidationBearerToken(req)
if err != nil {
http.Error(rw, "403 Forbidden", http.StatusForbidden)
return
}
userId := token.GetUserID()
sso := h.manager.FindServiceFromLogin(userId)
if sso == nil {
http.Error(rw, "Invalid user", http.StatusBadRequest)
return
}
var user database.User
if h.DbTx(rw, func(tx *database.Queries) (err error) {
user, err = tx.GetUser(req.Context(), userId)
return
}) {
return
}
claims := ParseClaims(token.GetScope())
if !claims["openid"] {
http.Error(rw, "Invalid scope", http.StatusBadRequest)
return
}
m := make(map[string]any)
if claims["name"] {
m["name"] = user.Name
}
if claims["username"] {
m["preferred_username"] = user.Login
m["login"] = user.Login
}
if claims["profile"] {
m["profile"] = user.ProfileUrl
m["picture"] = user.Picture
m["website"] = user.Website
}
if claims["email"] {
m["email"] = user.Email
m["email_verified"] = user.EmailVerified
}
if claims["birthdate"] && user.Birthdate.Valid {
m["birthdate"] = user.Birthdate.Date
}
if claims["age"] && user.Birthdate.Valid {
m["age"] = utils.Age(user.Birthdate.Date.ToTime())
}
if claims["zoneinfo"] {
m["zoneinfo"] = user.Zone
}
if claims["locale"] {
m["locale"] = user.Locale
}
m["sub"] = userId
m["aud"] = token.GetClientID()
m["updated_at"] = time.Now().Unix()
_ = json.NewEncoder(rw).Encode(m)
}
func (h *httpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { func (h *httpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
// function is only called with GET or POST method // function is only called with GET or POST method
isPost := req.Method == http.MethodPost isPost := req.Method == http.MethodPost
@ -95,7 +223,7 @@ func (h *httpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request
"ServiceName": h.conf.ServiceName, "ServiceName": h.conf.ServiceName,
"AppName": appName, "AppName": appName,
"AppDomain": appDomain, "AppDomain": appDomain,
"DisplayName": auth.DisplayName, "DisplayName": auth.UserInfo.GetStringOrEmpty("name"),
"WantsList": scope.FancyScopeList(scopeList), "WantsList": scope.FancyScopeList(scopeList),
"ResponseType": form.Get("response_type"), "ResponseType": form.Get("response_type"),
"ResponseMode": form.Get("response_mode"), "ResponseMode": form.Get("response_mode"),

196
server/otp.go Normal file
View File

@ -0,0 +1,196 @@
package server
import (
"bytes"
"context"
"encoding/base64"
"github.com/1f349/lavender/database"
"github.com/1f349/lavender/pages"
"github.com/julienschmidt/httprouter"
"github.com/skip2/go-qrcode"
"github.com/xlzd/gotp"
"html/template"
"image/png"
"net/http"
"time"
)
func (h *httpServer) loginOtpGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
if !auth.NeedOtp {
h.SafeRedirect(rw, req)
return
}
pages.RenderPageTemplate(rw, "login-otp", map[string]any{
"ServiceName": h.conf.ServiceName,
"Redirect": req.URL.Query().Get("redirect"),
})
}
func (h *httpServer) loginOtpPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
if !auth.NeedOtp {
http.Redirect(rw, req, "/", http.StatusFound)
return
}
otpInput := req.FormValue("code")
if h.fetchAndValidateOtp(rw, auth.Subject, otpInput) {
return
}
auth.NeedOtp = false
h.setLoginDataCookie2(rw, auth)
h.SafeRedirect(rw, req)
}
func (h *httpServer) fetchAndValidateOtp(rw http.ResponseWriter, sub, code string) bool {
var hasOtp bool
var otpRow database.GetOtpRow
var secret string
var digits int64
if h.DbTx(rw, func(tx *database.Queries) (err error) {
hasOtp, err = tx.HasOtp(context.Background(), sub)
if err != nil {
return
}
if hasOtp {
otpRow, err = tx.GetOtp(context.Background(), sub)
secret = otpRow.OtpSecret
digits = otpRow.OtpDigits
}
return
}) {
return true
}
if hasOtp {
totp := gotp.NewTOTP(secret, int(digits), 30, nil)
if !verifyTotp(totp, code) {
http.Error(rw, "400 Bad Request: Invalid OTP code", http.StatusBadRequest)
return true
}
}
return false
}
func (h *httpServer) editOtpPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
if req.Method == http.MethodPost && req.FormValue("remove") == "1" {
if !req.Form.Has("code") {
// render page
pages.RenderPageTemplate(rw, "remove-otp", map[string]any{
"ServiceName": h.conf.ServiceName,
})
return
}
otpInput := req.Form.Get("code")
if h.fetchAndValidateOtp(rw, auth.Subject, otpInput) {
return
}
if h.DbTx(rw, func(tx *database.Queries) error {
return tx.DeleteOtp(req.Context(), auth.Subject)
}) {
return
}
http.Redirect(rw, req, "/", http.StatusFound)
return
}
var digits int
switch req.FormValue("digits") {
case "6":
digits = 6
case "7":
digits = 7
case "8":
digits = 8
default:
http.Error(rw, "400 Bad Request: Invalid number of digits for OTP code", http.StatusBadRequest)
return
}
secret := req.FormValue("secret")
if !gotp.IsSecretValid(secret) {
http.Error(rw, "400 Bad Request: Invalid secret", http.StatusBadRequest)
return
}
if secret == "" {
// get user email
var email string
if h.DbTx(rw, func(tx *database.Queries) error {
var err error
email, err = tx.GetUserEmail(req.Context(), auth.Subject)
return err
}) {
return
}
secret = gotp.RandomSecret(64)
if secret == "" {
http.Error(rw, "500 Internal Server Error: failed to generate OTP secret", http.StatusInternalServerError)
return
}
totp := gotp.NewTOTP(secret, digits, 30, nil)
otpUri := totp.ProvisioningUri(email, h.conf.OtpIssuer)
code, err := qrcode.New(otpUri, qrcode.Medium)
if err != nil {
http.Error(rw, "500 Internal Server Error: failed to generate QR code", http.StatusInternalServerError)
return
}
qrImg := code.Image(60 * 4)
qrBounds := qrImg.Bounds()
qrWidth := qrBounds.Dx()
qrBuf := new(bytes.Buffer)
if png.Encode(qrBuf, qrImg) != nil {
http.Error(rw, "500 Internal Server Error: failed to generate PNG image of QR code", http.StatusInternalServerError)
return
}
// render page
pages.RenderPageTemplate(rw, "edit-otp", map[string]any{
"ServiceName": h.conf.ServiceName,
"OtpQr": template.URL("data:qrImg/png;base64," + base64.StdEncoding.EncodeToString(qrBuf.Bytes())),
"QrWidth": qrWidth,
"OtpUrl": otpUri,
"OtpSecret": secret,
"OtpDigits": digits,
})
return
}
totp := gotp.NewTOTP(secret, digits, 30, nil)
if !verifyTotp(totp, req.FormValue("code")) {
http.Error(rw, "400 Bad Request: invalid OTP code", http.StatusBadRequest)
return
}
if h.DbTx(rw, func(tx *database.Queries) error {
return tx.SetOtp(req.Context(), database.SetOtpParams{
Subject: auth.Subject,
OtpSecret: secret,
OtpDigits: int64(digits),
})
}) {
return
}
http.Redirect(rw, req, "/", http.StatusFound)
}
func verifyTotp(totp *gotp.TOTP, code string) bool {
t := time.Now()
if totp.VerifyTime(code, t) {
return true
}
if totp.VerifyTime(code, t.Add(-30*time.Second)) {
return true
}
return totp.VerifyTime(code, t.Add(30*time.Second))
}

View File

@ -7,6 +7,5 @@ import (
func TestHasRole(t *testing.T) { func TestHasRole(t *testing.T) {
assert.True(t, HasRole([]string{"lavender:admin", "test:something-else"}, "lavender:admin")) assert.True(t, HasRole([]string{"lavender:admin", "test:something-else"}, "lavender:admin"))
assert.False(t, HasRole([]string{"lavender:admin", "test:something-else"}, "lavender:admin"))
assert.False(t, HasRole([]string{"lavender:", "test:something-else"}, "lavender:admin")) assert.False(t, HasRole([]string{"lavender:", "test:something-else"}, "lavender:admin"))
} }

View File

@ -3,17 +3,14 @@ package server
import ( import (
"errors" "errors"
"github.com/1f349/cache" "github.com/1f349/cache"
clientStore "github.com/1f349/lavender/client-store"
"github.com/1f349/lavender/conf" "github.com/1f349/lavender/conf"
"github.com/1f349/lavender/database" "github.com/1f349/lavender/database"
"github.com/1f349/lavender/issuer" "github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/logger"
"github.com/1f349/lavender/pages" "github.com/1f349/lavender/pages"
scope2 "github.com/1f349/lavender/scope"
"github.com/1f349/mjwt" "github.com/1f349/mjwt"
"github.com/go-oauth2/oauth2/v4/generates"
"github.com/go-oauth2/oauth2/v4/manage" "github.com/go-oauth2/oauth2/v4/manage"
"github.com/go-oauth2/oauth2/v4/server" "github.com/go-oauth2/oauth2/v4/server"
"github.com/go-oauth2/oauth2/v4/store"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"net/http" "net/http"
"net/url" "net/url"
@ -76,44 +73,15 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, db *database.Queries, s
mailLinkCache: cache.New[mailLinkKey, string](), mailLinkCache: cache.New[mailLinkKey, string](),
} }
oauthManager := manage.NewManager() var err error
oauthManager.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) hs.manager, err = issuer.NewManager(config.SsoServices)
oauthManager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
oauthManager.MustTokenStorage(store.NewMemoryTokenStore())
oauthManager.MapAccessGenerate(NewMJWTAccessGenerate(signingKey, db))
oauthManager.MapClientStorage(clientStore.New(db))
oauthSrv := server.NewDefaultServer(oauthManager)
oauthSrv.SetClientInfoHandler(func(req *http.Request) (clientID, clientSecret string, err error) {
cId, cSecret, err := server.ClientBasicHandler(req)
if cId == "" && cSecret == "" {
cId, cSecret, err = server.ClientFormHandler(req)
}
if err != nil { if err != nil {
return "", "", err logger.Logger.Fatal("Failed to load SSO services", "err", err)
} }
return cId, cSecret, nil
})
oauthSrv.SetUserAuthorizationHandler(hs.oauthUserAuthorization)
oauthSrv.SetAuthorizeScopeHandler(func(rw http.ResponseWriter, req *http.Request) (scope string, err error) {
var form url.Values
if req.Method == http.MethodPost {
form = req.PostForm
} else {
form = req.URL.Query()
}
a := form.Get("scope")
if !scope2.ScopesExist(a) {
return "", errInvalidScope
}
return a, nil
})
addIdTokenSupport(oauthSrv, db, signingKey)
ssoManager := issuer.NewManager(config.SsoServices)
SetupOpenId(r, config.BaseUrl, signingKey) SetupOpenId(r, config.BaseUrl, signingKey)
r.POST("/logout", hs.RequireAuthentication(fu)) r.GET("/", hs.OptionalAuthentication(false, hs.Home))
r.POST("/logout", hs.RequireAuthentication(hs.logoutPost))
// theme styles // theme styles
r.GET("/assets/*filepath", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { r.GET("/assets/*filepath", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
@ -126,8 +94,16 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, db *database.Queries, s
http.ServeContent(rw, req, path.Base(name), contentCache, out) http.ServeContent(rw, req, path.Base(name), contentCache, out)
}) })
SetupManageApps(r) // login steps
SetupManageUsers(r) r.GET("/login", hs.OptionalAuthentication(false, hs.loginGet))
r.POST("/login", hs.OptionalAuthentication(false, hs.loginPost))
r.GET("/login/otp", hs.OptionalAuthentication(true, hs.loginOtpGet))
r.POST("/login/otp", hs.OptionalAuthentication(true, hs.loginOtpPost))
r.GET("/callback", hs.OptionalAuthentication(false, hs.loginCallback))
SetupManageApps(r, hs)
SetupManageUsers(r, hs)
SetupOAuth2(r, hs, signingKey, db)
} }
func (h *httpServer) SafeRedirect(rw http.ResponseWriter, req *http.Request) { func (h *httpServer) SafeRedirect(rw http.ResponseWriter, req *http.Request) {

View File

@ -21,3 +21,11 @@ sql:
go_type: "github.com/1f349/lavender/database/types.UserZone" go_type: "github.com/1f349/lavender/database/types.UserZone"
- column: "users.locale" - column: "users.locale"
go_type: "github.com/1f349/lavender/database/types.UserLocale" go_type: "github.com/1f349/lavender/database/types.UserLocale"
- column: "users.auth_type"
go_type: "github.com/1f349/lavender/database/types.AuthType"
- column: "users.access_token"
go_type: "database/sql.NullString"
- column: "users.refresh_token"
go_type: "database/sql.NullString"
- column: "users.token_expiry"
go_type: "database/sql.NullTime"

28
utils/age.go Normal file
View File

@ -0,0 +1,28 @@
package utils
import (
"time"
)
var ageTimeNow = time.Now
func Age(t time.Time) int {
n := ageTimeNow()
// the birthday is in the future so the age is 0
if n.Before(t) {
return 0
}
// the year difference
dy := n.Year() - t.Year()
// the birthday in the current year
tCurrent := t.AddDate(dy, 0, 0)
// minus 1 if the birthday has not yet occurred in the current year
if tCurrent.Before(n) {
dy -= 1
}
return dy
}

30
utils/age_test.go Normal file
View File

@ -0,0 +1,30 @@
package utils
import (
"fmt"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestAge(t *testing.T) {
lGmt := time.FixedZone("GMT", 0)
lBst := time.FixedZone("BST", 60*60)
tPast := time.Date(1939, time.January, 5, 0, 0, 0, 0, lGmt)
tPastDst := time.Date(2001, time.January, 5, 1, 0, 0, 0, lBst)
tCur := time.Date(2005, time.January, 5, 0, 30, 0, 0, lGmt)
tCurDst := time.Date(2005, time.January, 5, 0, 30, 0, 0, lBst)
tFut := time.Date(2008, time.January, 5, 0, 0, 0, 0, time.UTC)
ageTimeNow = func() time.Time { return tCur }
assert.Equal(t, 65, Age(tPast))
assert.Equal(t, 3, Age(tPastDst))
assert.Equal(t, 0, Age(tFut))
ageTimeNow = func() time.Time { return tCurDst }
assert.Equal(t, 66, Age(tPast))
assert.Equal(t, 4, Age(tPastDst))
fmt.Println(tPastDst.AddDate(4, 0, 0).UTC(), tCur.UTC())
assert.Equal(t, 0, Age(tFut))
}