diff --git a/auth/userinfofields.go b/auth/userinfofields.go index 7f2093c..3411eef 100644 --- a/auth/userinfofields.go +++ b/auth/userinfofields.go @@ -1,5 +1,7 @@ package auth +import "github.com/hardfinhq/go-date" + type UserInfoFields map[string]any func (u UserInfoFields) GetString(key string) (string, bool) { @@ -20,7 +22,24 @@ func (u UserInfoFields) GetStringOrEmpty(key string) string { return s } +func (u UserInfoFields) GetStringFromKeysOrEmpty(keys ...string) string { + for _, key := range keys { + s, _ := u[key].(string) + if s == "" { + continue + } + return s + } + return "" +} + func (u UserInfoFields) GetBoolean(key string) (bool, bool) { b, ok := u[key].(bool) return b, ok } + +func (u UserInfoFields) GetNullDate(key string) date.NullDate { + s, _ := u[key].(string) + fromStr, err := date.FromString(s) + return date.NullDate{Date: fromStr, Valid: err == nil} +} diff --git a/cmd/lavender/serve.go b/cmd/lavender/serve.go index ebdb20a..f32367f 100644 --- a/cmd/lavender/serve.go +++ b/cmd/lavender/serve.go @@ -3,10 +3,13 @@ package main import ( "context" "flag" + "fmt" "github.com/1f349/lavender" "github.com/1f349/lavender/conf" + "github.com/1f349/lavender/database" "github.com/1f349/lavender/logger" "github.com/1f349/lavender/pages" + "github.com/1f349/lavender/role" "github.com/1f349/lavender/server" "github.com/1f349/mjwt" "github.com/charmbracelet/log" @@ -114,6 +117,10 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) logger.Logger.Fatal("Failed to open database", "err", err) } + if err := checkDbHasUser(db); err != nil { + logger.Logger.Fatal("Failed to add initial user", "err", err) + } + if err := pages.LoadPages(wd); err != nil { logger.Logger.Fatal("Failed to load page templates:", err) } @@ -168,3 +175,45 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) return subcommands.ExitSuccess } + +func checkDbHasUser(db *database.Queries) error { + value, err := db.HasUser(context.Background()) + if err != nil { + return err + } + + if !value { + logger.Logger.Warn("No users are available, setting up initial admin user") + + ctx := context.Background() + err = db.UseTx(ctx, func(tx *database.Queries) error { + adminUuid, err := db.AddLocalUser(context.Background(), database.AddLocalUserParams{ + Password: "admin", + Email: "admin@localhost", + EmailVerified: false, + Name: "Admin", + Username: "admin", + ChangePassword: true, + }) + if err != nil { + return fmt.Errorf("failed to add user: %w", err) + } + roleId, err := db.AddRole(context.Background(), role.LavenderAdmin) + if err != nil { + return fmt.Errorf("failed to add role: %w", err) + } + err = db.AddUserRole(context.Background(), database.AddUserRoleParams{ + RoleID: roleId, + Subject: adminUuid, + }) + if err != nil { + return fmt.Errorf("failed to add user role: %w", err) + } + return nil + }) + if err != nil { + return err + } + } + return nil +} diff --git a/conf/conf.go b/conf/conf.go index fd8f314..49c09eb 100644 --- a/conf/conf.go +++ b/conf/conf.go @@ -6,12 +6,13 @@ import ( ) type Conf struct { - Listen string `yaml:"listen"` - BaseUrl string `yaml:"baseUrl"` - ServiceName string `yaml:"serviceName"` - Issuer string `yaml:"issuer"` - Kid string `yaml:"kid"` - Namespace string `yaml:"namespace"` - Mail mail.Mail `yaml:"mail"` - SsoServices []issuer.SsoConfig `yaml:"ssoServices"` + Listen string `yaml:"listen"` + BaseUrl string `yaml:"baseUrl"` + ServiceName string `yaml:"serviceName"` + Issuer string `yaml:"issuer"` + Kid string `yaml:"kid"` + Namespace string `yaml:"namespace"` + OtpIssuer string `yaml:"otpIssuer"` + Mail mail.Mail `yaml:"mail"` + SsoServices map[string]issuer.SsoConfig `yaml:"ssoServices"` } diff --git a/database/manage-oauth.sql.go b/database/manage-oauth.sql.go index 02d4a56..7a7e4dc 100644 --- a/database/manage-oauth.sql.go +++ b/database/manage-oauth.sql.go @@ -20,14 +20,14 @@ SELECT subject, active FROM client_store WHERE owner_subject = ? - OR ? = 1 + OR CAST(? AS BOOLEAN) = 1 LIMIT 25 OFFSET ? ` type GetAppListParams struct { - OwnerSubject string `json:"owner_subject"` - Column2 interface{} `json:"column_2"` - Offset int64 `json:"offset"` + OwnerSubject string `json:"owner_subject"` + Column2 bool `json:"column_2"` + Offset int64 `json:"offset"` } type GetAppListRow struct { diff --git a/database/manage-users.sql.go b/database/manage-users.sql.go index 8b0a99a..aa05314 100644 --- a/database/manage-users.sql.go +++ b/database/manage-users.sql.go @@ -7,10 +7,30 @@ package database import ( "context" + "database/sql" "strings" "time" + + "github.com/1f349/lavender/database/types" ) +const addUserRole = `-- name: AddUserRole :exec +INSERT INTO users_roles(role_id, user_id) +SELECT ?, users.id +FROM users +WHERE subject = ? +` + +type AddUserRoleParams struct { + RoleID int64 `json:"role_id"` + Subject string `json:"subject"` +} + +func (q *Queries) AddUserRole(ctx context.Context, arg AddUserRoleParams) error { + _, err := q.db.ExecContext(ctx, addUserRole, arg.RoleID, arg.Subject) + return err +} + const changeUserActive = `-- name: ChangeUserActive :exec UPDATE users SET active = cast(? as boolean) @@ -34,26 +54,24 @@ SELECT users.subject, website, email, email_verified, - users.updated_at as user_updated_at, - p.updated_at as profile_updated_at, + updated_at, active FROM users - INNER JOIN main.profiles p on users.subject = p.subject LIMIT 50 OFFSET ? ` type GetUserListRow struct { - Subject string `json:"subject"` - Name string `json:"name"` - Picture string `json:"picture"` - Website string `json:"website"` - Email string `json:"email"` - EmailVerified bool `json:"email_verified"` - UserUpdatedAt time.Time `json:"user_updated_at"` - ProfileUpdatedAt time.Time `json:"profile_updated_at"` - Active bool `json:"active"` + Subject string `json:"subject"` + Name string `json:"name"` + Picture string `json:"picture"` + Website string `json:"website"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + UpdatedAt time.Time `json:"updated_at"` + Active bool `json:"active"` } +// INNER JOIN main.profiles p on users.subject = p.subject func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListRow, error) { rows, err := q.db.QueryContext(ctx, getUserList, offset) if err != nil { @@ -70,8 +88,7 @@ func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListR &i.Website, &i.Email, &i.EmailVerified, - &i.UserUpdatedAt, - &i.ProfileUpdatedAt, + &i.UpdatedAt, &i.Active, ); err != nil { return nil, err @@ -87,6 +104,25 @@ func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListR return items, nil } +const getUserToken = `-- name: GetUserToken :one +SELECT access_token, refresh_token, token_expiry +FROM users +WHERE subject = ? +` + +type GetUserTokenRow struct { + AccessToken sql.NullString `json:"access_token"` + RefreshToken sql.NullString `json:"refresh_token"` + TokenExpiry sql.NullTime `json:"token_expiry"` +} + +func (q *Queries) GetUserToken(ctx context.Context, subject string) (GetUserTokenRow, error) { + row := q.db.QueryRowContext(ctx, getUserToken, subject) + var i GetUserTokenRow + err := row.Scan(&i.AccessToken, &i.RefreshToken, &i.TokenExpiry) + return i, err +} + const getUsersRoles = `-- name: GetUsersRoles :many SELECT r.role, u.id FROM users_roles @@ -133,6 +169,105 @@ func (q *Queries) GetUsersRoles(ctx context.Context, userIds []int64) ([]GetUser return items, nil } +const modifyUserAuth = `-- name: ModifyUserAuth :exec +UPDATE users +SET auth_type = ?, + auth_namespace=?, + auth_user = ? +WHERE subject = ? +` + +type ModifyUserAuthParams struct { + AuthType types.AuthType `json:"auth_type"` + AuthNamespace string `json:"auth_namespace"` + AuthUser string `json:"auth_user"` + Subject string `json:"subject"` +} + +func (q *Queries) ModifyUserAuth(ctx context.Context, arg ModifyUserAuthParams) error { + _, err := q.db.ExecContext(ctx, modifyUserAuth, + arg.AuthType, + arg.AuthNamespace, + arg.AuthUser, + arg.Subject, + ) + return err +} + +const modifyUserEmail = `-- name: ModifyUserEmail :exec +UPDATE users +SET email = ?, + email_verified=? +WHERE subject = ? +` + +type ModifyUserEmailParams struct { + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Subject string `json:"subject"` +} + +func (q *Queries) ModifyUserEmail(ctx context.Context, arg ModifyUserEmailParams) error { + _, err := q.db.ExecContext(ctx, modifyUserEmail, arg.Email, arg.EmailVerified, arg.Subject) + return err +} + +const modifyUserRemoteLogin = `-- name: ModifyUserRemoteLogin :exec +UPDATE users +SET login = ?, + profile_url = ? +WHERE subject = ? +` + +type ModifyUserRemoteLoginParams struct { + Login string `json:"login"` + ProfileUrl string `json:"profile_url"` + Subject string `json:"subject"` +} + +func (q *Queries) ModifyUserRemoteLogin(ctx context.Context, arg ModifyUserRemoteLoginParams) error { + _, err := q.db.ExecContext(ctx, modifyUserRemoteLogin, arg.Login, arg.ProfileUrl, arg.Subject) + return err +} + +const removeUserRoles = `-- name: RemoveUserRoles :exec +DELETE +FROM users_roles +WHERE user_id IN (SELECT id + FROM users + WHERE subject = ?) +` + +func (q *Queries) RemoveUserRoles(ctx context.Context, subject string) error { + _, err := q.db.ExecContext(ctx, removeUserRoles, subject) + return err +} + +const updateUserToken = `-- name: UpdateUserToken :exec +UPDATE users +SET access_token = ?, + refresh_token=?, + token_expiry = ? +WHERE subject = ? +` + +type UpdateUserTokenParams struct { + AccessToken sql.NullString `json:"access_token"` + RefreshToken sql.NullString `json:"refresh_token"` + TokenExpiry sql.NullTime `json:"token_expiry"` + Subject string `json:"subject"` +} + +func (q *Queries) UpdateUserToken(ctx context.Context, arg UpdateUserTokenParams) error { + _, err := q.db.ExecContext(ctx, updateUserToken, + arg.AccessToken, + arg.RefreshToken, + arg.TokenExpiry, + arg.Subject, + ) + return err +} + const userEmailExists = `-- name: UserEmailExists :one SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND email_verified = 1) == 1 AS email_exists ` @@ -143,3 +278,14 @@ func (q *Queries) UserEmailExists(ctx context.Context, email string) (bool, erro err := row.Scan(&email_exists) return email_exists, err } + +const verifyUserEmail = `-- name: VerifyUserEmail :exec +UPDATE users +SET email_verified=1 +WHERE subject = ? +` + +func (q *Queries) VerifyUserEmail(ctx context.Context, subject string) error { + _, err := q.db.ExecContext(ctx, verifyUserEmail, subject) + return err +} diff --git a/database/migrations/20240820202502_init.up.sql b/database/migrations/20240820202502_init.up.sql index d78164e..3ba0bf3 100644 --- a/database/migrations/20240820202502_init.up.sql +++ b/database/migrations/20240820202502_init.up.sql @@ -21,9 +21,21 @@ CREATE TABLE users zone TEXT NOT NULL DEFAULT 'UTC', locale TEXT NOT NULL DEFAULT 'en-US', + login TEXT NOT NULL DEFAULT '', + profile_url TEXT NOT NULL DEFAULT '', + auth_type INTEGER NOT NULL, auth_namespace TEXT NOT NULL, - auth_user TEXT NOT NULL + auth_user TEXT NOT NULL, + + access_token TEXT NULL DEFAULT NULL, + refresh_token TEXT NULL DEFAULT NULL, + token_expiry DATETIME NULL DEFAULT NULL, + + otp_secret TEXT NOT NULL DEFAULT '', + otp_digits INTEGER NOT NULL DEFAULT 0, + + to_delete BOOLEAN NOT NULL DEFAULT 0 ); CREATE INDEX users_subject ON users (subject); @@ -39,21 +51,12 @@ CREATE TABLE users_roles role_id INTEGER NOT NULL, user_id INTEGER NOT NULL, - FOREIGN KEY (role_id) REFERENCES roles (id), + FOREIGN KEY (role_id) REFERENCES roles (id) ON DELETE RESTRICT, FOREIGN KEY (user_id) REFERENCES users (id), CONSTRAINT user_role UNIQUE (role_id, user_id) ); -CREATE TABLE otp -( - subject INTEGER NOT NULL UNIQUE PRIMARY KEY, - secret TEXT NOT NULL, - digits INTEGER NOT NULL, - - FOREIGN KEY (subject) REFERENCES users (subject) -); - CREATE TABLE client_store ( subject TEXT NOT NULL UNIQUE PRIMARY KEY, diff --git a/database/models.go b/database/models.go index fc0b48f..ce08f56 100644 --- a/database/models.go +++ b/database/models.go @@ -5,9 +5,12 @@ package database import ( + "database/sql" "time" + "github.com/1f349/lavender/database/types" "github.com/1f349/lavender/password" + "github.com/hardfinhq/go-date" ) type ClientStore struct { @@ -22,38 +25,39 @@ type ClientStore struct { Active bool `json:"active"` } -type Otp struct { - Subject int64 `json:"subject"` - Secret string `json:"secret"` - Digits int64 `json:"digits"` -} - -type Profile struct { - Subject string `json:"subject"` - Name string `json:"name"` - Picture string `json:"picture"` - Website string `json:"website"` - Pronouns string `json:"pronouns"` - Birthdate interface{} `json:"birthdate"` - Zone string `json:"zone"` - Locale string `json:"locale"` - UpdatedAt time.Time `json:"updated_at"` -} - type Role struct { ID int64 `json:"id"` Role string `json:"role"` } type User struct { - ID int64 `json:"id"` - Subject string `json:"subject"` - Password password.HashString `json:"password"` - Email string `json:"email"` - EmailVerified bool `json:"email_verified"` - UpdatedAt time.Time `json:"updated_at"` - Registered time.Time `json:"registered"` - Active bool `json:"active"` + ID int64 `json:"id"` + Subject string `json:"subject"` + Password password.HashString `json:"password"` + ChangePassword bool `json:"change_password"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + UpdatedAt time.Time `json:"updated_at"` + Registered time.Time `json:"registered"` + Active bool `json:"active"` + Name string `json:"name"` + Picture string `json:"picture"` + Website string `json:"website"` + Pronouns types.UserPronoun `json:"pronouns"` + Birthdate date.NullDate `json:"birthdate"` + Zone string `json:"zone"` + Locale types.UserLocale `json:"locale"` + Login string `json:"login"` + ProfileUrl string `json:"profile_url"` + AuthType types.AuthType `json:"auth_type"` + AuthNamespace string `json:"auth_namespace"` + AuthUser string `json:"auth_user"` + AccessToken sql.NullString `json:"access_token"` + RefreshToken sql.NullString `json:"refresh_token"` + TokenExpiry sql.NullTime `json:"token_expiry"` + OtpSecret string `json:"otp_secret"` + OtpDigits int64 `json:"otp_digits"` + ToDelete bool `json:"to_delete"` } type UsersRole struct { diff --git a/database/otp.sql.go b/database/otp.sql.go index fc30726..04197cb 100644 --- a/database/otp.sql.go +++ b/database/otp.sql.go @@ -10,31 +10,32 @@ import ( ) const deleteOtp = `-- name: DeleteOtp :exec -DELETE -FROM otp -WHERE otp.subject = ? +UPDATE users +SET otp_secret='', + otp_digits=0 +WHERE subject = ? ` -func (q *Queries) DeleteOtp(ctx context.Context, subject int64) error { +func (q *Queries) DeleteOtp(ctx context.Context, subject string) error { _, err := q.db.ExecContext(ctx, deleteOtp, subject) return err } const getOtp = `-- name: GetOtp :one -SELECT secret, digits -FROM otp +SELECT otp_secret, otp_digits +FROM users WHERE subject = ? ` type GetOtpRow struct { - Secret string `json:"secret"` - Digits int64 `json:"digits"` + OtpSecret string `json:"otp_secret"` + OtpDigits int64 `json:"otp_digits"` } -func (q *Queries) GetOtp(ctx context.Context, subject int64) (GetOtpRow, error) { +func (q *Queries) GetOtp(ctx context.Context, subject string) (GetOtpRow, error) { row := q.db.QueryRowContext(ctx, getOtp, subject) var i GetOtpRow - err := row.Scan(&i.Secret, &i.Digits) + err := row.Scan(&i.OtpSecret, &i.OtpDigits) return i, err } @@ -52,10 +53,13 @@ func (q *Queries) GetUserEmail(ctx context.Context, subject string) (string, err } const hasOtp = `-- name: HasOtp :one -SELECT EXISTS(SELECT 1 FROM otp WHERE subject = ?) == 1 as hasOtp +SELECT CAST(1 AS BOOLEAN) AS hasOtp +FROM users +WHERE subject = ? + AND otp_secret != '' ` -func (q *Queries) HasOtp(ctx context.Context, subject int64) (bool, error) { +func (q *Queries) HasOtp(ctx context.Context, subject string) (bool, error) { row := q.db.QueryRowContext(ctx, hasOtp, subject) var hasotp bool err := row.Scan(&hasotp) @@ -63,19 +67,19 @@ func (q *Queries) HasOtp(ctx context.Context, subject int64) (bool, error) { } const setOtp = `-- name: SetOtp :exec -INSERT OR -REPLACE -INTO otp (subject, secret, digits) -VALUES (?, ?, ?) +UPDATE users +SET otp_secret = ?, + otp_digits=? +WHERE subject = ? ` type SetOtpParams struct { - Subject int64 `json:"subject"` - Secret string `json:"secret"` - Digits int64 `json:"digits"` + OtpSecret string `json:"otp_secret"` + OtpDigits int64 `json:"otp_digits"` + Subject string `json:"subject"` } func (q *Queries) SetOtp(ctx context.Context, arg SetOtpParams) error { - _, err := q.db.ExecContext(ctx, setOtp, arg.Subject, arg.Secret, arg.Digits) + _, err := q.db.ExecContext(ctx, setOtp, arg.OtpSecret, arg.OtpDigits, arg.Subject) return err } diff --git a/database/password-wrapper.go b/database/password-wrapper.go index 07f94ee..f30b251 100644 --- a/database/password-wrapper.go +++ b/database/password-wrapper.go @@ -2,35 +2,69 @@ package database import ( "context" + "github.com/1f349/lavender/database/types" "github.com/1f349/lavender/password" "github.com/google/uuid" "time" ) -type AddUserParams struct { - Name string `json:"name"` - Subject string `json:"subject"` - Password string `json:"password"` - Email string `json:"email"` - EmailVerified bool `json:"email_verified"` - UpdatedAt time.Time `json:"updated_at"` - Active bool `json:"active"` +type AddLocalUserParams struct { + Password string `json:"password"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Name string `json:"name"` + Username string `json:"username"` + ChangePassword bool `json:"change_password"` } -func (q *Queries) AddUser(ctx context.Context, arg AddUserParams) (string, error) { +func (q *Queries) AddLocalUser(ctx context.Context, arg AddLocalUserParams) (string, error) { pwHash, err := password.HashPassword(arg.Password) if err != nil { return "", err } n := time.Now() a := addUserParams{ - Subject: uuid.NewString(), - Password: pwHash, - Email: arg.Email, - EmailVerified: arg.EmailVerified, - UpdatedAt: n, - Registered: n, - Active: true, + Subject: uuid.NewString(), + Password: pwHash, + Email: arg.Email, + EmailVerified: arg.EmailVerified, + UpdatedAt: n, + Registered: n, + Active: true, + Name: arg.Name, + Login: arg.Username, + ChangePassword: arg.ChangePassword, + AuthType: types.AuthTypeLocal, + AuthNamespace: "", + AuthUser: arg.Username, + } + return a.Subject, q.addUser(ctx, a) +} + +type AddOAuthUserParams struct { + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Name string `json:"name"` + Username string `json:"username"` + AuthNamespace string `json:"auth_namespace"` + AuthUser string `json:"auth_user"` +} + +func (q *Queries) AddOAuthUser(ctx context.Context, arg AddOAuthUserParams) (string, error) { + n := time.Now() + a := addUserParams{ + Subject: uuid.NewString(), + Email: arg.Email, + EmailVerified: arg.EmailVerified, + UpdatedAt: n, + Registered: n, + Active: true, + Name: arg.Name, + Login: arg.Username, + ChangePassword: false, + AuthType: types.AuthTypeOauth2, + AuthNamespace: arg.AuthNamespace, + AuthUser: arg.AuthUser, } return a.Subject, q.addUser(ctx, a) } diff --git a/database/profiles.sql.go b/database/profiles.sql.go index fd54f5c..d92f220 100644 --- a/database/profiles.sql.go +++ b/database/profiles.sql.go @@ -8,17 +8,38 @@ package database import ( "context" "time" + + "github.com/1f349/lavender/database/types" + "github.com/hardfinhq/go-date" ) const getProfile = `-- name: GetProfile :one -SELECT profiles.subject, profiles.name, profiles.picture, profiles.website, profiles.pronouns, profiles.birthdate, profiles.zone, profiles.locale, profiles.updated_at -FROM profiles +SELECT subject, + name, + picture, + website, + pronouns, + birthdate, + zone, + locale +FROM users WHERE subject = ? ` -func (q *Queries) GetProfile(ctx context.Context, subject string) (Profile, error) { +type GetProfileRow struct { + Subject string `json:"subject"` + Name string `json:"name"` + Picture string `json:"picture"` + Website string `json:"website"` + Pronouns types.UserPronoun `json:"pronouns"` + Birthdate date.NullDate `json:"birthdate"` + Zone string `json:"zone"` + Locale types.UserLocale `json:"locale"` +} + +func (q *Queries) GetProfile(ctx context.Context, subject string) (GetProfileRow, error) { row := q.db.QueryRowContext(ctx, getProfile, subject) - var i Profile + var i GetProfileRow err := row.Scan( &i.Subject, &i.Name, @@ -28,13 +49,12 @@ func (q *Queries) GetProfile(ctx context.Context, subject string) (Profile, erro &i.Birthdate, &i.Zone, &i.Locale, - &i.UpdatedAt, ) return i, err } const modifyProfile = `-- name: ModifyProfile :exec -UPDATE profiles +UPDATE users SET name = ?, picture = ?, website = ?, @@ -47,15 +67,15 @@ WHERE subject = ? ` type ModifyProfileParams struct { - Name string `json:"name"` - Picture string `json:"picture"` - Website string `json:"website"` - Pronouns string `json:"pronouns"` - Birthdate interface{} `json:"birthdate"` - Zone string `json:"zone"` - Locale string `json:"locale"` - UpdatedAt time.Time `json:"updated_at"` - Subject string `json:"subject"` + Name string `json:"name"` + Picture string `json:"picture"` + Website string `json:"website"` + Pronouns types.UserPronoun `json:"pronouns"` + Birthdate date.NullDate `json:"birthdate"` + Zone string `json:"zone"` + Locale types.UserLocale `json:"locale"` + UpdatedAt time.Time `json:"updated_at"` + Subject string `json:"subject"` } func (q *Queries) ModifyProfile(ctx context.Context, arg ModifyProfileParams) error { diff --git a/database/queries/manage-oauth.sql b/database/queries/manage-oauth.sql index 7225f40..12cb631 100644 --- a/database/queries/manage-oauth.sql +++ b/database/queries/manage-oauth.sql @@ -15,7 +15,7 @@ SELECT subject, active FROM client_store WHERE owner_subject = ? - OR ? = 1 + OR CAST(? AS BOOLEAN) = 1 LIMIT 25 OFFSET ?; -- name: InsertClientApp :exec diff --git a/database/queries/manage-users.sql b/database/queries/manage-users.sql index 587b87e..b9814bb 100644 --- a/database/queries/manage-users.sql +++ b/database/queries/manage-users.sql @@ -5,11 +5,10 @@ SELECT users.subject, website, email, email_verified, - users.updated_at as user_updated_at, - p.updated_at as profile_updated_at, + updated_at, active FROM users - INNER JOIN main.profiles p on users.subject = p.subject +--INNER JOIN main.profiles p on users.subject = p.subject LIMIT 50 OFFSET ?; -- name: GetUsersRoles :many @@ -24,5 +23,54 @@ UPDATE users SET active = cast(? as boolean) WHERE subject = ?; +-- name: VerifyUserEmail :exec +UPDATE users +SET email_verified=1 +WHERE subject = ?; + -- name: UserEmailExists :one SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND email_verified = 1) == 1 AS email_exists; + +-- name: ModifyUserEmail :exec +UPDATE users +SET email = ?, + email_verified=? +WHERE subject = ?; + +-- name: ModifyUserAuth :exec +UPDATE users +SET auth_type = ?, + auth_namespace=?, + auth_user = ? +WHERE subject = ?; + +-- name: ModifyUserRemoteLogin :exec +UPDATE users +SET login = ?, + profile_url = ? +WHERE subject = ?; + +-- name: UpdateUserToken :exec +UPDATE users +SET access_token = ?, + refresh_token=?, + token_expiry = ? +WHERE subject = ?; + +-- name: GetUserToken :one +SELECT access_token, refresh_token, token_expiry +FROM users +WHERE subject = ?; + +-- name: RemoveUserRoles :exec +DELETE +FROM users_roles +WHERE user_id IN (SELECT id + FROM users + WHERE subject = ?); + +-- name: AddUserRole :exec +INSERT INTO users_roles(role_id, user_id) +SELECT ?, users.id +FROM users +WHERE subject = ?; diff --git a/database/queries/otp.sql b/database/queries/otp.sql index 175399a..94599c5 100644 --- a/database/queries/otp.sql +++ b/database/queries/otp.sql @@ -1,21 +1,25 @@ -- name: SetOtp :exec -INSERT OR -REPLACE -INTO otp (subject, secret, digits) -VALUES (?, ?, ?); +UPDATE users +SET otp_secret = ?, + otp_digits=? +WHERE subject = ?; -- name: DeleteOtp :exec -DELETE -FROM otp -WHERE otp.subject = ?; +UPDATE users +SET otp_secret='', + otp_digits=0 +WHERE subject = ?; -- name: GetOtp :one -SELECT secret, digits -FROM otp +SELECT otp_secret, otp_digits +FROM users WHERE subject = ?; -- name: HasOtp :one -SELECT EXISTS(SELECT 1 FROM otp WHERE subject = ?) == 1 as hasOtp; +SELECT CAST(1 AS BOOLEAN) AS hasOtp +FROM users +WHERE subject = ? + AND otp_secret != ''; -- name: GetUserEmail :one SELECT email diff --git a/database/queries/profiles.sql b/database/queries/profiles.sql index 134da89..74203c6 100644 --- a/database/queries/profiles.sql +++ b/database/queries/profiles.sql @@ -1,10 +1,17 @@ -- name: GetProfile :one -SELECT profiles.* -FROM profiles +SELECT subject, + name, + picture, + website, + pronouns, + birthdate, + zone, + locale +FROM users WHERE subject = ?; -- name: ModifyProfile :exec -UPDATE profiles +UPDATE users SET name = ?, picture = ?, website = ?, diff --git a/database/queries/roles.sql b/database/queries/roles.sql new file mode 100644 index 0000000..3c7d54e --- /dev/null +++ b/database/queries/roles.sql @@ -0,0 +1,8 @@ +-- name: AddRole :execlastid +INSERT OR IGNORE INTO roles(role) +VALUES (?); + +-- name: RemoveRole :exec +DELETE +FROM roles +WHERE role = ?; diff --git a/database/queries/users.sql b/database/queries/users.sql index 1a916aa..0d33bfc 100644 --- a/database/queries/users.sql +++ b/database/queries/users.sql @@ -3,15 +3,11 @@ SELECT count(subject) > 0 AS hasUser FROM users; -- name: addUser :exec -INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active) -VALUES (?, ?, ?, ?, ?, ?, ?); - --- name: addOAuthUser :exec -INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active) -VALUES (?, ?, ?, ?, ?, ?, ?); +INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active, name, login, change_password, auth_type, auth_namespace, auth_user) +VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); -- name: checkLogin :one -SELECT subject, password, EXISTS(SELECT 1 FROM otp WHERE otp.subject = users.subject) == 1 AS has_otp, email, email_verified +SELECT subject, password, CAST(otp_secret != '' AS BOOLEAN) AS has_otp, email, email_verified FROM users WHERE users.subject = ? LIMIT 1; @@ -48,3 +44,9 @@ SET password = ?, updated_at=? WHERE subject = ? AND password = ?; + +-- name: FlagUserAsDeleted :exec +UPDATE users +SET active= false, + to_delete = true +WHERE subject = ?; diff --git a/database/roles.sql.go b/database/roles.sql.go new file mode 100644 index 0000000..fa557bd --- /dev/null +++ b/database/roles.sql.go @@ -0,0 +1,34 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: roles.sql + +package database + +import ( + "context" +) + +const addRole = `-- name: AddRole :execlastid +INSERT OR IGNORE INTO roles(role) +VALUES (?) +` + +func (q *Queries) AddRole(ctx context.Context, role string) (int64, error) { + result, err := q.db.ExecContext(ctx, addRole, role) + if err != nil { + return 0, err + } + return result.LastInsertId() +} + +const removeRole = `-- name: RemoveRole :exec +DELETE +FROM roles +WHERE role = ? +` + +func (q *Queries) RemoveRole(ctx context.Context, role string) error { + _, err := q.db.ExecContext(ctx, removeRole, role) + return err +} diff --git a/database/tx.go b/database/tx.go new file mode 100644 index 0000000..15f8522 --- /dev/null +++ b/database/tx.go @@ -0,0 +1,26 @@ +package database + +import ( + "context" + "database/sql" + "errors" +) + +var errCannotOpenTransactionWithoutSqlDB = errors.New("cannot open transaction without sql.DB") + +func (q *Queries) UseTx(ctx context.Context, cb func(tx *Queries) error) error { + sqlDB, ok := q.db.(*sql.DB) + if !ok { + panic(errCannotOpenTransactionWithoutSqlDB) + } + tx, err := sqlDB.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + err = cb(q.WithTx(tx)) + if err != nil { + return err + } + return tx.Commit() +} diff --git a/database/types/authtype.go b/database/types/authtype.go index 903f717..09ee32b 100644 --- a/database/types/authtype.go +++ b/database/types/authtype.go @@ -3,7 +3,7 @@ package types type AuthType byte const ( - AuthTypeBase AuthType = iota + AuthTypeLocal AuthType = iota AuthTypeOauth2 ) diff --git a/database/users.sql.go b/database/users.sql.go index 04c1c02..7b7b6ef 100644 --- a/database/users.sql.go +++ b/database/users.sql.go @@ -9,11 +9,24 @@ import ( "context" "time" + "github.com/1f349/lavender/database/types" "github.com/1f349/lavender/password" ) +const flagUserAsDeleted = `-- name: FlagUserAsDeleted :exec +UPDATE users +SET active= false, + to_delete = true +WHERE subject = ? +` + +func (q *Queries) FlagUserAsDeleted(ctx context.Context, subject string) error { + _, err := q.db.ExecContext(ctx, flagUserAsDeleted, subject) + return err +} + const getUser = `-- name: GetUser :one -SELECT id, subject, password, email, email_verified, updated_at, registered, active +SELECT id, subject, password, change_password, email, email_verified, updated_at, registered, active, name, picture, website, pronouns, birthdate, zone, locale, login, profile_url, auth_type, auth_namespace, auth_user, access_token, refresh_token, token_expiry, otp_secret, otp_digits, to_delete FROM users WHERE subject = ? LIMIT 1 @@ -26,11 +39,30 @@ func (q *Queries) GetUser(ctx context.Context, subject string) (User, error) { &i.ID, &i.Subject, &i.Password, + &i.ChangePassword, &i.Email, &i.EmailVerified, &i.UpdatedAt, &i.Registered, &i.Active, + &i.Name, + &i.Picture, + &i.Website, + &i.Pronouns, + &i.Birthdate, + &i.Zone, + &i.Locale, + &i.Login, + &i.ProfileUrl, + &i.AuthType, + &i.AuthNamespace, + &i.AuthUser, + &i.AccessToken, + &i.RefreshToken, + &i.TokenExpiry, + &i.OtpSecret, + &i.OtpDigits, + &i.ToDelete, ) return i, err } @@ -98,18 +130,24 @@ func (q *Queries) UserHasRole(ctx context.Context, arg UserHasRoleParams) error } const addUser = `-- name: addUser :exec -INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active) -VALUES (?, ?, ?, ?, ?, ?, ?) +INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active, name, login, change_password, auth_type, auth_namespace, auth_user) +VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` type addUserParams struct { - Subject string `json:"subject"` - Password password.HashString `json:"password"` - Email string `json:"email"` - EmailVerified bool `json:"email_verified"` - UpdatedAt time.Time `json:"updated_at"` - Registered time.Time `json:"registered"` - Active bool `json:"active"` + Subject string `json:"subject"` + Password password.HashString `json:"password"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + UpdatedAt time.Time `json:"updated_at"` + Registered time.Time `json:"registered"` + Active bool `json:"active"` + Name string `json:"name"` + Login string `json:"login"` + ChangePassword bool `json:"change_password"` + AuthType types.AuthType `json:"auth_type"` + AuthNamespace string `json:"auth_namespace"` + AuthUser string `json:"auth_user"` } func (q *Queries) addUser(ctx context.Context, arg addUserParams) error { @@ -121,6 +159,12 @@ func (q *Queries) addUser(ctx context.Context, arg addUserParams) error { arg.UpdatedAt, arg.Registered, arg.Active, + arg.Name, + arg.Login, + arg.ChangePassword, + arg.AuthType, + arg.AuthNamespace, + arg.AuthUser, ) return err } @@ -151,7 +195,7 @@ func (q *Queries) changeUserPassword(ctx context.Context, arg changeUserPassword } const checkLogin = `-- name: checkLogin :one -SELECT subject, password, EXISTS(SELECT 1 FROM otp WHERE otp.subject = users.subject) == 1 AS has_otp, email, email_verified +SELECT subject, password, CAST(otp_secret != '' AS BOOLEAN) AS has_otp, email, email_verified FROM users WHERE users.subject = ? LIMIT 1 diff --git a/go.mod b/go.mod index dc2b926..63fa064 100644 --- a/go.mod +++ b/go.mod @@ -6,12 +6,10 @@ require ( github.com/1f349/cache v0.0.3 github.com/1f349/mjwt v0.4.1 github.com/1f349/overlapfs v0.0.1 - github.com/1f349/tulip v0.0.0-20240725211619-6b19e2d4ca63 + github.com/1f349/simplemail v0.0.5 github.com/charmbracelet/log v0.4.0 github.com/cloudflare/tableflip v1.2.3 github.com/emersion/go-message v0.18.1 - github.com/emersion/go-sasl v0.0.0-20231106173351-e73c9f7bad43 - github.com/emersion/go-smtp v0.21.3 github.com/go-oauth2/oauth2/v4 v4.5.2 github.com/golang-jwt/jwt/v4 v4.5.0 github.com/golang-migrate/migrate/v4 v4.17.1 @@ -21,10 +19,13 @@ require ( github.com/julienschmidt/httprouter v1.3.0 github.com/mattn/go-sqlite3 v1.14.22 github.com/mrmelon54/pronouns v1.0.3 + github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/spf13/afero v1.11.0 github.com/stretchr/testify v1.9.0 + github.com/xlzd/gotp v0.1.0 golang.org/x/crypto v0.26.0 golang.org/x/oauth2 v0.22.0 + golang.org/x/sync v0.8.0 golang.org/x/text v0.17.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -36,6 +37,8 @@ require ( github.com/charmbracelet/lipgloss v0.12.1 // indirect github.com/charmbracelet/x/ansi v0.2.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/emersion/go-sasl v0.0.0-20231106173351-e73c9f7bad43 // indirect + github.com/emersion/go-smtp v0.21.3 // indirect github.com/go-jose/go-jose/v4 v4.0.4 // indirect github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect @@ -63,6 +66,5 @@ require ( go.uber.org/atomic v1.11.0 // indirect golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect golang.org/x/net v0.28.0 // indirect - golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.24.0 // indirect ) diff --git a/go.sum b/go.sum index 98374f7..7717885 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,8 @@ github.com/1f349/overlapfs v0.0.1 h1:LAxBolrXFAgU0yqZtXg/C/aaPq3eoQSPpBc49BHuTp0 github.com/1f349/overlapfs v0.0.1/go.mod h1:I6aItQycr7nrzplmfNXp/QF9tTmKRSgY3fXmu/7Ky2o= github.com/1f349/rsa-helper v0.0.2 h1:N/fLQqg5wrjIzG6G4zdwa5Xcv9/jIPutCls9YekZr9U= github.com/1f349/rsa-helper v0.0.2/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg= -github.com/1f349/tulip v0.0.0-20240725211619-6b19e2d4ca63 h1:jPg+0bgKD5kY7yQtRZqeba+BGKFE51evGvwewZwa7Xc= -github.com/1f349/tulip v0.0.0-20240725211619-6b19e2d4ca63/go.mod h1:1zFQhcbgiyPSWHVMp0cXJjmd6FhasP5bf5tWS4ZK61A= +github.com/1f349/simplemail v0.0.5 h1:cr+8pdWhFE/+XVSO7ZTjntySbmIbTqmDy2SR9cHAPLE= +github.com/1f349/simplemail v0.0.5/go.mod h1:ppAIqkvVkI6L99EefbR5NgOjpePNK/RKgeoehj5A+kU= github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= @@ -146,6 +146,8 @@ github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99 github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw= github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= +github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= +github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= @@ -199,6 +201,8 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHo github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +github.com/xlzd/gotp v0.1.0 h1:37blvlKCh38s+fkem+fFh7sMnceltoIEBYTVXyoa5Po= +github.com/xlzd/gotp v0.1.0/go.mod h1:ndLJ3JKzi3xLmUProq4LLxCuECL93dG9WASNLpHz8qg= github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 h1:6fRhSjgLCkTD3JnJxvaJ4Sj+TYblw757bqYgZaOq5ZY= github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI= github.com/yudai/gojsondiff v1.0.0 h1:27cbfqXLVEJ1o8I6v3y9lg8Ydm53EKqHXAOMxEGlCOA= diff --git a/issuer/manager.go b/issuer/manager.go index 87b74dc..8520c15 100644 --- a/issuer/manager.go +++ b/issuer/manager.go @@ -25,6 +25,7 @@ func NewManager(services map[string]SsoConfig) (*Manager, error) { } // save by namespace + conf.Namespace = namespace l.m[namespace] = conf } return l, nil diff --git a/mail/from-address.go b/mail/from-address.go deleted file mode 100644 index e52f5f8..0000000 --- a/mail/from-address.go +++ /dev/null @@ -1,26 +0,0 @@ -package mail - -import ( - "encoding/json" - "github.com/emersion/go-message/mail" -) - -type FromAddress struct { - *mail.Address -} - -var _ json.Unmarshaler = &FromAddress{} - -func (f *FromAddress) UnmarshalJSON(b []byte) error { - var a string - err := json.Unmarshal(b, &a) - if err != nil { - return err - } - address, err := mail.ParseAddress(a) - if err != nil { - return err - } - f.Address = address - return nil -} diff --git a/mail/mail.go b/mail/mail.go index 8403664..dca2e29 100644 --- a/mail/mail.go +++ b/mail/mail.go @@ -1,96 +1,48 @@ package mail import ( - "bytes" + "embed" + "errors" + "fmt" + "github.com/1f349/overlapfs" + "github.com/1f349/simplemail" "github.com/emersion/go-message/mail" - "github.com/emersion/go-sasl" - "github.com/emersion/go-smtp" - "io" - "net" - "time" + "io/fs" + "os" + "path/filepath" ) +//go:embed templates/*.go.html templates/*.go.txt +var embeddedTemplates embed.FS + type Mail struct { - Name string `json:"name"` - Tls bool `json:"tls"` - Server string `json:"server"` - From FromAddress `json:"from"` - Username string `json:"username"` - Password string `json:"password"` + mail *simplemail.SimpleMail + name string } -func (m *Mail) loginInfo() sasl.Client { - return sasl.NewPlainClient("", m.Username, m.Password) -} - -func (m *Mail) mailCall(to []string, r io.Reader) error { - host, _, err := net.SplitHostPort(m.Server) - if err != nil { - return err - } - if m.Tls { - return smtp.SendMailTLS(m.Server, m.loginInfo(), m.From.String(), to, r) - } - if host == "localhost" || host == "127.0.0.1" { - // internals of smtp.SendMail without STARTTLS for localhost testing - dial, err := smtp.Dial(m.Server) - if err != nil { - return err +func New(sender *simplemail.Mail, wd, name string) (*Mail, error) { + var o fs.FS = embeddedTemplates + o, _ = fs.Sub(o, "templates") + if wd != "" { + mailDir := filepath.Join(wd, "mail-templates") + err := os.Mkdir(mailDir, os.ModePerm) + if err == nil || errors.Is(err, os.ErrExist) { + wdFs := os.DirFS(mailDir) + o = overlapfs.OverlapFS{A: embeddedTemplates, B: wdFs} } - err = dial.Auth(m.loginInfo()) - if err != nil { - return err - } - return dial.SendMail(m.From.String(), to, r) } - return smtp.SendMail(m.Server, m.loginInfo(), m.From.String(), to, r) + + simpleMail, err := simplemail.New(sender, o) + return &Mail{ + mail: simpleMail, + name: name, + }, err } -func (m *Mail) SendMail(subject string, to []*mail.Address, htmlBody, textBody io.Reader) error { - // generate the email in this template - buf := new(bytes.Buffer) - - // setup mail headers - var h mail.Header - h.SetDate(time.Now()) - h.SetSubject(subject) - h.SetAddressList("From", []*mail.Address{m.From.Address}) - h.SetAddressList("To", to) - h.Set("Content-Type", "multipart/alternative") - - // setup html and text alternative headers - var hHtml, hTxt mail.InlineHeader - hHtml.Set("Content-Type", "text/html; charset=utf-8") - hTxt.Set("Content-Type", "text/plain; charset=utf-8") - - createWriter, err := mail.CreateWriter(buf, h) - if err != nil { - return err - } - inline, err := createWriter.CreateInline() - if err != nil { - return err - } - partHtml, err := inline.CreatePart(hHtml) - if err != nil { - return err - } - if _, err := io.Copy(partHtml, htmlBody); err != nil { - return err - } - partTxt, err := inline.CreatePart(hTxt) - if err != nil { - return err - } - if _, err := io.Copy(partTxt, textBody); err != nil { - return err - } - - // convert all to addresses to strings - toStr := make([]string, len(to)) - for i := range toStr { - toStr[i] = to[i].String() - } - - return m.mailCall(toStr, buf) +func (m *Mail) SendEmailTemplate(templateName, subject, nameOfUser string, to *mail.Address, data map[string]any) error { + return m.mail.Send(templateName, fmt.Sprintf("%s - %s", subject, m.name), to, map[string]any{ + "ServiceName": m.name, + "Name": nameOfUser, + "Data": data, + }) } diff --git a/mail/send-template.go b/mail/send-template.go deleted file mode 100644 index 5f2c22f..0000000 --- a/mail/send-template.go +++ /dev/null @@ -1,18 +0,0 @@ -package mail - -import ( - "bytes" - "fmt" - "github.com/1f349/lavender/mail/templates" - "github.com/emersion/go-message/mail" -) - -func (m *Mail) SendEmailTemplate(templateName, subject, nameOfUser string, to *mail.Address, data map[string]any) error { - var bufHtml, bufTxt bytes.Buffer - templates.RenderMailTemplate(&bufHtml, &bufTxt, templateName, map[string]any{ - "ServiceName": m.Name, - "Name": nameOfUser, - "Data": data, - }) - return m.SendMail(fmt.Sprintf("%s - %s", subject, m.Name), []*mail.Address{to}, &bufHtml, &bufTxt) -} diff --git a/mail/templates/templates.go b/mail/templates/templates.go deleted file mode 100644 index ed82df3..0000000 --- a/mail/templates/templates.go +++ /dev/null @@ -1,55 +0,0 @@ -package templates - -import ( - "embed" - "errors" - "github.com/1f349/overlapfs" - "github.com/1f349/tulip/logger" - htmlTemplate "html/template" - "io" - "io/fs" - "os" - "path/filepath" - "sync" - textTemplate "text/template" -) - -var ( - //go:embed *.go.html *.go.txt - embeddedTemplates embed.FS - mailHtmlTemplates *htmlTemplate.Template - mailTextTemplates *textTemplate.Template - loadOnce sync.Once -) - -func LoadMailTemplates(wd string) (err error) { - loadOnce.Do(func() { - var o fs.FS = embeddedTemplates - if wd != "" { - mailDir := filepath.Join(wd, "mail-templates") - err = os.Mkdir(mailDir, os.ModePerm) - if err != nil && !errors.Is(err, os.ErrExist) { - return - } - wdFs := os.DirFS(mailDir) - o = overlapfs.OverlapFS{A: embeddedTemplates, B: wdFs} - } - mailHtmlTemplates, err = htmlTemplate.New("mail").ParseFS(o, "*.go.html") - if err != nil { - return - } - mailTextTemplates, err = textTemplate.New("mail").ParseFS(o, "*.go.txt") - }) - return -} - -func RenderMailTemplate(wrHtml, wrTxt io.Writer, name string, data any) { - err := mailHtmlTemplates.ExecuteTemplate(wrHtml, name+".go.html", data) - if err != nil { - logger.Logger.Warn("Failed to render mail html", "name", name, "err", err) - } - err = mailTextTemplates.ExecuteTemplate(wrTxt, name+".go.txt", data) - if err != nil { - logger.Logger.Warn("Failed to render mail text", "name", name, "err", err) - } -} diff --git a/server/auth.go b/server/auth.go index 60e222f..79ebad8 100644 --- a/server/auth.go +++ b/server/auth.go @@ -59,7 +59,7 @@ func (h *httpServer) RequireAdminAuthentication(next UserHandler) httprouter.Han } func (h *httpServer) RequireAuthentication(next UserHandler) httprouter.Handle { - return h.OptionalAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { + return h.OptionalAuthentication(false, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { if auth.IsGuest() { redirectUrl := PrepareRedirectUrl("/login", req.URL) http.Redirect(rw, req, redirectUrl.String(), http.StatusFound) @@ -69,16 +69,20 @@ func (h *httpServer) RequireAuthentication(next UserHandler) httprouter.Handle { }) } -func (h *httpServer) OptionalAuthentication(next UserHandler) httprouter.Handle { +func (h *httpServer) OptionalAuthentication(flowPart bool, next UserHandler) httprouter.Handle { return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { - authUser, err := h.internalAuthenticationHandler(rw, req) + authData, err := h.internalAuthenticationHandler(rw, req) if err != nil { if !errors.Is(err, ErrAuthHttpError) { http.Error(rw, err.Error(), http.StatusInternalServerError) } return } - next(rw, req, params, authUser) + if n := authData.NextFlowUrl(req.URL); n != nil && !flowPart { + http.Redirect(rw, req, n.String(), http.StatusFound) + return + } + next(rw, req, params, authData) } } diff --git a/server/auth_test.go b/server/auth_test.go new file mode 100644 index 0000000..68b6603 --- /dev/null +++ b/server/auth_test.go @@ -0,0 +1,73 @@ +package server + +import ( + "context" + "github.com/1f349/mjwt" + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestUserAuth_NextFlowUrl(t *testing.T) { + u := UserAuth{NeedOtp: true} + assert.Equal(t, url.URL{Path: "/login/otp"}, *u.NextFlowUrl(&url.URL{})) + assert.Equal(t, url.URL{Path: "/login/otp", RawQuery: url.Values{"redirect": {"/hello"}}.Encode()}, *u.NextFlowUrl(&url.URL{Path: "/hello"})) + assert.Equal(t, url.URL{Path: "/login/otp", RawQuery: url.Values{"redirect": {"/hello?a=A"}}.Encode()}, *u.NextFlowUrl(&url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}}.Encode()})) + u.NeedOtp = false + assert.Nil(t, u.NextFlowUrl(&url.URL{})) +} + +func TestUserAuth_IsGuest(t *testing.T) { + var u UserAuth + assert.True(t, u.IsGuest()) + u.Subject = uuid.NewString() + assert.False(t, u.IsGuest()) +} + +type fakeSessionStore struct { + m map[string]any + saveFunc func(map[string]any) error +} + +func (f *fakeSessionStore) Context() context.Context { return context.Background() } +func (f *fakeSessionStore) SessionID() string { return "fakeSessionStore" } +func (f *fakeSessionStore) Set(key string, value interface{}) { f.m[key] = value } + +func (f *fakeSessionStore) Get(key string) (a interface{}, ok bool) { + if a, ok = f.m[key]; false { + } + return +} + +func TestRequireAuthentication(t *testing.T) { +} + +func TestOptionalAuthentication(t *testing.T) { + jwtIssuer, err := mjwt.NewIssuer("TestIssuer", uuid.NewString(), jwt.SigningMethodRS512) + h := &httpServer{signingKey: jwtIssuer} + rec := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "https://example.com/hello", nil) + assert.NoError(t, err) + auth, err := h.internalAuthenticationHandler(rec, req) + assert.NoError(t, err) + assert.True(t, auth.IsGuest()) + auth.Subject = "567" +} + +func TestPrepareRedirectUrl(t *testing.T) { + assert.Equal(t, url.URL{Path: "/hello"}, *PrepareRedirectUrl("/hello", &url.URL{})) + assert.Equal(t, url.URL{Path: "/world"}, *PrepareRedirectUrl("/world", &url.URL{})) + assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"redirect": {"/hello"}}.Encode()}, *PrepareRedirectUrl("/a", &url.URL{Path: "/hello"})) + assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"redirect": {"/hello?a=A"}}.Encode()}, *PrepareRedirectUrl("/a", &url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}}.Encode()})) + assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"redirect": {"/hello?a=A&b=B"}}.Encode()}, *PrepareRedirectUrl("/a", &url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}, "b": {"B"}}.Encode()})) + + assert.Equal(t, url.URL{Path: "/hello", RawQuery: "z=y"}, *PrepareRedirectUrl("/hello?z=y", &url.URL{})) + assert.Equal(t, url.URL{Path: "/world", RawQuery: "z=y"}, *PrepareRedirectUrl("/world?z=y", &url.URL{})) + assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"z": {"y"}, "redirect": {"/hello"}}.Encode()}, *PrepareRedirectUrl("/a?z=y", &url.URL{Path: "/hello"})) + assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"z": {"y"}, "redirect": {"/hello?a=A"}}.Encode()}, *PrepareRedirectUrl("/a?z=y", &url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}}.Encode()})) + assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"z": {"y"}, "redirect": {"/hello?a=A&b=B"}}.Encode()}, *PrepareRedirectUrl("/a?z=y", &url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}, "b": {"B"}}.Encode()})) +} diff --git a/server/edit.go b/server/edit.go new file mode 100644 index 0000000..981cc0d --- /dev/null +++ b/server/edit.go @@ -0,0 +1,87 @@ +package server + +import ( + "fmt" + "github.com/1f349/lavender/database" + "github.com/1f349/lavender/lists" + "github.com/1f349/lavender/pages" + "github.com/google/uuid" + "github.com/julienschmidt/httprouter" + "net/http" + "time" +) + +func (h *httpServer) EditGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { + var user database.User + + if h.DbTx(rw, func(tx *database.Queries) error { + var err error + user, err = tx.GetUser(req.Context(), auth.Subject) + if err != nil { + return fmt.Errorf("failed to read user data: %w", err) + } + return nil + }) { + return + } + + lNonce := uuid.NewString() + http.SetCookie(rw, &http.Cookie{ + Name: "tulip-nonce", + Value: lNonce, + Path: "/", + Expires: time.Now().Add(10 * time.Minute), + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + pages.RenderPageTemplate(rw, "edit", map[string]any{ + "ServiceName": h.conf.ServiceName, + "User": user, + "Nonce": lNonce, + "FieldPronoun": user.Pronouns.String(), + "ListZoneInfo": lists.ListZoneInfo(), + "ListLocale": lists.ListLocale(), + }) +} +func (h *httpServer) EditPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { + if req.ParseForm() != nil { + rw.WriteHeader(http.StatusBadRequest) + _, _ = rw.Write([]byte("400 Bad Request\n")) + return + } + + var patch database.ProfilePatch + errs := patch.ParseFromForm(req.Form) + if len(errs) > 0 { + rw.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprintln(rw, "\n\n
") + _, _ = fmt.Fprintln(rw, "400 Bad Request: Failed to parse form data, press the back button in your browser, check your inputs and try again.
") + _, _ = fmt.Fprintln(rw, "