mirror of
https://github.com/1f349/lavender.git
synced 2024-12-22 15:44:07 +00:00
Fix a bunch more compile breaking issues
This commit is contained in:
parent
7064afd55e
commit
d25f9ae2ca
@ -1,5 +1,7 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
|
import "github.com/hardfinhq/go-date"
|
||||||
|
|
||||||
type UserInfoFields map[string]any
|
type UserInfoFields map[string]any
|
||||||
|
|
||||||
func (u UserInfoFields) GetString(key string) (string, bool) {
|
func (u UserInfoFields) GetString(key string) (string, bool) {
|
||||||
@ -20,7 +22,24 @@ func (u UserInfoFields) GetStringOrEmpty(key string) string {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u UserInfoFields) GetStringFromKeysOrEmpty(keys ...string) string {
|
||||||
|
for _, key := range keys {
|
||||||
|
s, _ := u[key].(string)
|
||||||
|
if s == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func (u UserInfoFields) GetBoolean(key string) (bool, bool) {
|
func (u UserInfoFields) GetBoolean(key string) (bool, bool) {
|
||||||
b, ok := u[key].(bool)
|
b, ok := u[key].(bool)
|
||||||
return b, ok
|
return b, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u UserInfoFields) GetNullDate(key string) date.NullDate {
|
||||||
|
s, _ := u[key].(string)
|
||||||
|
fromStr, err := date.FromString(s)
|
||||||
|
return date.NullDate{Date: fromStr, Valid: err == nil}
|
||||||
|
}
|
||||||
|
@ -3,10 +3,13 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
|
"fmt"
|
||||||
"github.com/1f349/lavender"
|
"github.com/1f349/lavender"
|
||||||
"github.com/1f349/lavender/conf"
|
"github.com/1f349/lavender/conf"
|
||||||
|
"github.com/1f349/lavender/database"
|
||||||
"github.com/1f349/lavender/logger"
|
"github.com/1f349/lavender/logger"
|
||||||
"github.com/1f349/lavender/pages"
|
"github.com/1f349/lavender/pages"
|
||||||
|
"github.com/1f349/lavender/role"
|
||||||
"github.com/1f349/lavender/server"
|
"github.com/1f349/lavender/server"
|
||||||
"github.com/1f349/mjwt"
|
"github.com/1f349/mjwt"
|
||||||
"github.com/charmbracelet/log"
|
"github.com/charmbracelet/log"
|
||||||
@ -114,6 +117,10 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
|
|||||||
logger.Logger.Fatal("Failed to open database", "err", err)
|
logger.Logger.Fatal("Failed to open database", "err", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := checkDbHasUser(db); err != nil {
|
||||||
|
logger.Logger.Fatal("Failed to add initial user", "err", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := pages.LoadPages(wd); err != nil {
|
if err := pages.LoadPages(wd); err != nil {
|
||||||
logger.Logger.Fatal("Failed to load page templates:", err)
|
logger.Logger.Fatal("Failed to load page templates:", err)
|
||||||
}
|
}
|
||||||
@ -168,3 +175,45 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
|
|||||||
|
|
||||||
return subcommands.ExitSuccess
|
return subcommands.ExitSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkDbHasUser(db *database.Queries) error {
|
||||||
|
value, err := db.HasUser(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !value {
|
||||||
|
logger.Logger.Warn("No users are available, setting up initial admin user")
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err = db.UseTx(ctx, func(tx *database.Queries) error {
|
||||||
|
adminUuid, err := db.AddLocalUser(context.Background(), database.AddLocalUserParams{
|
||||||
|
Password: "admin",
|
||||||
|
Email: "admin@localhost",
|
||||||
|
EmailVerified: false,
|
||||||
|
Name: "Admin",
|
||||||
|
Username: "admin",
|
||||||
|
ChangePassword: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add user: %w", err)
|
||||||
|
}
|
||||||
|
roleId, err := db.AddRole(context.Background(), role.LavenderAdmin)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add role: %w", err)
|
||||||
|
}
|
||||||
|
err = db.AddUserRole(context.Background(), database.AddUserRoleParams{
|
||||||
|
RoleID: roleId,
|
||||||
|
Subject: adminUuid,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add user role: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -12,6 +12,7 @@ type Conf struct {
|
|||||||
Issuer string `yaml:"issuer"`
|
Issuer string `yaml:"issuer"`
|
||||||
Kid string `yaml:"kid"`
|
Kid string `yaml:"kid"`
|
||||||
Namespace string `yaml:"namespace"`
|
Namespace string `yaml:"namespace"`
|
||||||
|
OtpIssuer string `yaml:"otpIssuer"`
|
||||||
Mail mail.Mail `yaml:"mail"`
|
Mail mail.Mail `yaml:"mail"`
|
||||||
SsoServices []issuer.SsoConfig `yaml:"ssoServices"`
|
SsoServices map[string]issuer.SsoConfig `yaml:"ssoServices"`
|
||||||
}
|
}
|
||||||
|
@ -20,13 +20,13 @@ SELECT subject,
|
|||||||
active
|
active
|
||||||
FROM client_store
|
FROM client_store
|
||||||
WHERE owner_subject = ?
|
WHERE owner_subject = ?
|
||||||
OR ? = 1
|
OR CAST(? AS BOOLEAN) = 1
|
||||||
LIMIT 25 OFFSET ?
|
LIMIT 25 OFFSET ?
|
||||||
`
|
`
|
||||||
|
|
||||||
type GetAppListParams struct {
|
type GetAppListParams struct {
|
||||||
OwnerSubject string `json:"owner_subject"`
|
OwnerSubject string `json:"owner_subject"`
|
||||||
Column2 interface{} `json:"column_2"`
|
Column2 bool `json:"column_2"`
|
||||||
Offset int64 `json:"offset"`
|
Offset int64 `json:"offset"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,10 +7,30 @@ package database
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/1f349/lavender/database/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const addUserRole = `-- name: AddUserRole :exec
|
||||||
|
INSERT INTO users_roles(role_id, user_id)
|
||||||
|
SELECT ?, users.id
|
||||||
|
FROM users
|
||||||
|
WHERE subject = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
type AddUserRoleParams struct {
|
||||||
|
RoleID int64 `json:"role_id"`
|
||||||
|
Subject string `json:"subject"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) AddUserRole(ctx context.Context, arg AddUserRoleParams) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, addUserRole, arg.RoleID, arg.Subject)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
const changeUserActive = `-- name: ChangeUserActive :exec
|
const changeUserActive = `-- name: ChangeUserActive :exec
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET active = cast(? as boolean)
|
SET active = cast(? as boolean)
|
||||||
@ -34,11 +54,9 @@ SELECT users.subject,
|
|||||||
website,
|
website,
|
||||||
email,
|
email,
|
||||||
email_verified,
|
email_verified,
|
||||||
users.updated_at as user_updated_at,
|
updated_at,
|
||||||
p.updated_at as profile_updated_at,
|
|
||||||
active
|
active
|
||||||
FROM users
|
FROM users
|
||||||
INNER JOIN main.profiles p on users.subject = p.subject
|
|
||||||
LIMIT 50 OFFSET ?
|
LIMIT 50 OFFSET ?
|
||||||
`
|
`
|
||||||
|
|
||||||
@ -49,11 +67,11 @@ type GetUserListRow struct {
|
|||||||
Website string `json:"website"`
|
Website string `json:"website"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
EmailVerified bool `json:"email_verified"`
|
EmailVerified bool `json:"email_verified"`
|
||||||
UserUpdatedAt time.Time `json:"user_updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
ProfileUpdatedAt time.Time `json:"profile_updated_at"`
|
|
||||||
Active bool `json:"active"`
|
Active bool `json:"active"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// INNER JOIN main.profiles p on users.subject = p.subject
|
||||||
func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListRow, error) {
|
func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListRow, error) {
|
||||||
rows, err := q.db.QueryContext(ctx, getUserList, offset)
|
rows, err := q.db.QueryContext(ctx, getUserList, offset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -70,8 +88,7 @@ func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListR
|
|||||||
&i.Website,
|
&i.Website,
|
||||||
&i.Email,
|
&i.Email,
|
||||||
&i.EmailVerified,
|
&i.EmailVerified,
|
||||||
&i.UserUpdatedAt,
|
&i.UpdatedAt,
|
||||||
&i.ProfileUpdatedAt,
|
|
||||||
&i.Active,
|
&i.Active,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -87,6 +104,25 @@ func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListR
|
|||||||
return items, nil
|
return items, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const getUserToken = `-- name: GetUserToken :one
|
||||||
|
SELECT access_token, refresh_token, token_expiry
|
||||||
|
FROM users
|
||||||
|
WHERE subject = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
type GetUserTokenRow struct {
|
||||||
|
AccessToken sql.NullString `json:"access_token"`
|
||||||
|
RefreshToken sql.NullString `json:"refresh_token"`
|
||||||
|
TokenExpiry sql.NullTime `json:"token_expiry"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) GetUserToken(ctx context.Context, subject string) (GetUserTokenRow, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getUserToken, subject)
|
||||||
|
var i GetUserTokenRow
|
||||||
|
err := row.Scan(&i.AccessToken, &i.RefreshToken, &i.TokenExpiry)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
const getUsersRoles = `-- name: GetUsersRoles :many
|
const getUsersRoles = `-- name: GetUsersRoles :many
|
||||||
SELECT r.role, u.id
|
SELECT r.role, u.id
|
||||||
FROM users_roles
|
FROM users_roles
|
||||||
@ -133,6 +169,105 @@ func (q *Queries) GetUsersRoles(ctx context.Context, userIds []int64) ([]GetUser
|
|||||||
return items, nil
|
return items, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const modifyUserAuth = `-- name: ModifyUserAuth :exec
|
||||||
|
UPDATE users
|
||||||
|
SET auth_type = ?,
|
||||||
|
auth_namespace=?,
|
||||||
|
auth_user = ?
|
||||||
|
WHERE subject = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
type ModifyUserAuthParams struct {
|
||||||
|
AuthType types.AuthType `json:"auth_type"`
|
||||||
|
AuthNamespace string `json:"auth_namespace"`
|
||||||
|
AuthUser string `json:"auth_user"`
|
||||||
|
Subject string `json:"subject"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) ModifyUserAuth(ctx context.Context, arg ModifyUserAuthParams) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, modifyUserAuth,
|
||||||
|
arg.AuthType,
|
||||||
|
arg.AuthNamespace,
|
||||||
|
arg.AuthUser,
|
||||||
|
arg.Subject,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const modifyUserEmail = `-- name: ModifyUserEmail :exec
|
||||||
|
UPDATE users
|
||||||
|
SET email = ?,
|
||||||
|
email_verified=?
|
||||||
|
WHERE subject = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
type ModifyUserEmailParams struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
EmailVerified bool `json:"email_verified"`
|
||||||
|
Subject string `json:"subject"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) ModifyUserEmail(ctx context.Context, arg ModifyUserEmailParams) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, modifyUserEmail, arg.Email, arg.EmailVerified, arg.Subject)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const modifyUserRemoteLogin = `-- name: ModifyUserRemoteLogin :exec
|
||||||
|
UPDATE users
|
||||||
|
SET login = ?,
|
||||||
|
profile_url = ?
|
||||||
|
WHERE subject = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
type ModifyUserRemoteLoginParams struct {
|
||||||
|
Login string `json:"login"`
|
||||||
|
ProfileUrl string `json:"profile_url"`
|
||||||
|
Subject string `json:"subject"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) ModifyUserRemoteLogin(ctx context.Context, arg ModifyUserRemoteLoginParams) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, modifyUserRemoteLogin, arg.Login, arg.ProfileUrl, arg.Subject)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const removeUserRoles = `-- name: RemoveUserRoles :exec
|
||||||
|
DELETE
|
||||||
|
FROM users_roles
|
||||||
|
WHERE user_id IN (SELECT id
|
||||||
|
FROM users
|
||||||
|
WHERE subject = ?)
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) RemoveUserRoles(ctx context.Context, subject string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, removeUserRoles, subject)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const updateUserToken = `-- name: UpdateUserToken :exec
|
||||||
|
UPDATE users
|
||||||
|
SET access_token = ?,
|
||||||
|
refresh_token=?,
|
||||||
|
token_expiry = ?
|
||||||
|
WHERE subject = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
type UpdateUserTokenParams struct {
|
||||||
|
AccessToken sql.NullString `json:"access_token"`
|
||||||
|
RefreshToken sql.NullString `json:"refresh_token"`
|
||||||
|
TokenExpiry sql.NullTime `json:"token_expiry"`
|
||||||
|
Subject string `json:"subject"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) UpdateUserToken(ctx context.Context, arg UpdateUserTokenParams) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, updateUserToken,
|
||||||
|
arg.AccessToken,
|
||||||
|
arg.RefreshToken,
|
||||||
|
arg.TokenExpiry,
|
||||||
|
arg.Subject,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
const userEmailExists = `-- name: UserEmailExists :one
|
const userEmailExists = `-- name: UserEmailExists :one
|
||||||
SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND email_verified = 1) == 1 AS email_exists
|
SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND email_verified = 1) == 1 AS email_exists
|
||||||
`
|
`
|
||||||
@ -143,3 +278,14 @@ func (q *Queries) UserEmailExists(ctx context.Context, email string) (bool, erro
|
|||||||
err := row.Scan(&email_exists)
|
err := row.Scan(&email_exists)
|
||||||
return email_exists, err
|
return email_exists, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const verifyUserEmail = `-- name: VerifyUserEmail :exec
|
||||||
|
UPDATE users
|
||||||
|
SET email_verified=1
|
||||||
|
WHERE subject = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) VerifyUserEmail(ctx context.Context, subject string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, verifyUserEmail, subject)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
@ -21,9 +21,21 @@ CREATE TABLE users
|
|||||||
zone TEXT NOT NULL DEFAULT 'UTC',
|
zone TEXT NOT NULL DEFAULT 'UTC',
|
||||||
locale TEXT NOT NULL DEFAULT 'en-US',
|
locale TEXT NOT NULL DEFAULT 'en-US',
|
||||||
|
|
||||||
|
login TEXT NOT NULL DEFAULT '',
|
||||||
|
profile_url TEXT NOT NULL DEFAULT '',
|
||||||
|
|
||||||
auth_type INTEGER NOT NULL,
|
auth_type INTEGER NOT NULL,
|
||||||
auth_namespace TEXT NOT NULL,
|
auth_namespace TEXT NOT NULL,
|
||||||
auth_user TEXT NOT NULL
|
auth_user TEXT NOT NULL,
|
||||||
|
|
||||||
|
access_token TEXT NULL DEFAULT NULL,
|
||||||
|
refresh_token TEXT NULL DEFAULT NULL,
|
||||||
|
token_expiry DATETIME NULL DEFAULT NULL,
|
||||||
|
|
||||||
|
otp_secret TEXT NOT NULL DEFAULT '',
|
||||||
|
otp_digits INTEGER NOT NULL DEFAULT 0,
|
||||||
|
|
||||||
|
to_delete BOOLEAN NOT NULL DEFAULT 0
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX users_subject ON users (subject);
|
CREATE INDEX users_subject ON users (subject);
|
||||||
@ -39,21 +51,12 @@ CREATE TABLE users_roles
|
|||||||
role_id INTEGER NOT NULL,
|
role_id INTEGER NOT NULL,
|
||||||
user_id INTEGER NOT NULL,
|
user_id INTEGER NOT NULL,
|
||||||
|
|
||||||
FOREIGN KEY (role_id) REFERENCES roles (id),
|
FOREIGN KEY (role_id) REFERENCES roles (id) ON DELETE RESTRICT,
|
||||||
FOREIGN KEY (user_id) REFERENCES users (id),
|
FOREIGN KEY (user_id) REFERENCES users (id),
|
||||||
|
|
||||||
CONSTRAINT user_role UNIQUE (role_id, user_id)
|
CONSTRAINT user_role UNIQUE (role_id, user_id)
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE otp
|
|
||||||
(
|
|
||||||
subject INTEGER NOT NULL UNIQUE PRIMARY KEY,
|
|
||||||
secret TEXT NOT NULL,
|
|
||||||
digits INTEGER NOT NULL,
|
|
||||||
|
|
||||||
FOREIGN KEY (subject) REFERENCES users (subject)
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE client_store
|
CREATE TABLE client_store
|
||||||
(
|
(
|
||||||
subject TEXT NOT NULL UNIQUE PRIMARY KEY,
|
subject TEXT NOT NULL UNIQUE PRIMARY KEY,
|
||||||
|
@ -5,9 +5,12 @@
|
|||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/1f349/lavender/database/types"
|
||||||
"github.com/1f349/lavender/password"
|
"github.com/1f349/lavender/password"
|
||||||
|
"github.com/hardfinhq/go-date"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ClientStore struct {
|
type ClientStore struct {
|
||||||
@ -22,24 +25,6 @@ type ClientStore struct {
|
|||||||
Active bool `json:"active"`
|
Active bool `json:"active"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Otp struct {
|
|
||||||
Subject int64 `json:"subject"`
|
|
||||||
Secret string `json:"secret"`
|
|
||||||
Digits int64 `json:"digits"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Profile struct {
|
|
||||||
Subject string `json:"subject"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Picture string `json:"picture"`
|
|
||||||
Website string `json:"website"`
|
|
||||||
Pronouns string `json:"pronouns"`
|
|
||||||
Birthdate interface{} `json:"birthdate"`
|
|
||||||
Zone string `json:"zone"`
|
|
||||||
Locale string `json:"locale"`
|
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Role struct {
|
type Role struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
@ -49,11 +34,30 @@ type User struct {
|
|||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
Subject string `json:"subject"`
|
Subject string `json:"subject"`
|
||||||
Password password.HashString `json:"password"`
|
Password password.HashString `json:"password"`
|
||||||
|
ChangePassword bool `json:"change_password"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
EmailVerified bool `json:"email_verified"`
|
EmailVerified bool `json:"email_verified"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
Registered time.Time `json:"registered"`
|
Registered time.Time `json:"registered"`
|
||||||
Active bool `json:"active"`
|
Active bool `json:"active"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Picture string `json:"picture"`
|
||||||
|
Website string `json:"website"`
|
||||||
|
Pronouns types.UserPronoun `json:"pronouns"`
|
||||||
|
Birthdate date.NullDate `json:"birthdate"`
|
||||||
|
Zone string `json:"zone"`
|
||||||
|
Locale types.UserLocale `json:"locale"`
|
||||||
|
Login string `json:"login"`
|
||||||
|
ProfileUrl string `json:"profile_url"`
|
||||||
|
AuthType types.AuthType `json:"auth_type"`
|
||||||
|
AuthNamespace string `json:"auth_namespace"`
|
||||||
|
AuthUser string `json:"auth_user"`
|
||||||
|
AccessToken sql.NullString `json:"access_token"`
|
||||||
|
RefreshToken sql.NullString `json:"refresh_token"`
|
||||||
|
TokenExpiry sql.NullTime `json:"token_expiry"`
|
||||||
|
OtpSecret string `json:"otp_secret"`
|
||||||
|
OtpDigits int64 `json:"otp_digits"`
|
||||||
|
ToDelete bool `json:"to_delete"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UsersRole struct {
|
type UsersRole struct {
|
||||||
|
@ -10,31 +10,32 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const deleteOtp = `-- name: DeleteOtp :exec
|
const deleteOtp = `-- name: DeleteOtp :exec
|
||||||
DELETE
|
UPDATE users
|
||||||
FROM otp
|
SET otp_secret='',
|
||||||
WHERE otp.subject = ?
|
otp_digits=0
|
||||||
|
WHERE subject = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) DeleteOtp(ctx context.Context, subject int64) error {
|
func (q *Queries) DeleteOtp(ctx context.Context, subject string) error {
|
||||||
_, err := q.db.ExecContext(ctx, deleteOtp, subject)
|
_, err := q.db.ExecContext(ctx, deleteOtp, subject)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
const getOtp = `-- name: GetOtp :one
|
const getOtp = `-- name: GetOtp :one
|
||||||
SELECT secret, digits
|
SELECT otp_secret, otp_digits
|
||||||
FROM otp
|
FROM users
|
||||||
WHERE subject = ?
|
WHERE subject = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
type GetOtpRow struct {
|
type GetOtpRow struct {
|
||||||
Secret string `json:"secret"`
|
OtpSecret string `json:"otp_secret"`
|
||||||
Digits int64 `json:"digits"`
|
OtpDigits int64 `json:"otp_digits"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) GetOtp(ctx context.Context, subject int64) (GetOtpRow, error) {
|
func (q *Queries) GetOtp(ctx context.Context, subject string) (GetOtpRow, error) {
|
||||||
row := q.db.QueryRowContext(ctx, getOtp, subject)
|
row := q.db.QueryRowContext(ctx, getOtp, subject)
|
||||||
var i GetOtpRow
|
var i GetOtpRow
|
||||||
err := row.Scan(&i.Secret, &i.Digits)
|
err := row.Scan(&i.OtpSecret, &i.OtpDigits)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,10 +53,13 @@ func (q *Queries) GetUserEmail(ctx context.Context, subject string) (string, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
const hasOtp = `-- name: HasOtp :one
|
const hasOtp = `-- name: HasOtp :one
|
||||||
SELECT EXISTS(SELECT 1 FROM otp WHERE subject = ?) == 1 as hasOtp
|
SELECT CAST(1 AS BOOLEAN) AS hasOtp
|
||||||
|
FROM users
|
||||||
|
WHERE subject = ?
|
||||||
|
AND otp_secret != ''
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) HasOtp(ctx context.Context, subject int64) (bool, error) {
|
func (q *Queries) HasOtp(ctx context.Context, subject string) (bool, error) {
|
||||||
row := q.db.QueryRowContext(ctx, hasOtp, subject)
|
row := q.db.QueryRowContext(ctx, hasOtp, subject)
|
||||||
var hasotp bool
|
var hasotp bool
|
||||||
err := row.Scan(&hasotp)
|
err := row.Scan(&hasotp)
|
||||||
@ -63,19 +67,19 @@ func (q *Queries) HasOtp(ctx context.Context, subject int64) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const setOtp = `-- name: SetOtp :exec
|
const setOtp = `-- name: SetOtp :exec
|
||||||
INSERT OR
|
UPDATE users
|
||||||
REPLACE
|
SET otp_secret = ?,
|
||||||
INTO otp (subject, secret, digits)
|
otp_digits=?
|
||||||
VALUES (?, ?, ?)
|
WHERE subject = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
type SetOtpParams struct {
|
type SetOtpParams struct {
|
||||||
Subject int64 `json:"subject"`
|
OtpSecret string `json:"otp_secret"`
|
||||||
Secret string `json:"secret"`
|
OtpDigits int64 `json:"otp_digits"`
|
||||||
Digits int64 `json:"digits"`
|
Subject string `json:"subject"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) SetOtp(ctx context.Context, arg SetOtpParams) error {
|
func (q *Queries) SetOtp(ctx context.Context, arg SetOtpParams) error {
|
||||||
_, err := q.db.ExecContext(ctx, setOtp, arg.Subject, arg.Secret, arg.Digits)
|
_, err := q.db.ExecContext(ctx, setOtp, arg.OtpSecret, arg.OtpDigits, arg.Subject)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -2,22 +2,22 @@ package database
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"github.com/1f349/lavender/database/types"
|
||||||
"github.com/1f349/lavender/password"
|
"github.com/1f349/lavender/password"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AddUserParams struct {
|
type AddLocalUserParams struct {
|
||||||
Name string `json:"name"`
|
|
||||||
Subject string `json:"subject"`
|
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
EmailVerified bool `json:"email_verified"`
|
EmailVerified bool `json:"email_verified"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
Name string `json:"name"`
|
||||||
Active bool `json:"active"`
|
Username string `json:"username"`
|
||||||
|
ChangePassword bool `json:"change_password"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) AddUser(ctx context.Context, arg AddUserParams) (string, error) {
|
func (q *Queries) AddLocalUser(ctx context.Context, arg AddLocalUserParams) (string, error) {
|
||||||
pwHash, err := password.HashPassword(arg.Password)
|
pwHash, err := password.HashPassword(arg.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@ -31,6 +31,40 @@ func (q *Queries) AddUser(ctx context.Context, arg AddUserParams) (string, error
|
|||||||
UpdatedAt: n,
|
UpdatedAt: n,
|
||||||
Registered: n,
|
Registered: n,
|
||||||
Active: true,
|
Active: true,
|
||||||
|
Name: arg.Name,
|
||||||
|
Login: arg.Username,
|
||||||
|
ChangePassword: arg.ChangePassword,
|
||||||
|
AuthType: types.AuthTypeLocal,
|
||||||
|
AuthNamespace: "",
|
||||||
|
AuthUser: arg.Username,
|
||||||
|
}
|
||||||
|
return a.Subject, q.addUser(ctx, a)
|
||||||
|
}
|
||||||
|
|
||||||
|
type AddOAuthUserParams struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
EmailVerified bool `json:"email_verified"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
AuthNamespace string `json:"auth_namespace"`
|
||||||
|
AuthUser string `json:"auth_user"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) AddOAuthUser(ctx context.Context, arg AddOAuthUserParams) (string, error) {
|
||||||
|
n := time.Now()
|
||||||
|
a := addUserParams{
|
||||||
|
Subject: uuid.NewString(),
|
||||||
|
Email: arg.Email,
|
||||||
|
EmailVerified: arg.EmailVerified,
|
||||||
|
UpdatedAt: n,
|
||||||
|
Registered: n,
|
||||||
|
Active: true,
|
||||||
|
Name: arg.Name,
|
||||||
|
Login: arg.Username,
|
||||||
|
ChangePassword: false,
|
||||||
|
AuthType: types.AuthTypeOauth2,
|
||||||
|
AuthNamespace: arg.AuthNamespace,
|
||||||
|
AuthUser: arg.AuthUser,
|
||||||
}
|
}
|
||||||
return a.Subject, q.addUser(ctx, a)
|
return a.Subject, q.addUser(ctx, a)
|
||||||
}
|
}
|
||||||
|
@ -8,17 +8,38 @@ package database
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/1f349/lavender/database/types"
|
||||||
|
"github.com/hardfinhq/go-date"
|
||||||
)
|
)
|
||||||
|
|
||||||
const getProfile = `-- name: GetProfile :one
|
const getProfile = `-- name: GetProfile :one
|
||||||
SELECT profiles.subject, profiles.name, profiles.picture, profiles.website, profiles.pronouns, profiles.birthdate, profiles.zone, profiles.locale, profiles.updated_at
|
SELECT subject,
|
||||||
FROM profiles
|
name,
|
||||||
|
picture,
|
||||||
|
website,
|
||||||
|
pronouns,
|
||||||
|
birthdate,
|
||||||
|
zone,
|
||||||
|
locale
|
||||||
|
FROM users
|
||||||
WHERE subject = ?
|
WHERE subject = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) GetProfile(ctx context.Context, subject string) (Profile, error) {
|
type GetProfileRow struct {
|
||||||
|
Subject string `json:"subject"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Picture string `json:"picture"`
|
||||||
|
Website string `json:"website"`
|
||||||
|
Pronouns types.UserPronoun `json:"pronouns"`
|
||||||
|
Birthdate date.NullDate `json:"birthdate"`
|
||||||
|
Zone string `json:"zone"`
|
||||||
|
Locale types.UserLocale `json:"locale"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) GetProfile(ctx context.Context, subject string) (GetProfileRow, error) {
|
||||||
row := q.db.QueryRowContext(ctx, getProfile, subject)
|
row := q.db.QueryRowContext(ctx, getProfile, subject)
|
||||||
var i Profile
|
var i GetProfileRow
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Subject,
|
&i.Subject,
|
||||||
&i.Name,
|
&i.Name,
|
||||||
@ -28,13 +49,12 @@ func (q *Queries) GetProfile(ctx context.Context, subject string) (Profile, erro
|
|||||||
&i.Birthdate,
|
&i.Birthdate,
|
||||||
&i.Zone,
|
&i.Zone,
|
||||||
&i.Locale,
|
&i.Locale,
|
||||||
&i.UpdatedAt,
|
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
const modifyProfile = `-- name: ModifyProfile :exec
|
const modifyProfile = `-- name: ModifyProfile :exec
|
||||||
UPDATE profiles
|
UPDATE users
|
||||||
SET name = ?,
|
SET name = ?,
|
||||||
picture = ?,
|
picture = ?,
|
||||||
website = ?,
|
website = ?,
|
||||||
@ -50,10 +70,10 @@ type ModifyProfileParams struct {
|
|||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Picture string `json:"picture"`
|
Picture string `json:"picture"`
|
||||||
Website string `json:"website"`
|
Website string `json:"website"`
|
||||||
Pronouns string `json:"pronouns"`
|
Pronouns types.UserPronoun `json:"pronouns"`
|
||||||
Birthdate interface{} `json:"birthdate"`
|
Birthdate date.NullDate `json:"birthdate"`
|
||||||
Zone string `json:"zone"`
|
Zone string `json:"zone"`
|
||||||
Locale string `json:"locale"`
|
Locale types.UserLocale `json:"locale"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
Subject string `json:"subject"`
|
Subject string `json:"subject"`
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ SELECT subject,
|
|||||||
active
|
active
|
||||||
FROM client_store
|
FROM client_store
|
||||||
WHERE owner_subject = ?
|
WHERE owner_subject = ?
|
||||||
OR ? = 1
|
OR CAST(? AS BOOLEAN) = 1
|
||||||
LIMIT 25 OFFSET ?;
|
LIMIT 25 OFFSET ?;
|
||||||
|
|
||||||
-- name: InsertClientApp :exec
|
-- name: InsertClientApp :exec
|
||||||
|
@ -5,11 +5,10 @@ SELECT users.subject,
|
|||||||
website,
|
website,
|
||||||
email,
|
email,
|
||||||
email_verified,
|
email_verified,
|
||||||
users.updated_at as user_updated_at,
|
updated_at,
|
||||||
p.updated_at as profile_updated_at,
|
|
||||||
active
|
active
|
||||||
FROM users
|
FROM users
|
||||||
INNER JOIN main.profiles p on users.subject = p.subject
|
--INNER JOIN main.profiles p on users.subject = p.subject
|
||||||
LIMIT 50 OFFSET ?;
|
LIMIT 50 OFFSET ?;
|
||||||
|
|
||||||
-- name: GetUsersRoles :many
|
-- name: GetUsersRoles :many
|
||||||
@ -24,5 +23,54 @@ UPDATE users
|
|||||||
SET active = cast(? as boolean)
|
SET active = cast(? as boolean)
|
||||||
WHERE subject = ?;
|
WHERE subject = ?;
|
||||||
|
|
||||||
|
-- name: VerifyUserEmail :exec
|
||||||
|
UPDATE users
|
||||||
|
SET email_verified=1
|
||||||
|
WHERE subject = ?;
|
||||||
|
|
||||||
-- name: UserEmailExists :one
|
-- name: UserEmailExists :one
|
||||||
SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND email_verified = 1) == 1 AS email_exists;
|
SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND email_verified = 1) == 1 AS email_exists;
|
||||||
|
|
||||||
|
-- name: ModifyUserEmail :exec
|
||||||
|
UPDATE users
|
||||||
|
SET email = ?,
|
||||||
|
email_verified=?
|
||||||
|
WHERE subject = ?;
|
||||||
|
|
||||||
|
-- name: ModifyUserAuth :exec
|
||||||
|
UPDATE users
|
||||||
|
SET auth_type = ?,
|
||||||
|
auth_namespace=?,
|
||||||
|
auth_user = ?
|
||||||
|
WHERE subject = ?;
|
||||||
|
|
||||||
|
-- name: ModifyUserRemoteLogin :exec
|
||||||
|
UPDATE users
|
||||||
|
SET login = ?,
|
||||||
|
profile_url = ?
|
||||||
|
WHERE subject = ?;
|
||||||
|
|
||||||
|
-- name: UpdateUserToken :exec
|
||||||
|
UPDATE users
|
||||||
|
SET access_token = ?,
|
||||||
|
refresh_token=?,
|
||||||
|
token_expiry = ?
|
||||||
|
WHERE subject = ?;
|
||||||
|
|
||||||
|
-- name: GetUserToken :one
|
||||||
|
SELECT access_token, refresh_token, token_expiry
|
||||||
|
FROM users
|
||||||
|
WHERE subject = ?;
|
||||||
|
|
||||||
|
-- name: RemoveUserRoles :exec
|
||||||
|
DELETE
|
||||||
|
FROM users_roles
|
||||||
|
WHERE user_id IN (SELECT id
|
||||||
|
FROM users
|
||||||
|
WHERE subject = ?);
|
||||||
|
|
||||||
|
-- name: AddUserRole :exec
|
||||||
|
INSERT INTO users_roles(role_id, user_id)
|
||||||
|
SELECT ?, users.id
|
||||||
|
FROM users
|
||||||
|
WHERE subject = ?;
|
||||||
|
@ -1,21 +1,25 @@
|
|||||||
-- name: SetOtp :exec
|
-- name: SetOtp :exec
|
||||||
INSERT OR
|
UPDATE users
|
||||||
REPLACE
|
SET otp_secret = ?,
|
||||||
INTO otp (subject, secret, digits)
|
otp_digits=?
|
||||||
VALUES (?, ?, ?);
|
WHERE subject = ?;
|
||||||
|
|
||||||
-- name: DeleteOtp :exec
|
-- name: DeleteOtp :exec
|
||||||
DELETE
|
UPDATE users
|
||||||
FROM otp
|
SET otp_secret='',
|
||||||
WHERE otp.subject = ?;
|
otp_digits=0
|
||||||
|
WHERE subject = ?;
|
||||||
|
|
||||||
-- name: GetOtp :one
|
-- name: GetOtp :one
|
||||||
SELECT secret, digits
|
SELECT otp_secret, otp_digits
|
||||||
FROM otp
|
FROM users
|
||||||
WHERE subject = ?;
|
WHERE subject = ?;
|
||||||
|
|
||||||
-- name: HasOtp :one
|
-- name: HasOtp :one
|
||||||
SELECT EXISTS(SELECT 1 FROM otp WHERE subject = ?) == 1 as hasOtp;
|
SELECT CAST(1 AS BOOLEAN) AS hasOtp
|
||||||
|
FROM users
|
||||||
|
WHERE subject = ?
|
||||||
|
AND otp_secret != '';
|
||||||
|
|
||||||
-- name: GetUserEmail :one
|
-- name: GetUserEmail :one
|
||||||
SELECT email
|
SELECT email
|
||||||
|
@ -1,10 +1,17 @@
|
|||||||
-- name: GetProfile :one
|
-- name: GetProfile :one
|
||||||
SELECT profiles.*
|
SELECT subject,
|
||||||
FROM profiles
|
name,
|
||||||
|
picture,
|
||||||
|
website,
|
||||||
|
pronouns,
|
||||||
|
birthdate,
|
||||||
|
zone,
|
||||||
|
locale
|
||||||
|
FROM users
|
||||||
WHERE subject = ?;
|
WHERE subject = ?;
|
||||||
|
|
||||||
-- name: ModifyProfile :exec
|
-- name: ModifyProfile :exec
|
||||||
UPDATE profiles
|
UPDATE users
|
||||||
SET name = ?,
|
SET name = ?,
|
||||||
picture = ?,
|
picture = ?,
|
||||||
website = ?,
|
website = ?,
|
||||||
|
8
database/queries/roles.sql
Normal file
8
database/queries/roles.sql
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
-- name: AddRole :execlastid
|
||||||
|
INSERT OR IGNORE INTO roles(role)
|
||||||
|
VALUES (?);
|
||||||
|
|
||||||
|
-- name: RemoveRole :exec
|
||||||
|
DELETE
|
||||||
|
FROM roles
|
||||||
|
WHERE role = ?;
|
@ -3,15 +3,11 @@ SELECT count(subject) > 0 AS hasUser
|
|||||||
FROM users;
|
FROM users;
|
||||||
|
|
||||||
-- name: addUser :exec
|
-- name: addUser :exec
|
||||||
INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active)
|
INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active, name, login, change_password, auth_type, auth_namespace, auth_user)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?);
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||||
|
|
||||||
-- name: addOAuthUser :exec
|
|
||||||
INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?);
|
|
||||||
|
|
||||||
-- name: checkLogin :one
|
-- name: checkLogin :one
|
||||||
SELECT subject, password, EXISTS(SELECT 1 FROM otp WHERE otp.subject = users.subject) == 1 AS has_otp, email, email_verified
|
SELECT subject, password, CAST(otp_secret != '' AS BOOLEAN) AS has_otp, email, email_verified
|
||||||
FROM users
|
FROM users
|
||||||
WHERE users.subject = ?
|
WHERE users.subject = ?
|
||||||
LIMIT 1;
|
LIMIT 1;
|
||||||
@ -48,3 +44,9 @@ SET password = ?,
|
|||||||
updated_at=?
|
updated_at=?
|
||||||
WHERE subject = ?
|
WHERE subject = ?
|
||||||
AND password = ?;
|
AND password = ?;
|
||||||
|
|
||||||
|
-- name: FlagUserAsDeleted :exec
|
||||||
|
UPDATE users
|
||||||
|
SET active= false,
|
||||||
|
to_delete = true
|
||||||
|
WHERE subject = ?;
|
||||||
|
34
database/roles.sql.go
Normal file
34
database/roles.sql.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.25.0
|
||||||
|
// source: roles.sql
|
||||||
|
|
||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
const addRole = `-- name: AddRole :execlastid
|
||||||
|
INSERT OR IGNORE INTO roles(role)
|
||||||
|
VALUES (?)
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) AddRole(ctx context.Context, role string) (int64, error) {
|
||||||
|
result, err := q.db.ExecContext(ctx, addRole, role)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return result.LastInsertId()
|
||||||
|
}
|
||||||
|
|
||||||
|
const removeRole = `-- name: RemoveRole :exec
|
||||||
|
DELETE
|
||||||
|
FROM roles
|
||||||
|
WHERE role = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) RemoveRole(ctx context.Context, role string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, removeRole, role)
|
||||||
|
return err
|
||||||
|
}
|
26
database/tx.go
Normal file
26
database/tx.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errCannotOpenTransactionWithoutSqlDB = errors.New("cannot open transaction without sql.DB")
|
||||||
|
|
||||||
|
func (q *Queries) UseTx(ctx context.Context, cb func(tx *Queries) error) error {
|
||||||
|
sqlDB, ok := q.db.(*sql.DB)
|
||||||
|
if !ok {
|
||||||
|
panic(errCannotOpenTransactionWithoutSqlDB)
|
||||||
|
}
|
||||||
|
tx, err := sqlDB.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
err = cb(q.WithTx(tx))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
@ -3,7 +3,7 @@ package types
|
|||||||
type AuthType byte
|
type AuthType byte
|
||||||
|
|
||||||
const (
|
const (
|
||||||
AuthTypeBase AuthType = iota
|
AuthTypeLocal AuthType = iota
|
||||||
AuthTypeOauth2
|
AuthTypeOauth2
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -9,11 +9,24 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/1f349/lavender/database/types"
|
||||||
"github.com/1f349/lavender/password"
|
"github.com/1f349/lavender/password"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const flagUserAsDeleted = `-- name: FlagUserAsDeleted :exec
|
||||||
|
UPDATE users
|
||||||
|
SET active= false,
|
||||||
|
to_delete = true
|
||||||
|
WHERE subject = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) FlagUserAsDeleted(ctx context.Context, subject string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, flagUserAsDeleted, subject)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
const getUser = `-- name: GetUser :one
|
const getUser = `-- name: GetUser :one
|
||||||
SELECT id, subject, password, email, email_verified, updated_at, registered, active
|
SELECT id, subject, password, change_password, email, email_verified, updated_at, registered, active, name, picture, website, pronouns, birthdate, zone, locale, login, profile_url, auth_type, auth_namespace, auth_user, access_token, refresh_token, token_expiry, otp_secret, otp_digits, to_delete
|
||||||
FROM users
|
FROM users
|
||||||
WHERE subject = ?
|
WHERE subject = ?
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
@ -26,11 +39,30 @@ func (q *Queries) GetUser(ctx context.Context, subject string) (User, error) {
|
|||||||
&i.ID,
|
&i.ID,
|
||||||
&i.Subject,
|
&i.Subject,
|
||||||
&i.Password,
|
&i.Password,
|
||||||
|
&i.ChangePassword,
|
||||||
&i.Email,
|
&i.Email,
|
||||||
&i.EmailVerified,
|
&i.EmailVerified,
|
||||||
&i.UpdatedAt,
|
&i.UpdatedAt,
|
||||||
&i.Registered,
|
&i.Registered,
|
||||||
&i.Active,
|
&i.Active,
|
||||||
|
&i.Name,
|
||||||
|
&i.Picture,
|
||||||
|
&i.Website,
|
||||||
|
&i.Pronouns,
|
||||||
|
&i.Birthdate,
|
||||||
|
&i.Zone,
|
||||||
|
&i.Locale,
|
||||||
|
&i.Login,
|
||||||
|
&i.ProfileUrl,
|
||||||
|
&i.AuthType,
|
||||||
|
&i.AuthNamespace,
|
||||||
|
&i.AuthUser,
|
||||||
|
&i.AccessToken,
|
||||||
|
&i.RefreshToken,
|
||||||
|
&i.TokenExpiry,
|
||||||
|
&i.OtpSecret,
|
||||||
|
&i.OtpDigits,
|
||||||
|
&i.ToDelete,
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
@ -98,8 +130,8 @@ func (q *Queries) UserHasRole(ctx context.Context, arg UserHasRoleParams) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
const addUser = `-- name: addUser :exec
|
const addUser = `-- name: addUser :exec
|
||||||
INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active)
|
INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active, name, login, change_password, auth_type, auth_namespace, auth_user)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
`
|
`
|
||||||
|
|
||||||
type addUserParams struct {
|
type addUserParams struct {
|
||||||
@ -110,6 +142,12 @@ type addUserParams struct {
|
|||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
Registered time.Time `json:"registered"`
|
Registered time.Time `json:"registered"`
|
||||||
Active bool `json:"active"`
|
Active bool `json:"active"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Login string `json:"login"`
|
||||||
|
ChangePassword bool `json:"change_password"`
|
||||||
|
AuthType types.AuthType `json:"auth_type"`
|
||||||
|
AuthNamespace string `json:"auth_namespace"`
|
||||||
|
AuthUser string `json:"auth_user"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) addUser(ctx context.Context, arg addUserParams) error {
|
func (q *Queries) addUser(ctx context.Context, arg addUserParams) error {
|
||||||
@ -121,6 +159,12 @@ func (q *Queries) addUser(ctx context.Context, arg addUserParams) error {
|
|||||||
arg.UpdatedAt,
|
arg.UpdatedAt,
|
||||||
arg.Registered,
|
arg.Registered,
|
||||||
arg.Active,
|
arg.Active,
|
||||||
|
arg.Name,
|
||||||
|
arg.Login,
|
||||||
|
arg.ChangePassword,
|
||||||
|
arg.AuthType,
|
||||||
|
arg.AuthNamespace,
|
||||||
|
arg.AuthUser,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -151,7 +195,7 @@ func (q *Queries) changeUserPassword(ctx context.Context, arg changeUserPassword
|
|||||||
}
|
}
|
||||||
|
|
||||||
const checkLogin = `-- name: checkLogin :one
|
const checkLogin = `-- name: checkLogin :one
|
||||||
SELECT subject, password, EXISTS(SELECT 1 FROM otp WHERE otp.subject = users.subject) == 1 AS has_otp, email, email_verified
|
SELECT subject, password, CAST(otp_secret != '' AS BOOLEAN) AS has_otp, email, email_verified
|
||||||
FROM users
|
FROM users
|
||||||
WHERE users.subject = ?
|
WHERE users.subject = ?
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
|
10
go.mod
10
go.mod
@ -6,12 +6,10 @@ require (
|
|||||||
github.com/1f349/cache v0.0.3
|
github.com/1f349/cache v0.0.3
|
||||||
github.com/1f349/mjwt v0.4.1
|
github.com/1f349/mjwt v0.4.1
|
||||||
github.com/1f349/overlapfs v0.0.1
|
github.com/1f349/overlapfs v0.0.1
|
||||||
github.com/1f349/tulip v0.0.0-20240725211619-6b19e2d4ca63
|
github.com/1f349/simplemail v0.0.5
|
||||||
github.com/charmbracelet/log v0.4.0
|
github.com/charmbracelet/log v0.4.0
|
||||||
github.com/cloudflare/tableflip v1.2.3
|
github.com/cloudflare/tableflip v1.2.3
|
||||||
github.com/emersion/go-message v0.18.1
|
github.com/emersion/go-message v0.18.1
|
||||||
github.com/emersion/go-sasl v0.0.0-20231106173351-e73c9f7bad43
|
|
||||||
github.com/emersion/go-smtp v0.21.3
|
|
||||||
github.com/go-oauth2/oauth2/v4 v4.5.2
|
github.com/go-oauth2/oauth2/v4 v4.5.2
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.0
|
github.com/golang-jwt/jwt/v4 v4.5.0
|
||||||
github.com/golang-migrate/migrate/v4 v4.17.1
|
github.com/golang-migrate/migrate/v4 v4.17.1
|
||||||
@ -21,10 +19,13 @@ require (
|
|||||||
github.com/julienschmidt/httprouter v1.3.0
|
github.com/julienschmidt/httprouter v1.3.0
|
||||||
github.com/mattn/go-sqlite3 v1.14.22
|
github.com/mattn/go-sqlite3 v1.14.22
|
||||||
github.com/mrmelon54/pronouns v1.0.3
|
github.com/mrmelon54/pronouns v1.0.3
|
||||||
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||||
github.com/spf13/afero v1.11.0
|
github.com/spf13/afero v1.11.0
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
|
github.com/xlzd/gotp v0.1.0
|
||||||
golang.org/x/crypto v0.26.0
|
golang.org/x/crypto v0.26.0
|
||||||
golang.org/x/oauth2 v0.22.0
|
golang.org/x/oauth2 v0.22.0
|
||||||
|
golang.org/x/sync v0.8.0
|
||||||
golang.org/x/text v0.17.0
|
golang.org/x/text v0.17.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
@ -36,6 +37,8 @@ require (
|
|||||||
github.com/charmbracelet/lipgloss v0.12.1 // indirect
|
github.com/charmbracelet/lipgloss v0.12.1 // indirect
|
||||||
github.com/charmbracelet/x/ansi v0.2.1 // indirect
|
github.com/charmbracelet/x/ansi v0.2.1 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/emersion/go-sasl v0.0.0-20231106173351-e73c9f7bad43 // indirect
|
||||||
|
github.com/emersion/go-smtp v0.21.3 // indirect
|
||||||
github.com/go-jose/go-jose/v4 v4.0.4 // indirect
|
github.com/go-jose/go-jose/v4 v4.0.4 // indirect
|
||||||
github.com/go-logfmt/logfmt v0.6.0 // indirect
|
github.com/go-logfmt/logfmt v0.6.0 // indirect
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
||||||
@ -63,6 +66,5 @@ require (
|
|||||||
go.uber.org/atomic v1.11.0 // indirect
|
go.uber.org/atomic v1.11.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect
|
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect
|
||||||
golang.org/x/net v0.28.0 // indirect
|
golang.org/x/net v0.28.0 // indirect
|
||||||
golang.org/x/sync v0.8.0 // indirect
|
|
||||||
golang.org/x/sys v0.24.0 // indirect
|
golang.org/x/sys v0.24.0 // indirect
|
||||||
)
|
)
|
||||||
|
8
go.sum
8
go.sum
@ -7,8 +7,8 @@ github.com/1f349/overlapfs v0.0.1 h1:LAxBolrXFAgU0yqZtXg/C/aaPq3eoQSPpBc49BHuTp0
|
|||||||
github.com/1f349/overlapfs v0.0.1/go.mod h1:I6aItQycr7nrzplmfNXp/QF9tTmKRSgY3fXmu/7Ky2o=
|
github.com/1f349/overlapfs v0.0.1/go.mod h1:I6aItQycr7nrzplmfNXp/QF9tTmKRSgY3fXmu/7Ky2o=
|
||||||
github.com/1f349/rsa-helper v0.0.2 h1:N/fLQqg5wrjIzG6G4zdwa5Xcv9/jIPutCls9YekZr9U=
|
github.com/1f349/rsa-helper v0.0.2 h1:N/fLQqg5wrjIzG6G4zdwa5Xcv9/jIPutCls9YekZr9U=
|
||||||
github.com/1f349/rsa-helper v0.0.2/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg=
|
github.com/1f349/rsa-helper v0.0.2/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg=
|
||||||
github.com/1f349/tulip v0.0.0-20240725211619-6b19e2d4ca63 h1:jPg+0bgKD5kY7yQtRZqeba+BGKFE51evGvwewZwa7Xc=
|
github.com/1f349/simplemail v0.0.5 h1:cr+8pdWhFE/+XVSO7ZTjntySbmIbTqmDy2SR9cHAPLE=
|
||||||
github.com/1f349/tulip v0.0.0-20240725211619-6b19e2d4ca63/go.mod h1:1zFQhcbgiyPSWHVMp0cXJjmd6FhasP5bf5tWS4ZK61A=
|
github.com/1f349/simplemail v0.0.5/go.mod h1:ppAIqkvVkI6L99EefbR5NgOjpePNK/RKgeoehj5A+kU=
|
||||||
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
|
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
|
||||||
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
|
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
|
||||||
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
||||||
@ -146,6 +146,8 @@ github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99
|
|||||||
github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw=
|
github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw=
|
||||||
github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0=
|
github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0=
|
||||||
github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
|
github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
|
||||||
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||||
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
||||||
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM=
|
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM=
|
||||||
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
|
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
|
||||||
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
|
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
|
||||||
@ -199,6 +201,8 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHo
|
|||||||
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
|
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
|
||||||
github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74=
|
github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74=
|
||||||
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
|
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
|
||||||
|
github.com/xlzd/gotp v0.1.0 h1:37blvlKCh38s+fkem+fFh7sMnceltoIEBYTVXyoa5Po=
|
||||||
|
github.com/xlzd/gotp v0.1.0/go.mod h1:ndLJ3JKzi3xLmUProq4LLxCuECL93dG9WASNLpHz8qg=
|
||||||
github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 h1:6fRhSjgLCkTD3JnJxvaJ4Sj+TYblw757bqYgZaOq5ZY=
|
github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 h1:6fRhSjgLCkTD3JnJxvaJ4Sj+TYblw757bqYgZaOq5ZY=
|
||||||
github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI=
|
github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI=
|
||||||
github.com/yudai/gojsondiff v1.0.0 h1:27cbfqXLVEJ1o8I6v3y9lg8Ydm53EKqHXAOMxEGlCOA=
|
github.com/yudai/gojsondiff v1.0.0 h1:27cbfqXLVEJ1o8I6v3y9lg8Ydm53EKqHXAOMxEGlCOA=
|
||||||
|
@ -25,6 +25,7 @@ func NewManager(services map[string]SsoConfig) (*Manager, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// save by namespace
|
// save by namespace
|
||||||
|
conf.Namespace = namespace
|
||||||
l.m[namespace] = conf
|
l.m[namespace] = conf
|
||||||
}
|
}
|
||||||
return l, nil
|
return l, nil
|
||||||
|
@ -1,26 +0,0 @@
|
|||||||
package mail
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"github.com/emersion/go-message/mail"
|
|
||||||
)
|
|
||||||
|
|
||||||
type FromAddress struct {
|
|
||||||
*mail.Address
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ json.Unmarshaler = &FromAddress{}
|
|
||||||
|
|
||||||
func (f *FromAddress) UnmarshalJSON(b []byte) error {
|
|
||||||
var a string
|
|
||||||
err := json.Unmarshal(b, &a)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
address, err := mail.ParseAddress(a)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
f.Address = address
|
|
||||||
return nil
|
|
||||||
}
|
|
116
mail/mail.go
116
mail/mail.go
@ -1,96 +1,48 @@
|
|||||||
package mail
|
package mail
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"embed"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/1f349/overlapfs"
|
||||||
|
"github.com/1f349/simplemail"
|
||||||
"github.com/emersion/go-message/mail"
|
"github.com/emersion/go-message/mail"
|
||||||
"github.com/emersion/go-sasl"
|
"io/fs"
|
||||||
"github.com/emersion/go-smtp"
|
"os"
|
||||||
"io"
|
"path/filepath"
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//go:embed templates/*.go.html templates/*.go.txt
|
||||||
|
var embeddedTemplates embed.FS
|
||||||
|
|
||||||
type Mail struct {
|
type Mail struct {
|
||||||
Name string `json:"name"`
|
mail *simplemail.SimpleMail
|
||||||
Tls bool `json:"tls"`
|
name string
|
||||||
Server string `json:"server"`
|
|
||||||
From FromAddress `json:"from"`
|
|
||||||
Username string `json:"username"`
|
|
||||||
Password string `json:"password"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mail) loginInfo() sasl.Client {
|
func New(sender *simplemail.Mail, wd, name string) (*Mail, error) {
|
||||||
return sasl.NewPlainClient("", m.Username, m.Password)
|
var o fs.FS = embeddedTemplates
|
||||||
|
o, _ = fs.Sub(o, "templates")
|
||||||
|
if wd != "" {
|
||||||
|
mailDir := filepath.Join(wd, "mail-templates")
|
||||||
|
err := os.Mkdir(mailDir, os.ModePerm)
|
||||||
|
if err == nil || errors.Is(err, os.ErrExist) {
|
||||||
|
wdFs := os.DirFS(mailDir)
|
||||||
|
o = overlapfs.OverlapFS{A: embeddedTemplates, B: wdFs}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mail) mailCall(to []string, r io.Reader) error {
|
simpleMail, err := simplemail.New(sender, o)
|
||||||
host, _, err := net.SplitHostPort(m.Server)
|
return &Mail{
|
||||||
if err != nil {
|
mail: simpleMail,
|
||||||
return err
|
name: name,
|
||||||
}
|
}, err
|
||||||
if m.Tls {
|
|
||||||
return smtp.SendMailTLS(m.Server, m.loginInfo(), m.From.String(), to, r)
|
|
||||||
}
|
|
||||||
if host == "localhost" || host == "127.0.0.1" {
|
|
||||||
// internals of smtp.SendMail without STARTTLS for localhost testing
|
|
||||||
dial, err := smtp.Dial(m.Server)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = dial.Auth(m.loginInfo())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return dial.SendMail(m.From.String(), to, r)
|
|
||||||
}
|
|
||||||
return smtp.SendMail(m.Server, m.loginInfo(), m.From.String(), to, r)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mail) SendMail(subject string, to []*mail.Address, htmlBody, textBody io.Reader) error {
|
func (m *Mail) SendEmailTemplate(templateName, subject, nameOfUser string, to *mail.Address, data map[string]any) error {
|
||||||
// generate the email in this template
|
return m.mail.Send(templateName, fmt.Sprintf("%s - %s", subject, m.name), to, map[string]any{
|
||||||
buf := new(bytes.Buffer)
|
"ServiceName": m.name,
|
||||||
|
"Name": nameOfUser,
|
||||||
// setup mail headers
|
"Data": data,
|
||||||
var h mail.Header
|
})
|
||||||
h.SetDate(time.Now())
|
|
||||||
h.SetSubject(subject)
|
|
||||||
h.SetAddressList("From", []*mail.Address{m.From.Address})
|
|
||||||
h.SetAddressList("To", to)
|
|
||||||
h.Set("Content-Type", "multipart/alternative")
|
|
||||||
|
|
||||||
// setup html and text alternative headers
|
|
||||||
var hHtml, hTxt mail.InlineHeader
|
|
||||||
hHtml.Set("Content-Type", "text/html; charset=utf-8")
|
|
||||||
hTxt.Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
|
|
||||||
createWriter, err := mail.CreateWriter(buf, h)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
inline, err := createWriter.CreateInline()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
partHtml, err := inline.CreatePart(hHtml)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := io.Copy(partHtml, htmlBody); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
partTxt, err := inline.CreatePart(hTxt)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := io.Copy(partTxt, textBody); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// convert all to addresses to strings
|
|
||||||
toStr := make([]string, len(to))
|
|
||||||
for i := range toStr {
|
|
||||||
toStr[i] = to[i].String()
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.mailCall(toStr, buf)
|
|
||||||
}
|
}
|
||||||
|
@ -1,18 +0,0 @@
|
|||||||
package mail
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"github.com/1f349/lavender/mail/templates"
|
|
||||||
"github.com/emersion/go-message/mail"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (m *Mail) SendEmailTemplate(templateName, subject, nameOfUser string, to *mail.Address, data map[string]any) error {
|
|
||||||
var bufHtml, bufTxt bytes.Buffer
|
|
||||||
templates.RenderMailTemplate(&bufHtml, &bufTxt, templateName, map[string]any{
|
|
||||||
"ServiceName": m.Name,
|
|
||||||
"Name": nameOfUser,
|
|
||||||
"Data": data,
|
|
||||||
})
|
|
||||||
return m.SendMail(fmt.Sprintf("%s - %s", subject, m.Name), []*mail.Address{to}, &bufHtml, &bufTxt)
|
|
||||||
}
|
|
@ -1,55 +0,0 @@
|
|||||||
package templates
|
|
||||||
|
|
||||||
import (
|
|
||||||
"embed"
|
|
||||||
"errors"
|
|
||||||
"github.com/1f349/overlapfs"
|
|
||||||
"github.com/1f349/tulip/logger"
|
|
||||||
htmlTemplate "html/template"
|
|
||||||
"io"
|
|
||||||
"io/fs"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"sync"
|
|
||||||
textTemplate "text/template"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
//go:embed *.go.html *.go.txt
|
|
||||||
embeddedTemplates embed.FS
|
|
||||||
mailHtmlTemplates *htmlTemplate.Template
|
|
||||||
mailTextTemplates *textTemplate.Template
|
|
||||||
loadOnce sync.Once
|
|
||||||
)
|
|
||||||
|
|
||||||
func LoadMailTemplates(wd string) (err error) {
|
|
||||||
loadOnce.Do(func() {
|
|
||||||
var o fs.FS = embeddedTemplates
|
|
||||||
if wd != "" {
|
|
||||||
mailDir := filepath.Join(wd, "mail-templates")
|
|
||||||
err = os.Mkdir(mailDir, os.ModePerm)
|
|
||||||
if err != nil && !errors.Is(err, os.ErrExist) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
wdFs := os.DirFS(mailDir)
|
|
||||||
o = overlapfs.OverlapFS{A: embeddedTemplates, B: wdFs}
|
|
||||||
}
|
|
||||||
mailHtmlTemplates, err = htmlTemplate.New("mail").ParseFS(o, "*.go.html")
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
mailTextTemplates, err = textTemplate.New("mail").ParseFS(o, "*.go.txt")
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func RenderMailTemplate(wrHtml, wrTxt io.Writer, name string, data any) {
|
|
||||||
err := mailHtmlTemplates.ExecuteTemplate(wrHtml, name+".go.html", data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Logger.Warn("Failed to render mail html", "name", name, "err", err)
|
|
||||||
}
|
|
||||||
err = mailTextTemplates.ExecuteTemplate(wrTxt, name+".go.txt", data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Logger.Warn("Failed to render mail text", "name", name, "err", err)
|
|
||||||
}
|
|
||||||
}
|
|
@ -59,7 +59,7 @@ func (h *httpServer) RequireAdminAuthentication(next UserHandler) httprouter.Han
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpServer) RequireAuthentication(next UserHandler) httprouter.Handle {
|
func (h *httpServer) RequireAuthentication(next UserHandler) httprouter.Handle {
|
||||||
return h.OptionalAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
|
return h.OptionalAuthentication(false, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
|
||||||
if auth.IsGuest() {
|
if auth.IsGuest() {
|
||||||
redirectUrl := PrepareRedirectUrl("/login", req.URL)
|
redirectUrl := PrepareRedirectUrl("/login", req.URL)
|
||||||
http.Redirect(rw, req, redirectUrl.String(), http.StatusFound)
|
http.Redirect(rw, req, redirectUrl.String(), http.StatusFound)
|
||||||
@ -69,16 +69,20 @@ func (h *httpServer) RequireAuthentication(next UserHandler) httprouter.Handle {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpServer) OptionalAuthentication(next UserHandler) httprouter.Handle {
|
func (h *httpServer) OptionalAuthentication(flowPart bool, next UserHandler) httprouter.Handle {
|
||||||
return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||||
authUser, err := h.internalAuthenticationHandler(rw, req)
|
authData, err := h.internalAuthenticationHandler(rw, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, ErrAuthHttpError) {
|
if !errors.Is(err, ErrAuthHttpError) {
|
||||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
next(rw, req, params, authUser)
|
if n := authData.NextFlowUrl(req.URL); n != nil && !flowPart {
|
||||||
|
http.Redirect(rw, req, n.String(), http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next(rw, req, params, authData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
73
server/auth_test.go
Normal file
73
server/auth_test.go
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/1f349/mjwt"
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUserAuth_NextFlowUrl(t *testing.T) {
|
||||||
|
u := UserAuth{NeedOtp: true}
|
||||||
|
assert.Equal(t, url.URL{Path: "/login/otp"}, *u.NextFlowUrl(&url.URL{}))
|
||||||
|
assert.Equal(t, url.URL{Path: "/login/otp", RawQuery: url.Values{"redirect": {"/hello"}}.Encode()}, *u.NextFlowUrl(&url.URL{Path: "/hello"}))
|
||||||
|
assert.Equal(t, url.URL{Path: "/login/otp", RawQuery: url.Values{"redirect": {"/hello?a=A"}}.Encode()}, *u.NextFlowUrl(&url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}}.Encode()}))
|
||||||
|
u.NeedOtp = false
|
||||||
|
assert.Nil(t, u.NextFlowUrl(&url.URL{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserAuth_IsGuest(t *testing.T) {
|
||||||
|
var u UserAuth
|
||||||
|
assert.True(t, u.IsGuest())
|
||||||
|
u.Subject = uuid.NewString()
|
||||||
|
assert.False(t, u.IsGuest())
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeSessionStore struct {
|
||||||
|
m map[string]any
|
||||||
|
saveFunc func(map[string]any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeSessionStore) Context() context.Context { return context.Background() }
|
||||||
|
func (f *fakeSessionStore) SessionID() string { return "fakeSessionStore" }
|
||||||
|
func (f *fakeSessionStore) Set(key string, value interface{}) { f.m[key] = value }
|
||||||
|
|
||||||
|
func (f *fakeSessionStore) Get(key string) (a interface{}, ok bool) {
|
||||||
|
if a, ok = f.m[key]; false {
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireAuthentication(t *testing.T) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOptionalAuthentication(t *testing.T) {
|
||||||
|
jwtIssuer, err := mjwt.NewIssuer("TestIssuer", uuid.NewString(), jwt.SigningMethodRS512)
|
||||||
|
h := &httpServer{signingKey: jwtIssuer}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://example.com/hello", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
auth, err := h.internalAuthenticationHandler(rec, req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, auth.IsGuest())
|
||||||
|
auth.Subject = "567"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareRedirectUrl(t *testing.T) {
|
||||||
|
assert.Equal(t, url.URL{Path: "/hello"}, *PrepareRedirectUrl("/hello", &url.URL{}))
|
||||||
|
assert.Equal(t, url.URL{Path: "/world"}, *PrepareRedirectUrl("/world", &url.URL{}))
|
||||||
|
assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"redirect": {"/hello"}}.Encode()}, *PrepareRedirectUrl("/a", &url.URL{Path: "/hello"}))
|
||||||
|
assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"redirect": {"/hello?a=A"}}.Encode()}, *PrepareRedirectUrl("/a", &url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}}.Encode()}))
|
||||||
|
assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"redirect": {"/hello?a=A&b=B"}}.Encode()}, *PrepareRedirectUrl("/a", &url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}, "b": {"B"}}.Encode()}))
|
||||||
|
|
||||||
|
assert.Equal(t, url.URL{Path: "/hello", RawQuery: "z=y"}, *PrepareRedirectUrl("/hello?z=y", &url.URL{}))
|
||||||
|
assert.Equal(t, url.URL{Path: "/world", RawQuery: "z=y"}, *PrepareRedirectUrl("/world?z=y", &url.URL{}))
|
||||||
|
assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"z": {"y"}, "redirect": {"/hello"}}.Encode()}, *PrepareRedirectUrl("/a?z=y", &url.URL{Path: "/hello"}))
|
||||||
|
assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"z": {"y"}, "redirect": {"/hello?a=A"}}.Encode()}, *PrepareRedirectUrl("/a?z=y", &url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}}.Encode()}))
|
||||||
|
assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"z": {"y"}, "redirect": {"/hello?a=A&b=B"}}.Encode()}, *PrepareRedirectUrl("/a?z=y", &url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}, "b": {"B"}}.Encode()}))
|
||||||
|
}
|
87
server/edit.go
Normal file
87
server/edit.go
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/1f349/lavender/database"
|
||||||
|
"github.com/1f349/lavender/lists"
|
||||||
|
"github.com/1f349/lavender/pages"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *httpServer) EditGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
||||||
|
var user database.User
|
||||||
|
|
||||||
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
|
var err error
|
||||||
|
user, err = tx.GetUser(req.Context(), auth.Subject)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read user data: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lNonce := uuid.NewString()
|
||||||
|
http.SetCookie(rw, &http.Cookie{
|
||||||
|
Name: "tulip-nonce",
|
||||||
|
Value: lNonce,
|
||||||
|
Path: "/",
|
||||||
|
Expires: time.Now().Add(10 * time.Minute),
|
||||||
|
Secure: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
})
|
||||||
|
pages.RenderPageTemplate(rw, "edit", map[string]any{
|
||||||
|
"ServiceName": h.conf.ServiceName,
|
||||||
|
"User": user,
|
||||||
|
"Nonce": lNonce,
|
||||||
|
"FieldPronoun": user.Pronouns.String(),
|
||||||
|
"ListZoneInfo": lists.ListZoneInfo(),
|
||||||
|
"ListLocale": lists.ListLocale(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
func (h *httpServer) EditPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
||||||
|
if req.ParseForm() != nil {
|
||||||
|
rw.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = rw.Write([]byte("400 Bad Request\n"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var patch database.ProfilePatch
|
||||||
|
errs := patch.ParseFromForm(req.Form)
|
||||||
|
if len(errs) > 0 {
|
||||||
|
rw.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = fmt.Fprintln(rw, "<!DOCTYPE html>\n<html>\n<body>")
|
||||||
|
_, _ = fmt.Fprintln(rw, "<p>400 Bad Request: Failed to parse form data, press the back button in your browser, check your inputs and try again.</p>")
|
||||||
|
_, _ = fmt.Fprintln(rw, "<ul>")
|
||||||
|
for _, i := range errs {
|
||||||
|
_, _ = fmt.Fprintf(rw, " <li>%s</li>\n", i)
|
||||||
|
}
|
||||||
|
_, _ = fmt.Fprintln(rw, "</ul>")
|
||||||
|
_, _ = fmt.Fprintln(rw, "</body>\n</html>")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m := database.ModifyProfileParams{
|
||||||
|
Name: patch.Name,
|
||||||
|
Picture: patch.Picture,
|
||||||
|
Website: patch.Website,
|
||||||
|
Pronouns: patch.Pronouns,
|
||||||
|
Birthdate: patch.Birthdate,
|
||||||
|
Zone: patch.Zone.String(),
|
||||||
|
Locale: patch.Locale,
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
Subject: auth.Subject,
|
||||||
|
}
|
||||||
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
|
if err := tx.ModifyProfile(req.Context(), m); err != nil {
|
||||||
|
return fmt.Errorf("failed to modify user info: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Redirect(rw, req, "/edit", http.StatusFound)
|
||||||
|
}
|
@ -42,4 +42,51 @@ func (h *httpServer) Home(rw http.ResponseWriter, req *http.Request, _ httproute
|
|||||||
"Nonce": lNonce,
|
"Nonce": lNonce,
|
||||||
"IsAdmin": isAdmin,
|
"IsAdmin": isAdmin,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// rw.Header().Set("Content-Type", "text/html")
|
||||||
|
// lNonce := uuid.NewString()
|
||||||
|
// http.SetCookie(rw, &http.Cookie{
|
||||||
|
// Name: "tulip-nonce",
|
||||||
|
// Value: lNonce,
|
||||||
|
// Path: "/",
|
||||||
|
// Expires: time.Now().Add(10 * time.Minute),
|
||||||
|
// Secure: true,
|
||||||
|
// SameSite: http.SameSiteLaxMode,
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// if auth.IsGuest() {
|
||||||
|
// pages.RenderPageTemplate(rw, "index-guest", map[string]any{
|
||||||
|
// "ServiceName": h.conf.ServiceName,
|
||||||
|
// })
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// var userWithName string
|
||||||
|
// var userRole types.UserRole
|
||||||
|
// var hasTwoFactor bool
|
||||||
|
// if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
|
// userWithName, err = tx.GetUserDisplayName(req.Context(), auth.Subject)
|
||||||
|
// if err != nil {
|
||||||
|
// return fmt.Errorf("failed to get user display name: %w", err)
|
||||||
|
// }
|
||||||
|
// hasTwoFactor, err = tx.HasOtp(req.Context(), auth.Subject)
|
||||||
|
// if err != nil {
|
||||||
|
// return fmt.Errorf("failed to get user two factor state: %w", err)
|
||||||
|
// }
|
||||||
|
// userRole, err = tx.GetUserRole(req.Context(), auth.Subject)
|
||||||
|
// if err != nil {
|
||||||
|
// return fmt.Errorf("failed to get user role: %w", err)
|
||||||
|
// }
|
||||||
|
// return
|
||||||
|
// }) {
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// pages.RenderPageTemplate(rw, "index", map[string]any{
|
||||||
|
// "ServiceName": h.conf.ServiceName,
|
||||||
|
// "Auth": auth,
|
||||||
|
// "User": database.User{Subject: auth.Subject, Name: userWithName, Role: userRole},
|
||||||
|
// "Nonce": lNonce,
|
||||||
|
// "OtpEnabled": hasTwoFactor,
|
||||||
|
// "IsAdmin": userRole == types.RoleAdmin,
|
||||||
|
// })
|
||||||
}
|
}
|
||||||
|
115
server/login.go
115
server/login.go
@ -8,6 +8,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
auth2 "github.com/1f349/lavender/auth"
|
auth2 "github.com/1f349/lavender/auth"
|
||||||
"github.com/1f349/lavender/database"
|
"github.com/1f349/lavender/database"
|
||||||
|
"github.com/1f349/lavender/database/types"
|
||||||
"github.com/1f349/lavender/issuer"
|
"github.com/1f349/lavender/issuer"
|
||||||
"github.com/1f349/lavender/pages"
|
"github.com/1f349/lavender/pages"
|
||||||
"github.com/1f349/mjwt"
|
"github.com/1f349/mjwt"
|
||||||
@ -15,13 +16,31 @@ import (
|
|||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"github.com/mrmelon54/pronouns"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/text/language"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// getUserLoginName finds the `login_name` query parameter within the `/authorize` redirect url
|
||||||
|
func getUserLoginName(req *http.Request) string {
|
||||||
|
q := req.URL.Query()
|
||||||
|
if !q.Has("redirect") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
originUrl, err := url.ParseRequestURI(q.Get("redirect"))
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if originUrl.Path != "/authorize" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return originUrl.Query().Get("login_name")
|
||||||
|
}
|
||||||
|
|
||||||
func (h *httpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
func (h *httpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
||||||
if !auth.IsGuest() {
|
if !auth.IsGuest() {
|
||||||
h.SafeRedirect(rw, req)
|
h.SafeRedirect(rw, req)
|
||||||
@ -131,41 +150,70 @@ func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellK
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = h.DbTxError(func(tx *database.Queries) error {
|
err = h.DbTxError(func(tx *database.Queries) error {
|
||||||
jBytes, err := json.Marshal(sessionData.UserInfo)
|
name := sessionData.UserInfo.GetStringOrDefault("name", "Unknown User")
|
||||||
|
|
||||||
|
_, err = tx.GetUser(req.Context(), sessionData.Subject)
|
||||||
|
uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost")
|
||||||
|
uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified")
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
_, err := tx.AddOAuthUser(req.Context(), database.AddOAuthUserParams{
|
||||||
|
Email: uEmail,
|
||||||
|
EmailVerified: uEmailVerified,
|
||||||
|
Name: name,
|
||||||
|
Username: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"),
|
||||||
|
AuthNamespace: sso.Namespace,
|
||||||
|
AuthUser: sessionData.UserInfo.GetStringOrEmpty("sub"),
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.ModifyUserEmail(req.Context(), database.ModifyUserEmailParams{
|
||||||
|
Email: uEmail,
|
||||||
|
EmailVerified: uEmailVerified,
|
||||||
|
Subject: sessionData.Subject,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = tx.GetUser(req.Context(), sessionData.Subject)
|
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
err = tx.ModifyUserAuth(req.Context(), database.ModifyUserAuthParams{
|
||||||
uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost")
|
AuthType: types.AuthTypeOauth2,
|
||||||
uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified")
|
AuthNamespace: sso.Namespace,
|
||||||
id, err := tx.AddUser(req.Context(), database.AddUserParams{
|
AuthUser: sessionData.UserInfo.GetStringOrEmpty("sub"),
|
||||||
Name: "",
|
|
||||||
Subject: sessionData.Subject,
|
Subject: sessionData.Subject,
|
||||||
Password: "",
|
|
||||||
Email: uEmail,
|
|
||||||
EmailVerified: uEmailVerified,
|
|
||||||
UpdatedAt: time.Now(),
|
|
||||||
Active: true,
|
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
return tx.AddUser(req.Context(), database.AddUserParams{
|
|
||||||
Subject: sessionData.Subject,
|
|
||||||
Email: uEmail,
|
|
||||||
EmailVerified: uEmailVerified,
|
|
||||||
Roles: "",
|
|
||||||
Userinfo: string(jBytes),
|
|
||||||
UpdatedAt: time.Now(),
|
|
||||||
Active: true,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost")
|
|
||||||
uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified")
|
err = tx.ModifyUserRemoteLogin(req.Context(), database.ModifyUserRemoteLoginParams{
|
||||||
return tx.UpdateUserInfo(req.Context(), database.UpdateUserInfoParams{
|
Login: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"),
|
||||||
Email: sessionData.Subject,
|
ProfileUrl: sessionData.UserInfo.GetStringOrEmpty("profile"),
|
||||||
EmailVerified: uEmailVerified,
|
Subject: sessionData.Subject,
|
||||||
Userinfo: string(jBytes),
|
})
|
||||||
Subject: uEmail,
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pronoun, err := pronouns.FindPronoun(sessionData.UserInfo.GetStringOrEmpty("pronouns"))
|
||||||
|
if err != nil {
|
||||||
|
pronoun = pronouns.TheyThem
|
||||||
|
}
|
||||||
|
locale, err := language.Parse(sessionData.UserInfo.GetStringOrEmpty("locale"))
|
||||||
|
if err != nil {
|
||||||
|
locale = language.AmericanEnglish
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.ModifyProfile(req.Context(), database.ModifyProfileParams{
|
||||||
|
Name: name,
|
||||||
|
Picture: sessionData.UserInfo.GetStringOrEmpty("profile"),
|
||||||
|
Website: sessionData.UserInfo.GetStringOrEmpty("website"),
|
||||||
|
Pronouns: types.UserPronoun{Pronoun: pronoun},
|
||||||
|
Birthdate: sessionData.UserInfo.GetNullDate("birthdate"),
|
||||||
|
Zone: sessionData.UserInfo.GetStringOrDefault("zoneinfo", "UTC"),
|
||||||
|
Locale: types.UserLocale{Tag: locale},
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
Subject: sessionData.Subject,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -177,7 +225,7 @@ func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellK
|
|||||||
return tx.UpdateUserToken(req.Context(), database.UpdateUserTokenParams{
|
return tx.UpdateUserToken(req.Context(), database.UpdateUserTokenParams{
|
||||||
AccessToken: sql.NullString{String: token.AccessToken, Valid: true},
|
AccessToken: sql.NullString{String: token.AccessToken, Valid: true},
|
||||||
RefreshToken: sql.NullString{String: token.RefreshToken, Valid: true},
|
RefreshToken: sql.NullString{String: token.RefreshToken, Valid: true},
|
||||||
Expiry: sql.NullTime{Time: token.Expiry, Valid: true},
|
TokenExpiry: sql.NullTime{Time: token.Expiry, Valid: true},
|
||||||
Subject: sessionData.Subject,
|
Subject: sessionData.Subject,
|
||||||
})
|
})
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@ -208,6 +256,11 @@ func (l lavenderLoginRefresh) Valid() error { return l.RefreshTokenClaims.Valid(
|
|||||||
|
|
||||||
func (l lavenderLoginRefresh) Type() string { return "lavender-login-refresh" }
|
func (l lavenderLoginRefresh) Type() string { return "lavender-login-refresh" }
|
||||||
|
|
||||||
|
func (h *httpServer) setLoginDataCookie2(rw http.ResponseWriter, authData UserAuth) bool {
|
||||||
|
// TODO(melon): should probably merge there methods
|
||||||
|
return h.setLoginDataCookie(rw, authData, "")
|
||||||
|
}
|
||||||
|
|
||||||
func (h *httpServer) setLoginDataCookie(rw http.ResponseWriter, authData UserAuth, loginName string) bool {
|
func (h *httpServer) setLoginDataCookie(rw http.ResponseWriter, authData UserAuth, loginName string) bool {
|
||||||
ps := auth.NewPermStorage()
|
ps := auth.NewPermStorage()
|
||||||
accId := uuid.NewString()
|
accId := uuid.NewString()
|
||||||
@ -286,13 +339,13 @@ func (h *httpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Re
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !token.AccessToken.Valid || !token.RefreshToken.Valid || !token.Expiry.Valid {
|
if !token.AccessToken.Valid || !token.RefreshToken.Valid || !token.TokenExpiry.Valid {
|
||||||
return fmt.Errorf("invalid oauth token")
|
return fmt.Errorf("invalid oauth token")
|
||||||
}
|
}
|
||||||
oauthToken = &oauth2.Token{
|
oauthToken = &oauth2.Token{
|
||||||
AccessToken: token.AccessToken.String,
|
AccessToken: token.AccessToken.String,
|
||||||
RefreshToken: token.RefreshToken.String,
|
RefreshToken: token.RefreshToken.String,
|
||||||
Expiry: token.Expiry.Time,
|
Expiry: token.TokenExpiry.Time,
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
25
server/logout.go
Normal file
25
server/logout.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *httpServer) logoutPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, _ UserAuth) {
|
||||||
|
http.SetCookie(rw, &http.Cookie{
|
||||||
|
Name: "lavender-login-access",
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: -1,
|
||||||
|
Secure: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
})
|
||||||
|
http.SetCookie(rw, &http.Cookie{
|
||||||
|
Name: "lavender-login-refresh",
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: -1,
|
||||||
|
Secure: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
})
|
||||||
|
|
||||||
|
http.Redirect(rw, req, "/", http.StatusFound)
|
||||||
|
}
|
123
server/mail.go
Normal file
123
server/mail.go
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/1f349/lavender/database"
|
||||||
|
"github.com/1f349/lavender/pages"
|
||||||
|
"github.com/emersion/go-message/mail"
|
||||||
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *httpServer) MailVerify(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||||
|
code := params.ByName("code")
|
||||||
|
|
||||||
|
k := mailLinkKey{mailLinkVerifyEmail, code}
|
||||||
|
|
||||||
|
userSub, ok := h.mailLinkCache.Get(k)
|
||||||
|
if !ok {
|
||||||
|
http.Error(rw, "Invalid email verification code", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
|
return tx.VerifyUserEmail(req.Context(), userSub)
|
||||||
|
}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mailLinkCache.Delete(k)
|
||||||
|
|
||||||
|
http.Error(rw, "Email address has been verified, you may close this tab and return to the login page.", http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpServer) MailPassword(rw http.ResponseWriter, _ *http.Request, params httprouter.Params) {
|
||||||
|
code := params.ByName("code")
|
||||||
|
|
||||||
|
k := mailLinkKey{mailLinkResetPassword, code}
|
||||||
|
_, ok := h.mailLinkCache.Get(k)
|
||||||
|
if !ok {
|
||||||
|
http.Error(rw, "Invalid password reset code", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pages.RenderPageTemplate(rw, "reset-password", map[string]any{
|
||||||
|
"ServiceName": h.conf.ServiceName,
|
||||||
|
"Code": code,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpServer) MailPasswordPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
|
||||||
|
pw := req.PostFormValue("new_password")
|
||||||
|
rpw := req.PostFormValue("confirm_password")
|
||||||
|
code := req.PostFormValue("code")
|
||||||
|
|
||||||
|
// reverse passwords are possible
|
||||||
|
if len(pw) == 0 {
|
||||||
|
http.Error(rw, "Cannot set an empty password", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// bcrypt only allows up to 72 bytes anyway
|
||||||
|
if len(pw) > 64 {
|
||||||
|
http.Error(rw, "Security by extremely long password is a weird flex", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rpw != pw {
|
||||||
|
http.Error(rw, "Passwords do not match", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
k := mailLinkKey{mailLinkResetPassword, code}
|
||||||
|
userSub, ok := h.mailLinkCache.Get(k)
|
||||||
|
if !ok {
|
||||||
|
http.Error(rw, "Invalid password reset code", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mailLinkCache.Delete(k)
|
||||||
|
|
||||||
|
// reset password database call
|
||||||
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
|
return tx.ChangePassword(req.Context(), userSub, pw)
|
||||||
|
}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
http.Error(rw, "Reset password successfully, you can login now.", http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpServer) MailDelete(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||||
|
code := params.ByName("code")
|
||||||
|
|
||||||
|
k := mailLinkKey{mailLinkDelete, code}
|
||||||
|
userSub, ok := h.mailLinkCache.Get(k)
|
||||||
|
if !ok {
|
||||||
|
http.Error(rw, "Invalid email delete code", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var userInfo database.User
|
||||||
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
|
userInfo, err = tx.GetUser(req.Context(), userSub)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return tx.FlagUserAsDeleted(req.Context(), userSub)
|
||||||
|
}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mailLinkCache.Delete(k)
|
||||||
|
|
||||||
|
// parse email for headers
|
||||||
|
address, err := mail.ParseAddress(userInfo.Email)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, "500 Internal Server Error: Failed to parse user email address", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.conf.Mail.SendEmailTemplate("mail-account-delete", "Account Deletion", userInfo.Name, address, nil)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, "Failed to send confirmation email.", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
http.Error(rw, "You will receive an email shortly to verify this action, you may close this tab.", http.StatusOK)
|
||||||
|
}
|
@ -12,11 +12,17 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func SetupManageApps(r *httprouter.Router, hs *httpServer) {
|
||||||
|
r.GET("/manage/apps", hs.RequireAuthentication(hs.ManageAppsGet))
|
||||||
|
r.GET("/manage/apps/create", hs.RequireAuthentication(hs.ManageAppsCreateGet))
|
||||||
|
r.POST("/manage/apps", hs.RequireAuthentication(hs.ManageAppsPost))
|
||||||
|
}
|
||||||
|
|
||||||
func (h *httpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
func (h *httpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
||||||
q := req.URL.Query()
|
q := req.URL.Query()
|
||||||
offset, _ := strconv.Atoi(q.Get("offset"))
|
offset, _ := strconv.Atoi(q.Get("offset"))
|
||||||
|
|
||||||
var roles string
|
var roles []string
|
||||||
var appList []database.GetAppListRow
|
var appList []database.GetAppListRow
|
||||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
|
roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
|
||||||
@ -24,7 +30,7 @@ func (h *httpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
appList, err = tx.GetAppList(req.Context(), database.GetAppListParams{
|
appList, err = tx.GetAppList(req.Context(), database.GetAppListParams{
|
||||||
Owner: auth.Subject,
|
OwnerSubject: auth.Subject,
|
||||||
Column2: HasRole(roles, role.LavenderAdmin),
|
Column2: HasRole(roles, role.LavenderAdmin),
|
||||||
Offset: int64(offset),
|
Offset: int64(offset),
|
||||||
})
|
})
|
||||||
@ -61,7 +67,7 @@ func (h *httpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpServer) ManageAppsCreateGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
func (h *httpServer) ManageAppsCreateGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
||||||
var roles string
|
var roles []string
|
||||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
|
roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
|
||||||
return
|
return
|
||||||
@ -96,7 +102,7 @@ func (h *httpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
|
|||||||
active := req.Form.Has("active")
|
active := req.Form.Has("active")
|
||||||
|
|
||||||
if sso || hasPerms {
|
if sso || hasPerms {
|
||||||
var roles string
|
var roles []string
|
||||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
|
roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
|
||||||
return
|
return
|
||||||
@ -125,7 +131,7 @@ func (h *httpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
|
|||||||
Name: name,
|
Name: name,
|
||||||
Secret: secret,
|
Secret: secret,
|
||||||
Domain: domain,
|
Domain: domain,
|
||||||
Owner: auth.Subject,
|
OwnerSubject: auth.Subject,
|
||||||
Perms: perms,
|
Perms: perms,
|
||||||
Public: public,
|
Public: public,
|
||||||
Sso: sso,
|
Sso: sso,
|
||||||
@ -145,7 +151,7 @@ func (h *httpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
|
|||||||
Sso: sso,
|
Sso: sso,
|
||||||
Active: active,
|
Active: active,
|
||||||
Subject: req.FormValue("subject"),
|
Subject: req.FormValue("subject"),
|
||||||
Owner: auth.Subject,
|
OwnerSubject: auth.Subject,
|
||||||
})
|
})
|
||||||
}) {
|
}) {
|
||||||
return
|
return
|
||||||
@ -166,7 +172,7 @@ func (h *httpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
|
|||||||
err = tx.ResetClientAppSecret(req.Context(), database.ResetClientAppSecretParams{
|
err = tx.ResetClientAppSecret(req.Context(), database.ResetClientAppSecretParams{
|
||||||
Secret: secret,
|
Secret: secret,
|
||||||
Subject: sub,
|
Subject: sub,
|
||||||
Owner: auth.Subject,
|
OwnerSubject: auth.Subject,
|
||||||
})
|
})
|
||||||
return err
|
return err
|
||||||
}) {
|
}) {
|
||||||
|
@ -5,16 +5,22 @@ import (
|
|||||||
"github.com/1f349/lavender/pages"
|
"github.com/1f349/lavender/pages"
|
||||||
"github.com/1f349/lavender/role"
|
"github.com/1f349/lavender/role"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func SetupManageUsers(r *httprouter.Router, hs *httpServer) {
|
||||||
|
r.GET("/manage/users", hs.RequireAdminAuthentication(hs.ManageUsersGet))
|
||||||
|
r.POST("/manage/users", hs.RequireAdminAuthentication(hs.ManageUsersPost))
|
||||||
|
}
|
||||||
|
|
||||||
func (h *httpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
func (h *httpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
||||||
q := req.URL.Query()
|
q := req.URL.Query()
|
||||||
offset, _ := strconv.Atoi(q.Get("offset"))
|
offset, _ := strconv.Atoi(q.Get("offset"))
|
||||||
|
|
||||||
var roles string
|
var roles []string
|
||||||
var userList []database.GetUserListRow
|
var userList []database.GetUserListRow
|
||||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
|
roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
|
||||||
@ -64,7 +70,7 @@ func (h *httpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var roles string
|
var roles []string
|
||||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
|
roles, err = tx.GetUserRoles(req.Context(), auth.Subject)
|
||||||
return
|
return
|
||||||
@ -78,18 +84,38 @@ func (h *httpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
|
|||||||
|
|
||||||
offset := req.Form.Get("offset")
|
offset := req.Form.Get("offset")
|
||||||
action := req.Form.Get("action")
|
action := req.Form.Get("action")
|
||||||
newRoles := req.Form.Get("roles")
|
newRoles := req.Form["roles"]
|
||||||
active := req.Form.Has("active")
|
active := req.Form.Has("active")
|
||||||
|
|
||||||
switch action {
|
switch action {
|
||||||
case "edit":
|
case "edit":
|
||||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
sub := req.Form.Get("subject")
|
sub := req.Form.Get("subject")
|
||||||
return tx.UpdateUser(req.Context(), database.UpdateUserParams{
|
return tx.UseTx(req.Context(), func(tx *database.Queries) (err error) {
|
||||||
Active: active,
|
err = tx.ChangeUserActive(req.Context(), database.ChangeUserActiveParams{Column1: active, Subject: sub})
|
||||||
Roles: newRoles,
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = tx.RemoveUserRoles(req.Context(), sub)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
errGrp := new(errgroup.Group)
|
||||||
|
errGrp.SetLimit(3)
|
||||||
|
for _, roleName := range newRoles {
|
||||||
|
errGrp.Go(func() error {
|
||||||
|
roleId, err := strconv.ParseInt(roleName, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.AddUserRole(req.Context(), database.AddUserRoleParams{
|
||||||
|
RoleID: roleId,
|
||||||
Subject: sub,
|
Subject: sub,
|
||||||
})
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return errGrp.Wait()
|
||||||
|
})
|
||||||
}) {
|
}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
130
server/oauth.go
130
server/oauth.go
@ -1,15 +1,143 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
clientStore "github.com/1f349/lavender/client-store"
|
||||||
|
"github.com/1f349/lavender/database"
|
||||||
"github.com/1f349/lavender/logger"
|
"github.com/1f349/lavender/logger"
|
||||||
"github.com/1f349/lavender/pages"
|
"github.com/1f349/lavender/pages"
|
||||||
"github.com/1f349/lavender/scope"
|
"github.com/1f349/lavender/scope"
|
||||||
|
"github.com/1f349/lavender/utils"
|
||||||
|
"github.com/1f349/mjwt"
|
||||||
|
"github.com/go-oauth2/oauth2/v4/generates"
|
||||||
|
"github.com/go-oauth2/oauth2/v4/manage"
|
||||||
|
"github.com/go-oauth2/oauth2/v4/server"
|
||||||
|
"github.com/go-oauth2/oauth2/v4/store"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func SetupOAuth2(r *httprouter.Router, hs *httpServer, key *mjwt.Issuer, db *database.Queries) {
|
||||||
|
oauthManager := manage.NewManager()
|
||||||
|
oauthManager.MapAuthorizeGenerate(generates.NewAuthorizeGenerate())
|
||||||
|
oauthManager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
|
||||||
|
oauthManager.MustTokenStorage(store.NewMemoryTokenStore())
|
||||||
|
oauthManager.MapAccessGenerate(NewMJWTAccessGenerate(key, db))
|
||||||
|
oauthManager.MapClientStorage(clientStore.New(db))
|
||||||
|
|
||||||
|
oauthSrv := server.NewDefaultServer(oauthManager)
|
||||||
|
oauthSrv.SetClientInfoHandler(func(req *http.Request) (clientID, clientSecret string, err error) {
|
||||||
|
cId, cSecret, err := server.ClientBasicHandler(req)
|
||||||
|
if cId == "" && cSecret == "" {
|
||||||
|
cId, cSecret, err = server.ClientFormHandler(req)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
return cId, cSecret, nil
|
||||||
|
})
|
||||||
|
oauthSrv.SetUserAuthorizationHandler(hs.oauthUserAuthorization)
|
||||||
|
oauthSrv.SetAuthorizeScopeHandler(func(rw http.ResponseWriter, req *http.Request) (string, error) {
|
||||||
|
var form url.Values
|
||||||
|
if req.Method == http.MethodPost {
|
||||||
|
form = req.PostForm
|
||||||
|
} else {
|
||||||
|
form = req.URL.Query()
|
||||||
|
}
|
||||||
|
a := form.Get("scope")
|
||||||
|
if !scope.ScopesExist(a) {
|
||||||
|
return "", errInvalidScope
|
||||||
|
}
|
||||||
|
return a, nil
|
||||||
|
})
|
||||||
|
addIdTokenSupport(oauthSrv, db, key)
|
||||||
|
|
||||||
|
r.GET("/authorize", hs.RequireAuthentication(hs.authorizeEndpoint))
|
||||||
|
r.POST("/authorize", hs.RequireAuthentication(hs.authorizeEndpoint))
|
||||||
|
r.POST("/token", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||||
|
if err := oauthSrv.HandleTokenRequest(rw, req); err != nil {
|
||||||
|
http.Error(rw, "Failed to handle token request", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpServer) userInfoRequest(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rw.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
rw.Header().Set("Access-Control-Allow-Headers", "Authorization,Content-Type")
|
||||||
|
rw.Header().Set("Access-Control-Allow-Origin", strings.TrimSuffix(req.Referer(), "/"))
|
||||||
|
rw.Header().Set("Access-Control-Allow-Methods", "GET")
|
||||||
|
if req.Method == http.MethodOptions {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := h.oauthSrv.ValidationBearerToken(req)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, "403 Forbidden", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userId := token.GetUserID()
|
||||||
|
|
||||||
|
sso := h.manager.FindServiceFromLogin(userId)
|
||||||
|
if sso == nil {
|
||||||
|
http.Error(rw, "Invalid user", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var user database.User
|
||||||
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
|
user, err = tx.GetUser(req.Context(), userId)
|
||||||
|
return
|
||||||
|
}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := ParseClaims(token.GetScope())
|
||||||
|
if !claims["openid"] {
|
||||||
|
http.Error(rw, "Invalid scope", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m := make(map[string]any)
|
||||||
|
|
||||||
|
if claims["name"] {
|
||||||
|
m["name"] = user.Name
|
||||||
|
}
|
||||||
|
if claims["username"] {
|
||||||
|
m["preferred_username"] = user.Login
|
||||||
|
m["login"] = user.Login
|
||||||
|
}
|
||||||
|
if claims["profile"] {
|
||||||
|
m["profile"] = user.ProfileUrl
|
||||||
|
m["picture"] = user.Picture
|
||||||
|
m["website"] = user.Website
|
||||||
|
}
|
||||||
|
if claims["email"] {
|
||||||
|
m["email"] = user.Email
|
||||||
|
m["email_verified"] = user.EmailVerified
|
||||||
|
}
|
||||||
|
if claims["birthdate"] && user.Birthdate.Valid {
|
||||||
|
m["birthdate"] = user.Birthdate.Date
|
||||||
|
}
|
||||||
|
if claims["age"] && user.Birthdate.Valid {
|
||||||
|
m["age"] = utils.Age(user.Birthdate.Date.ToTime())
|
||||||
|
}
|
||||||
|
if claims["zoneinfo"] {
|
||||||
|
m["zoneinfo"] = user.Zone
|
||||||
|
}
|
||||||
|
if claims["locale"] {
|
||||||
|
m["locale"] = user.Locale
|
||||||
|
}
|
||||||
|
|
||||||
|
m["sub"] = userId
|
||||||
|
m["aud"] = token.GetClientID()
|
||||||
|
m["updated_at"] = time.Now().Unix()
|
||||||
|
|
||||||
|
_ = json.NewEncoder(rw).Encode(m)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *httpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
func (h *httpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
||||||
// function is only called with GET or POST method
|
// function is only called with GET or POST method
|
||||||
isPost := req.Method == http.MethodPost
|
isPost := req.Method == http.MethodPost
|
||||||
@ -95,7 +223,7 @@ func (h *httpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request
|
|||||||
"ServiceName": h.conf.ServiceName,
|
"ServiceName": h.conf.ServiceName,
|
||||||
"AppName": appName,
|
"AppName": appName,
|
||||||
"AppDomain": appDomain,
|
"AppDomain": appDomain,
|
||||||
"DisplayName": auth.DisplayName,
|
"DisplayName": auth.UserInfo.GetStringOrEmpty("name"),
|
||||||
"WantsList": scope.FancyScopeList(scopeList),
|
"WantsList": scope.FancyScopeList(scopeList),
|
||||||
"ResponseType": form.Get("response_type"),
|
"ResponseType": form.Get("response_type"),
|
||||||
"ResponseMode": form.Get("response_mode"),
|
"ResponseMode": form.Get("response_mode"),
|
||||||
|
196
server/otp.go
Normal file
196
server/otp.go
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"github.com/1f349/lavender/database"
|
||||||
|
"github.com/1f349/lavender/pages"
|
||||||
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"github.com/skip2/go-qrcode"
|
||||||
|
"github.com/xlzd/gotp"
|
||||||
|
"html/template"
|
||||||
|
"image/png"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *httpServer) loginOtpGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
||||||
|
if !auth.NeedOtp {
|
||||||
|
h.SafeRedirect(rw, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pages.RenderPageTemplate(rw, "login-otp", map[string]any{
|
||||||
|
"ServiceName": h.conf.ServiceName,
|
||||||
|
"Redirect": req.URL.Query().Get("redirect"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpServer) loginOtpPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
||||||
|
if !auth.NeedOtp {
|
||||||
|
http.Redirect(rw, req, "/", http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
otpInput := req.FormValue("code")
|
||||||
|
if h.fetchAndValidateOtp(rw, auth.Subject, otpInput) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
auth.NeedOtp = false
|
||||||
|
|
||||||
|
h.setLoginDataCookie2(rw, auth)
|
||||||
|
h.SafeRedirect(rw, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpServer) fetchAndValidateOtp(rw http.ResponseWriter, sub, code string) bool {
|
||||||
|
var hasOtp bool
|
||||||
|
var otpRow database.GetOtpRow
|
||||||
|
var secret string
|
||||||
|
var digits int64
|
||||||
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
|
hasOtp, err = tx.HasOtp(context.Background(), sub)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if hasOtp {
|
||||||
|
otpRow, err = tx.GetOtp(context.Background(), sub)
|
||||||
|
secret = otpRow.OtpSecret
|
||||||
|
digits = otpRow.OtpDigits
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasOtp {
|
||||||
|
totp := gotp.NewTOTP(secret, int(digits), 30, nil)
|
||||||
|
if !verifyTotp(totp, code) {
|
||||||
|
http.Error(rw, "400 Bad Request: Invalid OTP code", http.StatusBadRequest)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpServer) editOtpPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
||||||
|
if req.Method == http.MethodPost && req.FormValue("remove") == "1" {
|
||||||
|
if !req.Form.Has("code") {
|
||||||
|
// render page
|
||||||
|
pages.RenderPageTemplate(rw, "remove-otp", map[string]any{
|
||||||
|
"ServiceName": h.conf.ServiceName,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
otpInput := req.Form.Get("code")
|
||||||
|
if h.fetchAndValidateOtp(rw, auth.Subject, otpInput) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
|
return tx.DeleteOtp(req.Context(), auth.Subject)
|
||||||
|
}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
http.Redirect(rw, req, "/", http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var digits int
|
||||||
|
switch req.FormValue("digits") {
|
||||||
|
case "6":
|
||||||
|
digits = 6
|
||||||
|
case "7":
|
||||||
|
digits = 7
|
||||||
|
case "8":
|
||||||
|
digits = 8
|
||||||
|
default:
|
||||||
|
http.Error(rw, "400 Bad Request: Invalid number of digits for OTP code", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
secret := req.FormValue("secret")
|
||||||
|
if !gotp.IsSecretValid(secret) {
|
||||||
|
http.Error(rw, "400 Bad Request: Invalid secret", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if secret == "" {
|
||||||
|
// get user email
|
||||||
|
var email string
|
||||||
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
|
var err error
|
||||||
|
email, err = tx.GetUserEmail(req.Context(), auth.Subject)
|
||||||
|
return err
|
||||||
|
}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
secret = gotp.RandomSecret(64)
|
||||||
|
if secret == "" {
|
||||||
|
http.Error(rw, "500 Internal Server Error: failed to generate OTP secret", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
totp := gotp.NewTOTP(secret, digits, 30, nil)
|
||||||
|
otpUri := totp.ProvisioningUri(email, h.conf.OtpIssuer)
|
||||||
|
code, err := qrcode.New(otpUri, qrcode.Medium)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, "500 Internal Server Error: failed to generate QR code", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
qrImg := code.Image(60 * 4)
|
||||||
|
qrBounds := qrImg.Bounds()
|
||||||
|
qrWidth := qrBounds.Dx()
|
||||||
|
|
||||||
|
qrBuf := new(bytes.Buffer)
|
||||||
|
if png.Encode(qrBuf, qrImg) != nil {
|
||||||
|
http.Error(rw, "500 Internal Server Error: failed to generate PNG image of QR code", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// render page
|
||||||
|
pages.RenderPageTemplate(rw, "edit-otp", map[string]any{
|
||||||
|
"ServiceName": h.conf.ServiceName,
|
||||||
|
"OtpQr": template.URL("data:qrImg/png;base64," + base64.StdEncoding.EncodeToString(qrBuf.Bytes())),
|
||||||
|
"QrWidth": qrWidth,
|
||||||
|
"OtpUrl": otpUri,
|
||||||
|
"OtpSecret": secret,
|
||||||
|
"OtpDigits": digits,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
totp := gotp.NewTOTP(secret, digits, 30, nil)
|
||||||
|
|
||||||
|
if !verifyTotp(totp, req.FormValue("code")) {
|
||||||
|
http.Error(rw, "400 Bad Request: invalid OTP code", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
|
return tx.SetOtp(req.Context(), database.SetOtpParams{
|
||||||
|
Subject: auth.Subject,
|
||||||
|
OtpSecret: secret,
|
||||||
|
OtpDigits: int64(digits),
|
||||||
|
})
|
||||||
|
}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
http.Redirect(rw, req, "/", http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyTotp(totp *gotp.TOTP, code string) bool {
|
||||||
|
t := time.Now()
|
||||||
|
if totp.VerifyTime(code, t) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if totp.VerifyTime(code, t.Add(-30*time.Second)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return totp.VerifyTime(code, t.Add(30*time.Second))
|
||||||
|
}
|
@ -7,6 +7,5 @@ import (
|
|||||||
|
|
||||||
func TestHasRole(t *testing.T) {
|
func TestHasRole(t *testing.T) {
|
||||||
assert.True(t, HasRole([]string{"lavender:admin", "test:something-else"}, "lavender:admin"))
|
assert.True(t, HasRole([]string{"lavender:admin", "test:something-else"}, "lavender:admin"))
|
||||||
assert.False(t, HasRole([]string{"lavender:admin", "test:something-else"}, "lavender:admin"))
|
|
||||||
assert.False(t, HasRole([]string{"lavender:", "test:something-else"}, "lavender:admin"))
|
assert.False(t, HasRole([]string{"lavender:", "test:something-else"}, "lavender:admin"))
|
||||||
}
|
}
|
||||||
|
@ -3,17 +3,14 @@ package server
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/1f349/cache"
|
"github.com/1f349/cache"
|
||||||
clientStore "github.com/1f349/lavender/client-store"
|
|
||||||
"github.com/1f349/lavender/conf"
|
"github.com/1f349/lavender/conf"
|
||||||
"github.com/1f349/lavender/database"
|
"github.com/1f349/lavender/database"
|
||||||
"github.com/1f349/lavender/issuer"
|
"github.com/1f349/lavender/issuer"
|
||||||
|
"github.com/1f349/lavender/logger"
|
||||||
"github.com/1f349/lavender/pages"
|
"github.com/1f349/lavender/pages"
|
||||||
scope2 "github.com/1f349/lavender/scope"
|
|
||||||
"github.com/1f349/mjwt"
|
"github.com/1f349/mjwt"
|
||||||
"github.com/go-oauth2/oauth2/v4/generates"
|
|
||||||
"github.com/go-oauth2/oauth2/v4/manage"
|
"github.com/go-oauth2/oauth2/v4/manage"
|
||||||
"github.com/go-oauth2/oauth2/v4/server"
|
"github.com/go-oauth2/oauth2/v4/server"
|
||||||
"github.com/go-oauth2/oauth2/v4/store"
|
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -76,44 +73,15 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, db *database.Queries, s
|
|||||||
mailLinkCache: cache.New[mailLinkKey, string](),
|
mailLinkCache: cache.New[mailLinkKey, string](),
|
||||||
}
|
}
|
||||||
|
|
||||||
oauthManager := manage.NewManager()
|
var err error
|
||||||
oauthManager.MapAuthorizeGenerate(generates.NewAuthorizeGenerate())
|
hs.manager, err = issuer.NewManager(config.SsoServices)
|
||||||
oauthManager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
|
|
||||||
oauthManager.MustTokenStorage(store.NewMemoryTokenStore())
|
|
||||||
oauthManager.MapAccessGenerate(NewMJWTAccessGenerate(signingKey, db))
|
|
||||||
oauthManager.MapClientStorage(clientStore.New(db))
|
|
||||||
|
|
||||||
oauthSrv := server.NewDefaultServer(oauthManager)
|
|
||||||
oauthSrv.SetClientInfoHandler(func(req *http.Request) (clientID, clientSecret string, err error) {
|
|
||||||
cId, cSecret, err := server.ClientBasicHandler(req)
|
|
||||||
if cId == "" && cSecret == "" {
|
|
||||||
cId, cSecret, err = server.ClientFormHandler(req)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
logger.Logger.Fatal("Failed to load SSO services", "err", err)
|
||||||
}
|
}
|
||||||
return cId, cSecret, nil
|
|
||||||
})
|
|
||||||
oauthSrv.SetUserAuthorizationHandler(hs.oauthUserAuthorization)
|
|
||||||
oauthSrv.SetAuthorizeScopeHandler(func(rw http.ResponseWriter, req *http.Request) (scope string, err error) {
|
|
||||||
var form url.Values
|
|
||||||
if req.Method == http.MethodPost {
|
|
||||||
form = req.PostForm
|
|
||||||
} else {
|
|
||||||
form = req.URL.Query()
|
|
||||||
}
|
|
||||||
a := form.Get("scope")
|
|
||||||
if !scope2.ScopesExist(a) {
|
|
||||||
return "", errInvalidScope
|
|
||||||
}
|
|
||||||
return a, nil
|
|
||||||
})
|
|
||||||
addIdTokenSupport(oauthSrv, db, signingKey)
|
|
||||||
|
|
||||||
ssoManager := issuer.NewManager(config.SsoServices)
|
|
||||||
|
|
||||||
SetupOpenId(r, config.BaseUrl, signingKey)
|
SetupOpenId(r, config.BaseUrl, signingKey)
|
||||||
r.POST("/logout", hs.RequireAuthentication(fu))
|
r.GET("/", hs.OptionalAuthentication(false, hs.Home))
|
||||||
|
r.POST("/logout", hs.RequireAuthentication(hs.logoutPost))
|
||||||
|
|
||||||
// theme styles
|
// theme styles
|
||||||
r.GET("/assets/*filepath", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
r.GET("/assets/*filepath", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||||
@ -126,8 +94,16 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, db *database.Queries, s
|
|||||||
http.ServeContent(rw, req, path.Base(name), contentCache, out)
|
http.ServeContent(rw, req, path.Base(name), contentCache, out)
|
||||||
})
|
})
|
||||||
|
|
||||||
SetupManageApps(r)
|
// login steps
|
||||||
SetupManageUsers(r)
|
r.GET("/login", hs.OptionalAuthentication(false, hs.loginGet))
|
||||||
|
r.POST("/login", hs.OptionalAuthentication(false, hs.loginPost))
|
||||||
|
r.GET("/login/otp", hs.OptionalAuthentication(true, hs.loginOtpGet))
|
||||||
|
r.POST("/login/otp", hs.OptionalAuthentication(true, hs.loginOtpPost))
|
||||||
|
r.GET("/callback", hs.OptionalAuthentication(false, hs.loginCallback))
|
||||||
|
|
||||||
|
SetupManageApps(r, hs)
|
||||||
|
SetupManageUsers(r, hs)
|
||||||
|
SetupOAuth2(r, hs, signingKey, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpServer) SafeRedirect(rw http.ResponseWriter, req *http.Request) {
|
func (h *httpServer) SafeRedirect(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
@ -21,3 +21,11 @@ sql:
|
|||||||
go_type: "github.com/1f349/lavender/database/types.UserZone"
|
go_type: "github.com/1f349/lavender/database/types.UserZone"
|
||||||
- column: "users.locale"
|
- column: "users.locale"
|
||||||
go_type: "github.com/1f349/lavender/database/types.UserLocale"
|
go_type: "github.com/1f349/lavender/database/types.UserLocale"
|
||||||
|
- column: "users.auth_type"
|
||||||
|
go_type: "github.com/1f349/lavender/database/types.AuthType"
|
||||||
|
- column: "users.access_token"
|
||||||
|
go_type: "database/sql.NullString"
|
||||||
|
- column: "users.refresh_token"
|
||||||
|
go_type: "database/sql.NullString"
|
||||||
|
- column: "users.token_expiry"
|
||||||
|
go_type: "database/sql.NullTime"
|
||||||
|
28
utils/age.go
Normal file
28
utils/age.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ageTimeNow = time.Now
|
||||||
|
|
||||||
|
func Age(t time.Time) int {
|
||||||
|
n := ageTimeNow()
|
||||||
|
|
||||||
|
// the birthday is in the future so the age is 0
|
||||||
|
if n.Before(t) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// the year difference
|
||||||
|
dy := n.Year() - t.Year()
|
||||||
|
|
||||||
|
// the birthday in the current year
|
||||||
|
tCurrent := t.AddDate(dy, 0, 0)
|
||||||
|
|
||||||
|
// minus 1 if the birthday has not yet occurred in the current year
|
||||||
|
if tCurrent.Before(n) {
|
||||||
|
dy -= 1
|
||||||
|
}
|
||||||
|
return dy
|
||||||
|
}
|
30
utils/age_test.go
Normal file
30
utils/age_test.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAge(t *testing.T) {
|
||||||
|
lGmt := time.FixedZone("GMT", 0)
|
||||||
|
lBst := time.FixedZone("BST", 60*60)
|
||||||
|
|
||||||
|
tPast := time.Date(1939, time.January, 5, 0, 0, 0, 0, lGmt)
|
||||||
|
tPastDst := time.Date(2001, time.January, 5, 1, 0, 0, 0, lBst)
|
||||||
|
tCur := time.Date(2005, time.January, 5, 0, 30, 0, 0, lGmt)
|
||||||
|
tCurDst := time.Date(2005, time.January, 5, 0, 30, 0, 0, lBst)
|
||||||
|
tFut := time.Date(2008, time.January, 5, 0, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
ageTimeNow = func() time.Time { return tCur }
|
||||||
|
assert.Equal(t, 65, Age(tPast))
|
||||||
|
assert.Equal(t, 3, Age(tPastDst))
|
||||||
|
assert.Equal(t, 0, Age(tFut))
|
||||||
|
|
||||||
|
ageTimeNow = func() time.Time { return tCurDst }
|
||||||
|
assert.Equal(t, 66, Age(tPast))
|
||||||
|
assert.Equal(t, 4, Age(tPastDst))
|
||||||
|
fmt.Println(tPastDst.AddDate(4, 0, 0).UTC(), tCur.UTC())
|
||||||
|
assert.Equal(t, 0, Age(tFut))
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user