From d25f9ae2cace21219fc11b666808645ec99c8433 Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Sat, 5 Oct 2024 21:08:02 +0100 Subject: [PATCH] Fix a bunch more compile breaking issues --- auth/userinfofields.go | 19 ++ cmd/lavender/serve.go | 49 +++++ conf/conf.go | 17 +- database/manage-oauth.sql.go | 8 +- database/manage-users.sql.go | 174 ++++++++++++++-- .../migrations/20240820202502_init.up.sql | 25 ++- database/models.go | 56 ++--- database/otp.sql.go | 44 ++-- database/password-wrapper.go | 66 ++++-- database/profiles.sql.go | 50 +++-- database/queries/manage-oauth.sql | 2 +- database/queries/manage-users.sql | 54 ++++- database/queries/otp.sql | 24 ++- database/queries/profiles.sql | 13 +- database/queries/roles.sql | 8 + database/queries/users.sql | 16 +- database/roles.sql.go | 34 +++ database/tx.go | 26 +++ database/types/authtype.go | 2 +- database/users.sql.go | 66 +++++- go.mod | 10 +- go.sum | 8 +- issuer/manager.go | 1 + mail/from-address.go | 26 --- mail/mail.go | 116 +++-------- mail/send-template.go | 18 -- mail/templates/templates.go | 55 ----- server/auth.go | 12 +- server/auth_test.go | 73 +++++++ server/edit.go | 87 ++++++++ server/home.go | 47 +++++ server/login.go | 119 ++++++++--- server/logout.go | 25 +++ server/mail.go | 123 +++++++++++ server/manage-apps.go | 60 +++--- server/manage-users.go | 40 +++- server/oauth.go | 130 +++++++++++- server/otp.go | 196 ++++++++++++++++++ server/roles_test.go | 1 - server/server.go | 60 ++---- sqlc.yaml | 8 + utils/age.go | 28 +++ utils/age_test.go | 30 +++ 43 files changed, 1574 insertions(+), 452 deletions(-) create mode 100644 database/queries/roles.sql create mode 100644 database/roles.sql.go create mode 100644 database/tx.go delete mode 100644 mail/from-address.go delete mode 100644 mail/send-template.go delete mode 100644 mail/templates/templates.go create mode 100644 server/auth_test.go create mode 100644 server/edit.go create mode 100644 server/logout.go create mode 100644 server/mail.go create mode 100644 server/otp.go create mode 100644 utils/age.go create mode 100644 utils/age_test.go diff --git a/auth/userinfofields.go b/auth/userinfofields.go index 7f2093c..3411eef 100644 --- a/auth/userinfofields.go +++ b/auth/userinfofields.go @@ -1,5 +1,7 @@ package auth +import "github.com/hardfinhq/go-date" + type UserInfoFields map[string]any func (u UserInfoFields) GetString(key string) (string, bool) { @@ -20,7 +22,24 @@ func (u UserInfoFields) GetStringOrEmpty(key string) string { return s } +func (u UserInfoFields) GetStringFromKeysOrEmpty(keys ...string) string { + for _, key := range keys { + s, _ := u[key].(string) + if s == "" { + continue + } + return s + } + return "" +} + func (u UserInfoFields) GetBoolean(key string) (bool, bool) { b, ok := u[key].(bool) return b, ok } + +func (u UserInfoFields) GetNullDate(key string) date.NullDate { + s, _ := u[key].(string) + fromStr, err := date.FromString(s) + return date.NullDate{Date: fromStr, Valid: err == nil} +} diff --git a/cmd/lavender/serve.go b/cmd/lavender/serve.go index ebdb20a..f32367f 100644 --- a/cmd/lavender/serve.go +++ b/cmd/lavender/serve.go @@ -3,10 +3,13 @@ package main import ( "context" "flag" + "fmt" "github.com/1f349/lavender" "github.com/1f349/lavender/conf" + "github.com/1f349/lavender/database" "github.com/1f349/lavender/logger" "github.com/1f349/lavender/pages" + "github.com/1f349/lavender/role" "github.com/1f349/lavender/server" "github.com/1f349/mjwt" "github.com/charmbracelet/log" @@ -114,6 +117,10 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) logger.Logger.Fatal("Failed to open database", "err", err) } + if err := checkDbHasUser(db); err != nil { + logger.Logger.Fatal("Failed to add initial user", "err", err) + } + if err := pages.LoadPages(wd); err != nil { logger.Logger.Fatal("Failed to load page templates:", err) } @@ -168,3 +175,45 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) return subcommands.ExitSuccess } + +func checkDbHasUser(db *database.Queries) error { + value, err := db.HasUser(context.Background()) + if err != nil { + return err + } + + if !value { + logger.Logger.Warn("No users are available, setting up initial admin user") + + ctx := context.Background() + err = db.UseTx(ctx, func(tx *database.Queries) error { + adminUuid, err := db.AddLocalUser(context.Background(), database.AddLocalUserParams{ + Password: "admin", + Email: "admin@localhost", + EmailVerified: false, + Name: "Admin", + Username: "admin", + ChangePassword: true, + }) + if err != nil { + return fmt.Errorf("failed to add user: %w", err) + } + roleId, err := db.AddRole(context.Background(), role.LavenderAdmin) + if err != nil { + return fmt.Errorf("failed to add role: %w", err) + } + err = db.AddUserRole(context.Background(), database.AddUserRoleParams{ + RoleID: roleId, + Subject: adminUuid, + }) + if err != nil { + return fmt.Errorf("failed to add user role: %w", err) + } + return nil + }) + if err != nil { + return err + } + } + return nil +} diff --git a/conf/conf.go b/conf/conf.go index fd8f314..49c09eb 100644 --- a/conf/conf.go +++ b/conf/conf.go @@ -6,12 +6,13 @@ import ( ) type Conf struct { - Listen string `yaml:"listen"` - BaseUrl string `yaml:"baseUrl"` - ServiceName string `yaml:"serviceName"` - Issuer string `yaml:"issuer"` - Kid string `yaml:"kid"` - Namespace string `yaml:"namespace"` - Mail mail.Mail `yaml:"mail"` - SsoServices []issuer.SsoConfig `yaml:"ssoServices"` + Listen string `yaml:"listen"` + BaseUrl string `yaml:"baseUrl"` + ServiceName string `yaml:"serviceName"` + Issuer string `yaml:"issuer"` + Kid string `yaml:"kid"` + Namespace string `yaml:"namespace"` + OtpIssuer string `yaml:"otpIssuer"` + Mail mail.Mail `yaml:"mail"` + SsoServices map[string]issuer.SsoConfig `yaml:"ssoServices"` } diff --git a/database/manage-oauth.sql.go b/database/manage-oauth.sql.go index 02d4a56..7a7e4dc 100644 --- a/database/manage-oauth.sql.go +++ b/database/manage-oauth.sql.go @@ -20,14 +20,14 @@ SELECT subject, active FROM client_store WHERE owner_subject = ? - OR ? = 1 + OR CAST(? AS BOOLEAN) = 1 LIMIT 25 OFFSET ? ` type GetAppListParams struct { - OwnerSubject string `json:"owner_subject"` - Column2 interface{} `json:"column_2"` - Offset int64 `json:"offset"` + OwnerSubject string `json:"owner_subject"` + Column2 bool `json:"column_2"` + Offset int64 `json:"offset"` } type GetAppListRow struct { diff --git a/database/manage-users.sql.go b/database/manage-users.sql.go index 8b0a99a..aa05314 100644 --- a/database/manage-users.sql.go +++ b/database/manage-users.sql.go @@ -7,10 +7,30 @@ package database import ( "context" + "database/sql" "strings" "time" + + "github.com/1f349/lavender/database/types" ) +const addUserRole = `-- name: AddUserRole :exec +INSERT INTO users_roles(role_id, user_id) +SELECT ?, users.id +FROM users +WHERE subject = ? +` + +type AddUserRoleParams struct { + RoleID int64 `json:"role_id"` + Subject string `json:"subject"` +} + +func (q *Queries) AddUserRole(ctx context.Context, arg AddUserRoleParams) error { + _, err := q.db.ExecContext(ctx, addUserRole, arg.RoleID, arg.Subject) + return err +} + const changeUserActive = `-- name: ChangeUserActive :exec UPDATE users SET active = cast(? as boolean) @@ -34,26 +54,24 @@ SELECT users.subject, website, email, email_verified, - users.updated_at as user_updated_at, - p.updated_at as profile_updated_at, + updated_at, active FROM users - INNER JOIN main.profiles p on users.subject = p.subject LIMIT 50 OFFSET ? ` type GetUserListRow struct { - Subject string `json:"subject"` - Name string `json:"name"` - Picture string `json:"picture"` - Website string `json:"website"` - Email string `json:"email"` - EmailVerified bool `json:"email_verified"` - UserUpdatedAt time.Time `json:"user_updated_at"` - ProfileUpdatedAt time.Time `json:"profile_updated_at"` - Active bool `json:"active"` + Subject string `json:"subject"` + Name string `json:"name"` + Picture string `json:"picture"` + Website string `json:"website"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + UpdatedAt time.Time `json:"updated_at"` + Active bool `json:"active"` } +// INNER JOIN main.profiles p on users.subject = p.subject func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListRow, error) { rows, err := q.db.QueryContext(ctx, getUserList, offset) if err != nil { @@ -70,8 +88,7 @@ func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListR &i.Website, &i.Email, &i.EmailVerified, - &i.UserUpdatedAt, - &i.ProfileUpdatedAt, + &i.UpdatedAt, &i.Active, ); err != nil { return nil, err @@ -87,6 +104,25 @@ func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListR return items, nil } +const getUserToken = `-- name: GetUserToken :one +SELECT access_token, refresh_token, token_expiry +FROM users +WHERE subject = ? +` + +type GetUserTokenRow struct { + AccessToken sql.NullString `json:"access_token"` + RefreshToken sql.NullString `json:"refresh_token"` + TokenExpiry sql.NullTime `json:"token_expiry"` +} + +func (q *Queries) GetUserToken(ctx context.Context, subject string) (GetUserTokenRow, error) { + row := q.db.QueryRowContext(ctx, getUserToken, subject) + var i GetUserTokenRow + err := row.Scan(&i.AccessToken, &i.RefreshToken, &i.TokenExpiry) + return i, err +} + const getUsersRoles = `-- name: GetUsersRoles :many SELECT r.role, u.id FROM users_roles @@ -133,6 +169,105 @@ func (q *Queries) GetUsersRoles(ctx context.Context, userIds []int64) ([]GetUser return items, nil } +const modifyUserAuth = `-- name: ModifyUserAuth :exec +UPDATE users +SET auth_type = ?, + auth_namespace=?, + auth_user = ? +WHERE subject = ? +` + +type ModifyUserAuthParams struct { + AuthType types.AuthType `json:"auth_type"` + AuthNamespace string `json:"auth_namespace"` + AuthUser string `json:"auth_user"` + Subject string `json:"subject"` +} + +func (q *Queries) ModifyUserAuth(ctx context.Context, arg ModifyUserAuthParams) error { + _, err := q.db.ExecContext(ctx, modifyUserAuth, + arg.AuthType, + arg.AuthNamespace, + arg.AuthUser, + arg.Subject, + ) + return err +} + +const modifyUserEmail = `-- name: ModifyUserEmail :exec +UPDATE users +SET email = ?, + email_verified=? +WHERE subject = ? +` + +type ModifyUserEmailParams struct { + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Subject string `json:"subject"` +} + +func (q *Queries) ModifyUserEmail(ctx context.Context, arg ModifyUserEmailParams) error { + _, err := q.db.ExecContext(ctx, modifyUserEmail, arg.Email, arg.EmailVerified, arg.Subject) + return err +} + +const modifyUserRemoteLogin = `-- name: ModifyUserRemoteLogin :exec +UPDATE users +SET login = ?, + profile_url = ? +WHERE subject = ? +` + +type ModifyUserRemoteLoginParams struct { + Login string `json:"login"` + ProfileUrl string `json:"profile_url"` + Subject string `json:"subject"` +} + +func (q *Queries) ModifyUserRemoteLogin(ctx context.Context, arg ModifyUserRemoteLoginParams) error { + _, err := q.db.ExecContext(ctx, modifyUserRemoteLogin, arg.Login, arg.ProfileUrl, arg.Subject) + return err +} + +const removeUserRoles = `-- name: RemoveUserRoles :exec +DELETE +FROM users_roles +WHERE user_id IN (SELECT id + FROM users + WHERE subject = ?) +` + +func (q *Queries) RemoveUserRoles(ctx context.Context, subject string) error { + _, err := q.db.ExecContext(ctx, removeUserRoles, subject) + return err +} + +const updateUserToken = `-- name: UpdateUserToken :exec +UPDATE users +SET access_token = ?, + refresh_token=?, + token_expiry = ? +WHERE subject = ? +` + +type UpdateUserTokenParams struct { + AccessToken sql.NullString `json:"access_token"` + RefreshToken sql.NullString `json:"refresh_token"` + TokenExpiry sql.NullTime `json:"token_expiry"` + Subject string `json:"subject"` +} + +func (q *Queries) UpdateUserToken(ctx context.Context, arg UpdateUserTokenParams) error { + _, err := q.db.ExecContext(ctx, updateUserToken, + arg.AccessToken, + arg.RefreshToken, + arg.TokenExpiry, + arg.Subject, + ) + return err +} + const userEmailExists = `-- name: UserEmailExists :one SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND email_verified = 1) == 1 AS email_exists ` @@ -143,3 +278,14 @@ func (q *Queries) UserEmailExists(ctx context.Context, email string) (bool, erro err := row.Scan(&email_exists) return email_exists, err } + +const verifyUserEmail = `-- name: VerifyUserEmail :exec +UPDATE users +SET email_verified=1 +WHERE subject = ? +` + +func (q *Queries) VerifyUserEmail(ctx context.Context, subject string) error { + _, err := q.db.ExecContext(ctx, verifyUserEmail, subject) + return err +} diff --git a/database/migrations/20240820202502_init.up.sql b/database/migrations/20240820202502_init.up.sql index d78164e..3ba0bf3 100644 --- a/database/migrations/20240820202502_init.up.sql +++ b/database/migrations/20240820202502_init.up.sql @@ -21,9 +21,21 @@ CREATE TABLE users zone TEXT NOT NULL DEFAULT 'UTC', locale TEXT NOT NULL DEFAULT 'en-US', + login TEXT NOT NULL DEFAULT '', + profile_url TEXT NOT NULL DEFAULT '', + auth_type INTEGER NOT NULL, auth_namespace TEXT NOT NULL, - auth_user TEXT NOT NULL + auth_user TEXT NOT NULL, + + access_token TEXT NULL DEFAULT NULL, + refresh_token TEXT NULL DEFAULT NULL, + token_expiry DATETIME NULL DEFAULT NULL, + + otp_secret TEXT NOT NULL DEFAULT '', + otp_digits INTEGER NOT NULL DEFAULT 0, + + to_delete BOOLEAN NOT NULL DEFAULT 0 ); CREATE INDEX users_subject ON users (subject); @@ -39,21 +51,12 @@ CREATE TABLE users_roles role_id INTEGER NOT NULL, user_id INTEGER NOT NULL, - FOREIGN KEY (role_id) REFERENCES roles (id), + FOREIGN KEY (role_id) REFERENCES roles (id) ON DELETE RESTRICT, FOREIGN KEY (user_id) REFERENCES users (id), CONSTRAINT user_role UNIQUE (role_id, user_id) ); -CREATE TABLE otp -( - subject INTEGER NOT NULL UNIQUE PRIMARY KEY, - secret TEXT NOT NULL, - digits INTEGER NOT NULL, - - FOREIGN KEY (subject) REFERENCES users (subject) -); - CREATE TABLE client_store ( subject TEXT NOT NULL UNIQUE PRIMARY KEY, diff --git a/database/models.go b/database/models.go index fc0b48f..ce08f56 100644 --- a/database/models.go +++ b/database/models.go @@ -5,9 +5,12 @@ package database import ( + "database/sql" "time" + "github.com/1f349/lavender/database/types" "github.com/1f349/lavender/password" + "github.com/hardfinhq/go-date" ) type ClientStore struct { @@ -22,38 +25,39 @@ type ClientStore struct { Active bool `json:"active"` } -type Otp struct { - Subject int64 `json:"subject"` - Secret string `json:"secret"` - Digits int64 `json:"digits"` -} - -type Profile struct { - Subject string `json:"subject"` - Name string `json:"name"` - Picture string `json:"picture"` - Website string `json:"website"` - Pronouns string `json:"pronouns"` - Birthdate interface{} `json:"birthdate"` - Zone string `json:"zone"` - Locale string `json:"locale"` - UpdatedAt time.Time `json:"updated_at"` -} - type Role struct { ID int64 `json:"id"` Role string `json:"role"` } type User struct { - ID int64 `json:"id"` - Subject string `json:"subject"` - Password password.HashString `json:"password"` - Email string `json:"email"` - EmailVerified bool `json:"email_verified"` - UpdatedAt time.Time `json:"updated_at"` - Registered time.Time `json:"registered"` - Active bool `json:"active"` + ID int64 `json:"id"` + Subject string `json:"subject"` + Password password.HashString `json:"password"` + ChangePassword bool `json:"change_password"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + UpdatedAt time.Time `json:"updated_at"` + Registered time.Time `json:"registered"` + Active bool `json:"active"` + Name string `json:"name"` + Picture string `json:"picture"` + Website string `json:"website"` + Pronouns types.UserPronoun `json:"pronouns"` + Birthdate date.NullDate `json:"birthdate"` + Zone string `json:"zone"` + Locale types.UserLocale `json:"locale"` + Login string `json:"login"` + ProfileUrl string `json:"profile_url"` + AuthType types.AuthType `json:"auth_type"` + AuthNamespace string `json:"auth_namespace"` + AuthUser string `json:"auth_user"` + AccessToken sql.NullString `json:"access_token"` + RefreshToken sql.NullString `json:"refresh_token"` + TokenExpiry sql.NullTime `json:"token_expiry"` + OtpSecret string `json:"otp_secret"` + OtpDigits int64 `json:"otp_digits"` + ToDelete bool `json:"to_delete"` } type UsersRole struct { diff --git a/database/otp.sql.go b/database/otp.sql.go index fc30726..04197cb 100644 --- a/database/otp.sql.go +++ b/database/otp.sql.go @@ -10,31 +10,32 @@ import ( ) const deleteOtp = `-- name: DeleteOtp :exec -DELETE -FROM otp -WHERE otp.subject = ? +UPDATE users +SET otp_secret='', + otp_digits=0 +WHERE subject = ? ` -func (q *Queries) DeleteOtp(ctx context.Context, subject int64) error { +func (q *Queries) DeleteOtp(ctx context.Context, subject string) error { _, err := q.db.ExecContext(ctx, deleteOtp, subject) return err } const getOtp = `-- name: GetOtp :one -SELECT secret, digits -FROM otp +SELECT otp_secret, otp_digits +FROM users WHERE subject = ? ` type GetOtpRow struct { - Secret string `json:"secret"` - Digits int64 `json:"digits"` + OtpSecret string `json:"otp_secret"` + OtpDigits int64 `json:"otp_digits"` } -func (q *Queries) GetOtp(ctx context.Context, subject int64) (GetOtpRow, error) { +func (q *Queries) GetOtp(ctx context.Context, subject string) (GetOtpRow, error) { row := q.db.QueryRowContext(ctx, getOtp, subject) var i GetOtpRow - err := row.Scan(&i.Secret, &i.Digits) + err := row.Scan(&i.OtpSecret, &i.OtpDigits) return i, err } @@ -52,10 +53,13 @@ func (q *Queries) GetUserEmail(ctx context.Context, subject string) (string, err } const hasOtp = `-- name: HasOtp :one -SELECT EXISTS(SELECT 1 FROM otp WHERE subject = ?) == 1 as hasOtp +SELECT CAST(1 AS BOOLEAN) AS hasOtp +FROM users +WHERE subject = ? + AND otp_secret != '' ` -func (q *Queries) HasOtp(ctx context.Context, subject int64) (bool, error) { +func (q *Queries) HasOtp(ctx context.Context, subject string) (bool, error) { row := q.db.QueryRowContext(ctx, hasOtp, subject) var hasotp bool err := row.Scan(&hasotp) @@ -63,19 +67,19 @@ func (q *Queries) HasOtp(ctx context.Context, subject int64) (bool, error) { } const setOtp = `-- name: SetOtp :exec -INSERT OR -REPLACE -INTO otp (subject, secret, digits) -VALUES (?, ?, ?) +UPDATE users +SET otp_secret = ?, + otp_digits=? +WHERE subject = ? ` type SetOtpParams struct { - Subject int64 `json:"subject"` - Secret string `json:"secret"` - Digits int64 `json:"digits"` + OtpSecret string `json:"otp_secret"` + OtpDigits int64 `json:"otp_digits"` + Subject string `json:"subject"` } func (q *Queries) SetOtp(ctx context.Context, arg SetOtpParams) error { - _, err := q.db.ExecContext(ctx, setOtp, arg.Subject, arg.Secret, arg.Digits) + _, err := q.db.ExecContext(ctx, setOtp, arg.OtpSecret, arg.OtpDigits, arg.Subject) return err } diff --git a/database/password-wrapper.go b/database/password-wrapper.go index 07f94ee..f30b251 100644 --- a/database/password-wrapper.go +++ b/database/password-wrapper.go @@ -2,35 +2,69 @@ package database import ( "context" + "github.com/1f349/lavender/database/types" "github.com/1f349/lavender/password" "github.com/google/uuid" "time" ) -type AddUserParams struct { - Name string `json:"name"` - Subject string `json:"subject"` - Password string `json:"password"` - Email string `json:"email"` - EmailVerified bool `json:"email_verified"` - UpdatedAt time.Time `json:"updated_at"` - Active bool `json:"active"` +type AddLocalUserParams struct { + Password string `json:"password"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Name string `json:"name"` + Username string `json:"username"` + ChangePassword bool `json:"change_password"` } -func (q *Queries) AddUser(ctx context.Context, arg AddUserParams) (string, error) { +func (q *Queries) AddLocalUser(ctx context.Context, arg AddLocalUserParams) (string, error) { pwHash, err := password.HashPassword(arg.Password) if err != nil { return "", err } n := time.Now() a := addUserParams{ - Subject: uuid.NewString(), - Password: pwHash, - Email: arg.Email, - EmailVerified: arg.EmailVerified, - UpdatedAt: n, - Registered: n, - Active: true, + Subject: uuid.NewString(), + Password: pwHash, + Email: arg.Email, + EmailVerified: arg.EmailVerified, + UpdatedAt: n, + Registered: n, + Active: true, + Name: arg.Name, + Login: arg.Username, + ChangePassword: arg.ChangePassword, + AuthType: types.AuthTypeLocal, + AuthNamespace: "", + AuthUser: arg.Username, + } + return a.Subject, q.addUser(ctx, a) +} + +type AddOAuthUserParams struct { + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Name string `json:"name"` + Username string `json:"username"` + AuthNamespace string `json:"auth_namespace"` + AuthUser string `json:"auth_user"` +} + +func (q *Queries) AddOAuthUser(ctx context.Context, arg AddOAuthUserParams) (string, error) { + n := time.Now() + a := addUserParams{ + Subject: uuid.NewString(), + Email: arg.Email, + EmailVerified: arg.EmailVerified, + UpdatedAt: n, + Registered: n, + Active: true, + Name: arg.Name, + Login: arg.Username, + ChangePassword: false, + AuthType: types.AuthTypeOauth2, + AuthNamespace: arg.AuthNamespace, + AuthUser: arg.AuthUser, } return a.Subject, q.addUser(ctx, a) } diff --git a/database/profiles.sql.go b/database/profiles.sql.go index fd54f5c..d92f220 100644 --- a/database/profiles.sql.go +++ b/database/profiles.sql.go @@ -8,17 +8,38 @@ package database import ( "context" "time" + + "github.com/1f349/lavender/database/types" + "github.com/hardfinhq/go-date" ) const getProfile = `-- name: GetProfile :one -SELECT profiles.subject, profiles.name, profiles.picture, profiles.website, profiles.pronouns, profiles.birthdate, profiles.zone, profiles.locale, profiles.updated_at -FROM profiles +SELECT subject, + name, + picture, + website, + pronouns, + birthdate, + zone, + locale +FROM users WHERE subject = ? ` -func (q *Queries) GetProfile(ctx context.Context, subject string) (Profile, error) { +type GetProfileRow struct { + Subject string `json:"subject"` + Name string `json:"name"` + Picture string `json:"picture"` + Website string `json:"website"` + Pronouns types.UserPronoun `json:"pronouns"` + Birthdate date.NullDate `json:"birthdate"` + Zone string `json:"zone"` + Locale types.UserLocale `json:"locale"` +} + +func (q *Queries) GetProfile(ctx context.Context, subject string) (GetProfileRow, error) { row := q.db.QueryRowContext(ctx, getProfile, subject) - var i Profile + var i GetProfileRow err := row.Scan( &i.Subject, &i.Name, @@ -28,13 +49,12 @@ func (q *Queries) GetProfile(ctx context.Context, subject string) (Profile, erro &i.Birthdate, &i.Zone, &i.Locale, - &i.UpdatedAt, ) return i, err } const modifyProfile = `-- name: ModifyProfile :exec -UPDATE profiles +UPDATE users SET name = ?, picture = ?, website = ?, @@ -47,15 +67,15 @@ WHERE subject = ? ` type ModifyProfileParams struct { - Name string `json:"name"` - Picture string `json:"picture"` - Website string `json:"website"` - Pronouns string `json:"pronouns"` - Birthdate interface{} `json:"birthdate"` - Zone string `json:"zone"` - Locale string `json:"locale"` - UpdatedAt time.Time `json:"updated_at"` - Subject string `json:"subject"` + Name string `json:"name"` + Picture string `json:"picture"` + Website string `json:"website"` + Pronouns types.UserPronoun `json:"pronouns"` + Birthdate date.NullDate `json:"birthdate"` + Zone string `json:"zone"` + Locale types.UserLocale `json:"locale"` + UpdatedAt time.Time `json:"updated_at"` + Subject string `json:"subject"` } func (q *Queries) ModifyProfile(ctx context.Context, arg ModifyProfileParams) error { diff --git a/database/queries/manage-oauth.sql b/database/queries/manage-oauth.sql index 7225f40..12cb631 100644 --- a/database/queries/manage-oauth.sql +++ b/database/queries/manage-oauth.sql @@ -15,7 +15,7 @@ SELECT subject, active FROM client_store WHERE owner_subject = ? - OR ? = 1 + OR CAST(? AS BOOLEAN) = 1 LIMIT 25 OFFSET ?; -- name: InsertClientApp :exec diff --git a/database/queries/manage-users.sql b/database/queries/manage-users.sql index 587b87e..b9814bb 100644 --- a/database/queries/manage-users.sql +++ b/database/queries/manage-users.sql @@ -5,11 +5,10 @@ SELECT users.subject, website, email, email_verified, - users.updated_at as user_updated_at, - p.updated_at as profile_updated_at, + updated_at, active FROM users - INNER JOIN main.profiles p on users.subject = p.subject +--INNER JOIN main.profiles p on users.subject = p.subject LIMIT 50 OFFSET ?; -- name: GetUsersRoles :many @@ -24,5 +23,54 @@ UPDATE users SET active = cast(? as boolean) WHERE subject = ?; +-- name: VerifyUserEmail :exec +UPDATE users +SET email_verified=1 +WHERE subject = ?; + -- name: UserEmailExists :one SELECT EXISTS(SELECT 1 FROM users WHERE email = ? AND email_verified = 1) == 1 AS email_exists; + +-- name: ModifyUserEmail :exec +UPDATE users +SET email = ?, + email_verified=? +WHERE subject = ?; + +-- name: ModifyUserAuth :exec +UPDATE users +SET auth_type = ?, + auth_namespace=?, + auth_user = ? +WHERE subject = ?; + +-- name: ModifyUserRemoteLogin :exec +UPDATE users +SET login = ?, + profile_url = ? +WHERE subject = ?; + +-- name: UpdateUserToken :exec +UPDATE users +SET access_token = ?, + refresh_token=?, + token_expiry = ? +WHERE subject = ?; + +-- name: GetUserToken :one +SELECT access_token, refresh_token, token_expiry +FROM users +WHERE subject = ?; + +-- name: RemoveUserRoles :exec +DELETE +FROM users_roles +WHERE user_id IN (SELECT id + FROM users + WHERE subject = ?); + +-- name: AddUserRole :exec +INSERT INTO users_roles(role_id, user_id) +SELECT ?, users.id +FROM users +WHERE subject = ?; diff --git a/database/queries/otp.sql b/database/queries/otp.sql index 175399a..94599c5 100644 --- a/database/queries/otp.sql +++ b/database/queries/otp.sql @@ -1,21 +1,25 @@ -- name: SetOtp :exec -INSERT OR -REPLACE -INTO otp (subject, secret, digits) -VALUES (?, ?, ?); +UPDATE users +SET otp_secret = ?, + otp_digits=? +WHERE subject = ?; -- name: DeleteOtp :exec -DELETE -FROM otp -WHERE otp.subject = ?; +UPDATE users +SET otp_secret='', + otp_digits=0 +WHERE subject = ?; -- name: GetOtp :one -SELECT secret, digits -FROM otp +SELECT otp_secret, otp_digits +FROM users WHERE subject = ?; -- name: HasOtp :one -SELECT EXISTS(SELECT 1 FROM otp WHERE subject = ?) == 1 as hasOtp; +SELECT CAST(1 AS BOOLEAN) AS hasOtp +FROM users +WHERE subject = ? + AND otp_secret != ''; -- name: GetUserEmail :one SELECT email diff --git a/database/queries/profiles.sql b/database/queries/profiles.sql index 134da89..74203c6 100644 --- a/database/queries/profiles.sql +++ b/database/queries/profiles.sql @@ -1,10 +1,17 @@ -- name: GetProfile :one -SELECT profiles.* -FROM profiles +SELECT subject, + name, + picture, + website, + pronouns, + birthdate, + zone, + locale +FROM users WHERE subject = ?; -- name: ModifyProfile :exec -UPDATE profiles +UPDATE users SET name = ?, picture = ?, website = ?, diff --git a/database/queries/roles.sql b/database/queries/roles.sql new file mode 100644 index 0000000..3c7d54e --- /dev/null +++ b/database/queries/roles.sql @@ -0,0 +1,8 @@ +-- name: AddRole :execlastid +INSERT OR IGNORE INTO roles(role) +VALUES (?); + +-- name: RemoveRole :exec +DELETE +FROM roles +WHERE role = ?; diff --git a/database/queries/users.sql b/database/queries/users.sql index 1a916aa..0d33bfc 100644 --- a/database/queries/users.sql +++ b/database/queries/users.sql @@ -3,15 +3,11 @@ SELECT count(subject) > 0 AS hasUser FROM users; -- name: addUser :exec -INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active) -VALUES (?, ?, ?, ?, ?, ?, ?); - --- name: addOAuthUser :exec -INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active) -VALUES (?, ?, ?, ?, ?, ?, ?); +INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active, name, login, change_password, auth_type, auth_namespace, auth_user) +VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); -- name: checkLogin :one -SELECT subject, password, EXISTS(SELECT 1 FROM otp WHERE otp.subject = users.subject) == 1 AS has_otp, email, email_verified +SELECT subject, password, CAST(otp_secret != '' AS BOOLEAN) AS has_otp, email, email_verified FROM users WHERE users.subject = ? LIMIT 1; @@ -48,3 +44,9 @@ SET password = ?, updated_at=? WHERE subject = ? AND password = ?; + +-- name: FlagUserAsDeleted :exec +UPDATE users +SET active= false, + to_delete = true +WHERE subject = ?; diff --git a/database/roles.sql.go b/database/roles.sql.go new file mode 100644 index 0000000..fa557bd --- /dev/null +++ b/database/roles.sql.go @@ -0,0 +1,34 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: roles.sql + +package database + +import ( + "context" +) + +const addRole = `-- name: AddRole :execlastid +INSERT OR IGNORE INTO roles(role) +VALUES (?) +` + +func (q *Queries) AddRole(ctx context.Context, role string) (int64, error) { + result, err := q.db.ExecContext(ctx, addRole, role) + if err != nil { + return 0, err + } + return result.LastInsertId() +} + +const removeRole = `-- name: RemoveRole :exec +DELETE +FROM roles +WHERE role = ? +` + +func (q *Queries) RemoveRole(ctx context.Context, role string) error { + _, err := q.db.ExecContext(ctx, removeRole, role) + return err +} diff --git a/database/tx.go b/database/tx.go new file mode 100644 index 0000000..15f8522 --- /dev/null +++ b/database/tx.go @@ -0,0 +1,26 @@ +package database + +import ( + "context" + "database/sql" + "errors" +) + +var errCannotOpenTransactionWithoutSqlDB = errors.New("cannot open transaction without sql.DB") + +func (q *Queries) UseTx(ctx context.Context, cb func(tx *Queries) error) error { + sqlDB, ok := q.db.(*sql.DB) + if !ok { + panic(errCannotOpenTransactionWithoutSqlDB) + } + tx, err := sqlDB.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + err = cb(q.WithTx(tx)) + if err != nil { + return err + } + return tx.Commit() +} diff --git a/database/types/authtype.go b/database/types/authtype.go index 903f717..09ee32b 100644 --- a/database/types/authtype.go +++ b/database/types/authtype.go @@ -3,7 +3,7 @@ package types type AuthType byte const ( - AuthTypeBase AuthType = iota + AuthTypeLocal AuthType = iota AuthTypeOauth2 ) diff --git a/database/users.sql.go b/database/users.sql.go index 04c1c02..7b7b6ef 100644 --- a/database/users.sql.go +++ b/database/users.sql.go @@ -9,11 +9,24 @@ import ( "context" "time" + "github.com/1f349/lavender/database/types" "github.com/1f349/lavender/password" ) +const flagUserAsDeleted = `-- name: FlagUserAsDeleted :exec +UPDATE users +SET active= false, + to_delete = true +WHERE subject = ? +` + +func (q *Queries) FlagUserAsDeleted(ctx context.Context, subject string) error { + _, err := q.db.ExecContext(ctx, flagUserAsDeleted, subject) + return err +} + const getUser = `-- name: GetUser :one -SELECT id, subject, password, email, email_verified, updated_at, registered, active +SELECT id, subject, password, change_password, email, email_verified, updated_at, registered, active, name, picture, website, pronouns, birthdate, zone, locale, login, profile_url, auth_type, auth_namespace, auth_user, access_token, refresh_token, token_expiry, otp_secret, otp_digits, to_delete FROM users WHERE subject = ? LIMIT 1 @@ -26,11 +39,30 @@ func (q *Queries) GetUser(ctx context.Context, subject string) (User, error) { &i.ID, &i.Subject, &i.Password, + &i.ChangePassword, &i.Email, &i.EmailVerified, &i.UpdatedAt, &i.Registered, &i.Active, + &i.Name, + &i.Picture, + &i.Website, + &i.Pronouns, + &i.Birthdate, + &i.Zone, + &i.Locale, + &i.Login, + &i.ProfileUrl, + &i.AuthType, + &i.AuthNamespace, + &i.AuthUser, + &i.AccessToken, + &i.RefreshToken, + &i.TokenExpiry, + &i.OtpSecret, + &i.OtpDigits, + &i.ToDelete, ) return i, err } @@ -98,18 +130,24 @@ func (q *Queries) UserHasRole(ctx context.Context, arg UserHasRoleParams) error } const addUser = `-- name: addUser :exec -INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active) -VALUES (?, ?, ?, ?, ?, ?, ?) +INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active, name, login, change_password, auth_type, auth_namespace, auth_user) +VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` type addUserParams struct { - Subject string `json:"subject"` - Password password.HashString `json:"password"` - Email string `json:"email"` - EmailVerified bool `json:"email_verified"` - UpdatedAt time.Time `json:"updated_at"` - Registered time.Time `json:"registered"` - Active bool `json:"active"` + Subject string `json:"subject"` + Password password.HashString `json:"password"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + UpdatedAt time.Time `json:"updated_at"` + Registered time.Time `json:"registered"` + Active bool `json:"active"` + Name string `json:"name"` + Login string `json:"login"` + ChangePassword bool `json:"change_password"` + AuthType types.AuthType `json:"auth_type"` + AuthNamespace string `json:"auth_namespace"` + AuthUser string `json:"auth_user"` } func (q *Queries) addUser(ctx context.Context, arg addUserParams) error { @@ -121,6 +159,12 @@ func (q *Queries) addUser(ctx context.Context, arg addUserParams) error { arg.UpdatedAt, arg.Registered, arg.Active, + arg.Name, + arg.Login, + arg.ChangePassword, + arg.AuthType, + arg.AuthNamespace, + arg.AuthUser, ) return err } @@ -151,7 +195,7 @@ func (q *Queries) changeUserPassword(ctx context.Context, arg changeUserPassword } const checkLogin = `-- name: checkLogin :one -SELECT subject, password, EXISTS(SELECT 1 FROM otp WHERE otp.subject = users.subject) == 1 AS has_otp, email, email_verified +SELECT subject, password, CAST(otp_secret != '' AS BOOLEAN) AS has_otp, email, email_verified FROM users WHERE users.subject = ? LIMIT 1 diff --git a/go.mod b/go.mod index dc2b926..63fa064 100644 --- a/go.mod +++ b/go.mod @@ -6,12 +6,10 @@ require ( github.com/1f349/cache v0.0.3 github.com/1f349/mjwt v0.4.1 github.com/1f349/overlapfs v0.0.1 - github.com/1f349/tulip v0.0.0-20240725211619-6b19e2d4ca63 + github.com/1f349/simplemail v0.0.5 github.com/charmbracelet/log v0.4.0 github.com/cloudflare/tableflip v1.2.3 github.com/emersion/go-message v0.18.1 - github.com/emersion/go-sasl v0.0.0-20231106173351-e73c9f7bad43 - github.com/emersion/go-smtp v0.21.3 github.com/go-oauth2/oauth2/v4 v4.5.2 github.com/golang-jwt/jwt/v4 v4.5.0 github.com/golang-migrate/migrate/v4 v4.17.1 @@ -21,10 +19,13 @@ require ( github.com/julienschmidt/httprouter v1.3.0 github.com/mattn/go-sqlite3 v1.14.22 github.com/mrmelon54/pronouns v1.0.3 + github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/spf13/afero v1.11.0 github.com/stretchr/testify v1.9.0 + github.com/xlzd/gotp v0.1.0 golang.org/x/crypto v0.26.0 golang.org/x/oauth2 v0.22.0 + golang.org/x/sync v0.8.0 golang.org/x/text v0.17.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -36,6 +37,8 @@ require ( github.com/charmbracelet/lipgloss v0.12.1 // indirect github.com/charmbracelet/x/ansi v0.2.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/emersion/go-sasl v0.0.0-20231106173351-e73c9f7bad43 // indirect + github.com/emersion/go-smtp v0.21.3 // indirect github.com/go-jose/go-jose/v4 v4.0.4 // indirect github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect @@ -63,6 +66,5 @@ require ( go.uber.org/atomic v1.11.0 // indirect golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect golang.org/x/net v0.28.0 // indirect - golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.24.0 // indirect ) diff --git a/go.sum b/go.sum index 98374f7..7717885 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,8 @@ github.com/1f349/overlapfs v0.0.1 h1:LAxBolrXFAgU0yqZtXg/C/aaPq3eoQSPpBc49BHuTp0 github.com/1f349/overlapfs v0.0.1/go.mod h1:I6aItQycr7nrzplmfNXp/QF9tTmKRSgY3fXmu/7Ky2o= github.com/1f349/rsa-helper v0.0.2 h1:N/fLQqg5wrjIzG6G4zdwa5Xcv9/jIPutCls9YekZr9U= github.com/1f349/rsa-helper v0.0.2/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg= -github.com/1f349/tulip v0.0.0-20240725211619-6b19e2d4ca63 h1:jPg+0bgKD5kY7yQtRZqeba+BGKFE51evGvwewZwa7Xc= -github.com/1f349/tulip v0.0.0-20240725211619-6b19e2d4ca63/go.mod h1:1zFQhcbgiyPSWHVMp0cXJjmd6FhasP5bf5tWS4ZK61A= +github.com/1f349/simplemail v0.0.5 h1:cr+8pdWhFE/+XVSO7ZTjntySbmIbTqmDy2SR9cHAPLE= +github.com/1f349/simplemail v0.0.5/go.mod h1:ppAIqkvVkI6L99EefbR5NgOjpePNK/RKgeoehj5A+kU= github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= @@ -146,6 +146,8 @@ github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99 github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw= github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= +github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= +github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= @@ -199,6 +201,8 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHo github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +github.com/xlzd/gotp v0.1.0 h1:37blvlKCh38s+fkem+fFh7sMnceltoIEBYTVXyoa5Po= +github.com/xlzd/gotp v0.1.0/go.mod h1:ndLJ3JKzi3xLmUProq4LLxCuECL93dG9WASNLpHz8qg= github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 h1:6fRhSjgLCkTD3JnJxvaJ4Sj+TYblw757bqYgZaOq5ZY= github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI= github.com/yudai/gojsondiff v1.0.0 h1:27cbfqXLVEJ1o8I6v3y9lg8Ydm53EKqHXAOMxEGlCOA= diff --git a/issuer/manager.go b/issuer/manager.go index 87b74dc..8520c15 100644 --- a/issuer/manager.go +++ b/issuer/manager.go @@ -25,6 +25,7 @@ func NewManager(services map[string]SsoConfig) (*Manager, error) { } // save by namespace + conf.Namespace = namespace l.m[namespace] = conf } return l, nil diff --git a/mail/from-address.go b/mail/from-address.go deleted file mode 100644 index e52f5f8..0000000 --- a/mail/from-address.go +++ /dev/null @@ -1,26 +0,0 @@ -package mail - -import ( - "encoding/json" - "github.com/emersion/go-message/mail" -) - -type FromAddress struct { - *mail.Address -} - -var _ json.Unmarshaler = &FromAddress{} - -func (f *FromAddress) UnmarshalJSON(b []byte) error { - var a string - err := json.Unmarshal(b, &a) - if err != nil { - return err - } - address, err := mail.ParseAddress(a) - if err != nil { - return err - } - f.Address = address - return nil -} diff --git a/mail/mail.go b/mail/mail.go index 8403664..dca2e29 100644 --- a/mail/mail.go +++ b/mail/mail.go @@ -1,96 +1,48 @@ package mail import ( - "bytes" + "embed" + "errors" + "fmt" + "github.com/1f349/overlapfs" + "github.com/1f349/simplemail" "github.com/emersion/go-message/mail" - "github.com/emersion/go-sasl" - "github.com/emersion/go-smtp" - "io" - "net" - "time" + "io/fs" + "os" + "path/filepath" ) +//go:embed templates/*.go.html templates/*.go.txt +var embeddedTemplates embed.FS + type Mail struct { - Name string `json:"name"` - Tls bool `json:"tls"` - Server string `json:"server"` - From FromAddress `json:"from"` - Username string `json:"username"` - Password string `json:"password"` + mail *simplemail.SimpleMail + name string } -func (m *Mail) loginInfo() sasl.Client { - return sasl.NewPlainClient("", m.Username, m.Password) -} - -func (m *Mail) mailCall(to []string, r io.Reader) error { - host, _, err := net.SplitHostPort(m.Server) - if err != nil { - return err - } - if m.Tls { - return smtp.SendMailTLS(m.Server, m.loginInfo(), m.From.String(), to, r) - } - if host == "localhost" || host == "127.0.0.1" { - // internals of smtp.SendMail without STARTTLS for localhost testing - dial, err := smtp.Dial(m.Server) - if err != nil { - return err +func New(sender *simplemail.Mail, wd, name string) (*Mail, error) { + var o fs.FS = embeddedTemplates + o, _ = fs.Sub(o, "templates") + if wd != "" { + mailDir := filepath.Join(wd, "mail-templates") + err := os.Mkdir(mailDir, os.ModePerm) + if err == nil || errors.Is(err, os.ErrExist) { + wdFs := os.DirFS(mailDir) + o = overlapfs.OverlapFS{A: embeddedTemplates, B: wdFs} } - err = dial.Auth(m.loginInfo()) - if err != nil { - return err - } - return dial.SendMail(m.From.String(), to, r) } - return smtp.SendMail(m.Server, m.loginInfo(), m.From.String(), to, r) + + simpleMail, err := simplemail.New(sender, o) + return &Mail{ + mail: simpleMail, + name: name, + }, err } -func (m *Mail) SendMail(subject string, to []*mail.Address, htmlBody, textBody io.Reader) error { - // generate the email in this template - buf := new(bytes.Buffer) - - // setup mail headers - var h mail.Header - h.SetDate(time.Now()) - h.SetSubject(subject) - h.SetAddressList("From", []*mail.Address{m.From.Address}) - h.SetAddressList("To", to) - h.Set("Content-Type", "multipart/alternative") - - // setup html and text alternative headers - var hHtml, hTxt mail.InlineHeader - hHtml.Set("Content-Type", "text/html; charset=utf-8") - hTxt.Set("Content-Type", "text/plain; charset=utf-8") - - createWriter, err := mail.CreateWriter(buf, h) - if err != nil { - return err - } - inline, err := createWriter.CreateInline() - if err != nil { - return err - } - partHtml, err := inline.CreatePart(hHtml) - if err != nil { - return err - } - if _, err := io.Copy(partHtml, htmlBody); err != nil { - return err - } - partTxt, err := inline.CreatePart(hTxt) - if err != nil { - return err - } - if _, err := io.Copy(partTxt, textBody); err != nil { - return err - } - - // convert all to addresses to strings - toStr := make([]string, len(to)) - for i := range toStr { - toStr[i] = to[i].String() - } - - return m.mailCall(toStr, buf) +func (m *Mail) SendEmailTemplate(templateName, subject, nameOfUser string, to *mail.Address, data map[string]any) error { + return m.mail.Send(templateName, fmt.Sprintf("%s - %s", subject, m.name), to, map[string]any{ + "ServiceName": m.name, + "Name": nameOfUser, + "Data": data, + }) } diff --git a/mail/send-template.go b/mail/send-template.go deleted file mode 100644 index 5f2c22f..0000000 --- a/mail/send-template.go +++ /dev/null @@ -1,18 +0,0 @@ -package mail - -import ( - "bytes" - "fmt" - "github.com/1f349/lavender/mail/templates" - "github.com/emersion/go-message/mail" -) - -func (m *Mail) SendEmailTemplate(templateName, subject, nameOfUser string, to *mail.Address, data map[string]any) error { - var bufHtml, bufTxt bytes.Buffer - templates.RenderMailTemplate(&bufHtml, &bufTxt, templateName, map[string]any{ - "ServiceName": m.Name, - "Name": nameOfUser, - "Data": data, - }) - return m.SendMail(fmt.Sprintf("%s - %s", subject, m.Name), []*mail.Address{to}, &bufHtml, &bufTxt) -} diff --git a/mail/templates/templates.go b/mail/templates/templates.go deleted file mode 100644 index ed82df3..0000000 --- a/mail/templates/templates.go +++ /dev/null @@ -1,55 +0,0 @@ -package templates - -import ( - "embed" - "errors" - "github.com/1f349/overlapfs" - "github.com/1f349/tulip/logger" - htmlTemplate "html/template" - "io" - "io/fs" - "os" - "path/filepath" - "sync" - textTemplate "text/template" -) - -var ( - //go:embed *.go.html *.go.txt - embeddedTemplates embed.FS - mailHtmlTemplates *htmlTemplate.Template - mailTextTemplates *textTemplate.Template - loadOnce sync.Once -) - -func LoadMailTemplates(wd string) (err error) { - loadOnce.Do(func() { - var o fs.FS = embeddedTemplates - if wd != "" { - mailDir := filepath.Join(wd, "mail-templates") - err = os.Mkdir(mailDir, os.ModePerm) - if err != nil && !errors.Is(err, os.ErrExist) { - return - } - wdFs := os.DirFS(mailDir) - o = overlapfs.OverlapFS{A: embeddedTemplates, B: wdFs} - } - mailHtmlTemplates, err = htmlTemplate.New("mail").ParseFS(o, "*.go.html") - if err != nil { - return - } - mailTextTemplates, err = textTemplate.New("mail").ParseFS(o, "*.go.txt") - }) - return -} - -func RenderMailTemplate(wrHtml, wrTxt io.Writer, name string, data any) { - err := mailHtmlTemplates.ExecuteTemplate(wrHtml, name+".go.html", data) - if err != nil { - logger.Logger.Warn("Failed to render mail html", "name", name, "err", err) - } - err = mailTextTemplates.ExecuteTemplate(wrTxt, name+".go.txt", data) - if err != nil { - logger.Logger.Warn("Failed to render mail text", "name", name, "err", err) - } -} diff --git a/server/auth.go b/server/auth.go index 60e222f..79ebad8 100644 --- a/server/auth.go +++ b/server/auth.go @@ -59,7 +59,7 @@ func (h *httpServer) RequireAdminAuthentication(next UserHandler) httprouter.Han } func (h *httpServer) RequireAuthentication(next UserHandler) httprouter.Handle { - return h.OptionalAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { + return h.OptionalAuthentication(false, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { if auth.IsGuest() { redirectUrl := PrepareRedirectUrl("/login", req.URL) http.Redirect(rw, req, redirectUrl.String(), http.StatusFound) @@ -69,16 +69,20 @@ func (h *httpServer) RequireAuthentication(next UserHandler) httprouter.Handle { }) } -func (h *httpServer) OptionalAuthentication(next UserHandler) httprouter.Handle { +func (h *httpServer) OptionalAuthentication(flowPart bool, next UserHandler) httprouter.Handle { return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { - authUser, err := h.internalAuthenticationHandler(rw, req) + authData, err := h.internalAuthenticationHandler(rw, req) if err != nil { if !errors.Is(err, ErrAuthHttpError) { http.Error(rw, err.Error(), http.StatusInternalServerError) } return } - next(rw, req, params, authUser) + if n := authData.NextFlowUrl(req.URL); n != nil && !flowPart { + http.Redirect(rw, req, n.String(), http.StatusFound) + return + } + next(rw, req, params, authData) } } diff --git a/server/auth_test.go b/server/auth_test.go new file mode 100644 index 0000000..68b6603 --- /dev/null +++ b/server/auth_test.go @@ -0,0 +1,73 @@ +package server + +import ( + "context" + "github.com/1f349/mjwt" + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestUserAuth_NextFlowUrl(t *testing.T) { + u := UserAuth{NeedOtp: true} + assert.Equal(t, url.URL{Path: "/login/otp"}, *u.NextFlowUrl(&url.URL{})) + assert.Equal(t, url.URL{Path: "/login/otp", RawQuery: url.Values{"redirect": {"/hello"}}.Encode()}, *u.NextFlowUrl(&url.URL{Path: "/hello"})) + assert.Equal(t, url.URL{Path: "/login/otp", RawQuery: url.Values{"redirect": {"/hello?a=A"}}.Encode()}, *u.NextFlowUrl(&url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}}.Encode()})) + u.NeedOtp = false + assert.Nil(t, u.NextFlowUrl(&url.URL{})) +} + +func TestUserAuth_IsGuest(t *testing.T) { + var u UserAuth + assert.True(t, u.IsGuest()) + u.Subject = uuid.NewString() + assert.False(t, u.IsGuest()) +} + +type fakeSessionStore struct { + m map[string]any + saveFunc func(map[string]any) error +} + +func (f *fakeSessionStore) Context() context.Context { return context.Background() } +func (f *fakeSessionStore) SessionID() string { return "fakeSessionStore" } +func (f *fakeSessionStore) Set(key string, value interface{}) { f.m[key] = value } + +func (f *fakeSessionStore) Get(key string) (a interface{}, ok bool) { + if a, ok = f.m[key]; false { + } + return +} + +func TestRequireAuthentication(t *testing.T) { +} + +func TestOptionalAuthentication(t *testing.T) { + jwtIssuer, err := mjwt.NewIssuer("TestIssuer", uuid.NewString(), jwt.SigningMethodRS512) + h := &httpServer{signingKey: jwtIssuer} + rec := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "https://example.com/hello", nil) + assert.NoError(t, err) + auth, err := h.internalAuthenticationHandler(rec, req) + assert.NoError(t, err) + assert.True(t, auth.IsGuest()) + auth.Subject = "567" +} + +func TestPrepareRedirectUrl(t *testing.T) { + assert.Equal(t, url.URL{Path: "/hello"}, *PrepareRedirectUrl("/hello", &url.URL{})) + assert.Equal(t, url.URL{Path: "/world"}, *PrepareRedirectUrl("/world", &url.URL{})) + assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"redirect": {"/hello"}}.Encode()}, *PrepareRedirectUrl("/a", &url.URL{Path: "/hello"})) + assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"redirect": {"/hello?a=A"}}.Encode()}, *PrepareRedirectUrl("/a", &url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}}.Encode()})) + assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"redirect": {"/hello?a=A&b=B"}}.Encode()}, *PrepareRedirectUrl("/a", &url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}, "b": {"B"}}.Encode()})) + + assert.Equal(t, url.URL{Path: "/hello", RawQuery: "z=y"}, *PrepareRedirectUrl("/hello?z=y", &url.URL{})) + assert.Equal(t, url.URL{Path: "/world", RawQuery: "z=y"}, *PrepareRedirectUrl("/world?z=y", &url.URL{})) + assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"z": {"y"}, "redirect": {"/hello"}}.Encode()}, *PrepareRedirectUrl("/a?z=y", &url.URL{Path: "/hello"})) + assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"z": {"y"}, "redirect": {"/hello?a=A"}}.Encode()}, *PrepareRedirectUrl("/a?z=y", &url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}}.Encode()})) + assert.Equal(t, url.URL{Path: "/a", RawQuery: url.Values{"z": {"y"}, "redirect": {"/hello?a=A&b=B"}}.Encode()}, *PrepareRedirectUrl("/a?z=y", &url.URL{Path: "/hello", RawQuery: url.Values{"a": {"A"}, "b": {"B"}}.Encode()})) +} diff --git a/server/edit.go b/server/edit.go new file mode 100644 index 0000000..981cc0d --- /dev/null +++ b/server/edit.go @@ -0,0 +1,87 @@ +package server + +import ( + "fmt" + "github.com/1f349/lavender/database" + "github.com/1f349/lavender/lists" + "github.com/1f349/lavender/pages" + "github.com/google/uuid" + "github.com/julienschmidt/httprouter" + "net/http" + "time" +) + +func (h *httpServer) EditGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { + var user database.User + + if h.DbTx(rw, func(tx *database.Queries) error { + var err error + user, err = tx.GetUser(req.Context(), auth.Subject) + if err != nil { + return fmt.Errorf("failed to read user data: %w", err) + } + return nil + }) { + return + } + + lNonce := uuid.NewString() + http.SetCookie(rw, &http.Cookie{ + Name: "tulip-nonce", + Value: lNonce, + Path: "/", + Expires: time.Now().Add(10 * time.Minute), + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + pages.RenderPageTemplate(rw, "edit", map[string]any{ + "ServiceName": h.conf.ServiceName, + "User": user, + "Nonce": lNonce, + "FieldPronoun": user.Pronouns.String(), + "ListZoneInfo": lists.ListZoneInfo(), + "ListLocale": lists.ListLocale(), + }) +} +func (h *httpServer) EditPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { + if req.ParseForm() != nil { + rw.WriteHeader(http.StatusBadRequest) + _, _ = rw.Write([]byte("400 Bad Request\n")) + return + } + + var patch database.ProfilePatch + errs := patch.ParseFromForm(req.Form) + if len(errs) > 0 { + rw.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprintln(rw, "\n\n") + _, _ = fmt.Fprintln(rw, "

400 Bad Request: Failed to parse form data, press the back button in your browser, check your inputs and try again.

") + _, _ = fmt.Fprintln(rw, "") + _, _ = fmt.Fprintln(rw, "\n") + return + } + m := database.ModifyProfileParams{ + Name: patch.Name, + Picture: patch.Picture, + Website: patch.Website, + Pronouns: patch.Pronouns, + Birthdate: patch.Birthdate, + Zone: patch.Zone.String(), + Locale: patch.Locale, + UpdatedAt: time.Now(), + Subject: auth.Subject, + } + if h.DbTx(rw, func(tx *database.Queries) error { + if err := tx.ModifyProfile(req.Context(), m); err != nil { + return fmt.Errorf("failed to modify user info: %w", err) + } + return nil + }) { + return + } + http.Redirect(rw, req, "/edit", http.StatusFound) +} diff --git a/server/home.go b/server/home.go index cc93e98..2b67a64 100644 --- a/server/home.go +++ b/server/home.go @@ -42,4 +42,51 @@ func (h *httpServer) Home(rw http.ResponseWriter, req *http.Request, _ httproute "Nonce": lNonce, "IsAdmin": isAdmin, }) + + // rw.Header().Set("Content-Type", "text/html") + // lNonce := uuid.NewString() + // http.SetCookie(rw, &http.Cookie{ + // Name: "tulip-nonce", + // Value: lNonce, + // Path: "/", + // Expires: time.Now().Add(10 * time.Minute), + // Secure: true, + // SameSite: http.SameSiteLaxMode, + // }) + // + // if auth.IsGuest() { + // pages.RenderPageTemplate(rw, "index-guest", map[string]any{ + // "ServiceName": h.conf.ServiceName, + // }) + // return + // } + // + // var userWithName string + // var userRole types.UserRole + // var hasTwoFactor bool + // if h.DbTx(rw, func(tx *database.Queries) (err error) { + // userWithName, err = tx.GetUserDisplayName(req.Context(), auth.Subject) + // if err != nil { + // return fmt.Errorf("failed to get user display name: %w", err) + // } + // hasTwoFactor, err = tx.HasOtp(req.Context(), auth.Subject) + // if err != nil { + // return fmt.Errorf("failed to get user two factor state: %w", err) + // } + // userRole, err = tx.GetUserRole(req.Context(), auth.Subject) + // if err != nil { + // return fmt.Errorf("failed to get user role: %w", err) + // } + // return + // }) { + // return + // } + // pages.RenderPageTemplate(rw, "index", map[string]any{ + // "ServiceName": h.conf.ServiceName, + // "Auth": auth, + // "User": database.User{Subject: auth.Subject, Name: userWithName, Role: userRole}, + // "Nonce": lNonce, + // "OtpEnabled": hasTwoFactor, + // "IsAdmin": userRole == types.RoleAdmin, + // }) } diff --git a/server/login.go b/server/login.go index 323a03c..96c56b6 100644 --- a/server/login.go +++ b/server/login.go @@ -8,6 +8,7 @@ import ( "fmt" auth2 "github.com/1f349/lavender/auth" "github.com/1f349/lavender/database" + "github.com/1f349/lavender/database/types" "github.com/1f349/lavender/issuer" "github.com/1f349/lavender/pages" "github.com/1f349/mjwt" @@ -15,13 +16,31 @@ import ( "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "github.com/julienschmidt/httprouter" + "github.com/mrmelon54/pronouns" "golang.org/x/oauth2" + "golang.org/x/text/language" "net/http" "net/url" "strings" "time" ) +// getUserLoginName finds the `login_name` query parameter within the `/authorize` redirect url +func getUserLoginName(req *http.Request) string { + q := req.URL.Query() + if !q.Has("redirect") { + return "" + } + originUrl, err := url.ParseRequestURI(q.Get("redirect")) + if err != nil { + return "" + } + if originUrl.Path != "/authorize" { + return "" + } + return originUrl.Query().Get("login_name") +} + func (h *httpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { if !auth.IsGuest() { h.SafeRedirect(rw, req) @@ -131,41 +150,70 @@ func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellK } err = h.DbTxError(func(tx *database.Queries) error { - jBytes, err := json.Marshal(sessionData.UserInfo) + name := sessionData.UserInfo.GetStringOrDefault("name", "Unknown User") + + _, err = tx.GetUser(req.Context(), sessionData.Subject) + uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost") + uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified") + if errors.Is(err, sql.ErrNoRows) { + _, err := tx.AddOAuthUser(req.Context(), database.AddOAuthUserParams{ + Email: uEmail, + EmailVerified: uEmailVerified, + Name: name, + Username: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"), + AuthNamespace: sso.Namespace, + AuthUser: sessionData.UserInfo.GetStringOrEmpty("sub"), + }) + return err + } + + err = tx.ModifyUserEmail(req.Context(), database.ModifyUserEmailParams{ + Email: uEmail, + EmailVerified: uEmailVerified, + Subject: sessionData.Subject, + }) if err != nil { return err } - _, err = tx.GetUser(req.Context(), sessionData.Subject) - if errors.Is(err, sql.ErrNoRows) { - uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost") - uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified") - id, err := tx.AddUser(req.Context(), database.AddUserParams{ - Name: "", - Subject: sessionData.Subject, - Password: "", - Email: uEmail, - EmailVerified: uEmailVerified, - UpdatedAt: time.Now(), - Active: true, - }) + + err = tx.ModifyUserAuth(req.Context(), database.ModifyUserAuthParams{ + AuthType: types.AuthTypeOauth2, + AuthNamespace: sso.Namespace, + AuthUser: sessionData.UserInfo.GetStringOrEmpty("sub"), + Subject: sessionData.Subject, + }) + if err != nil { return err - return tx.AddUser(req.Context(), database.AddUserParams{ - Subject: sessionData.Subject, - Email: uEmail, - EmailVerified: uEmailVerified, - Roles: "", - Userinfo: string(jBytes), - UpdatedAt: time.Now(), - Active: true, - }) } - uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost") - uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified") - return tx.UpdateUserInfo(req.Context(), database.UpdateUserInfoParams{ - Email: sessionData.Subject, - EmailVerified: uEmailVerified, - Userinfo: string(jBytes), - Subject: uEmail, + + err = tx.ModifyUserRemoteLogin(req.Context(), database.ModifyUserRemoteLoginParams{ + Login: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"), + ProfileUrl: sessionData.UserInfo.GetStringOrEmpty("profile"), + Subject: sessionData.Subject, + }) + if err != nil { + return err + } + + pronoun, err := pronouns.FindPronoun(sessionData.UserInfo.GetStringOrEmpty("pronouns")) + if err != nil { + pronoun = pronouns.TheyThem + } + locale, err := language.Parse(sessionData.UserInfo.GetStringOrEmpty("locale")) + if err != nil { + locale = language.AmericanEnglish + } + + return tx.ModifyProfile(req.Context(), database.ModifyProfileParams{ + Name: name, + Picture: sessionData.UserInfo.GetStringOrEmpty("profile"), + Website: sessionData.UserInfo.GetStringOrEmpty("website"), + Pronouns: types.UserPronoun{Pronoun: pronoun}, + Birthdate: sessionData.UserInfo.GetNullDate("birthdate"), + Zone: sessionData.UserInfo.GetStringOrDefault("zoneinfo", "UTC"), + Locale: types.UserLocale{Tag: locale}, + UpdatedAt: time.Now(), + Subject: sessionData.Subject, }) }) if err != nil { @@ -177,7 +225,7 @@ func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellK return tx.UpdateUserToken(req.Context(), database.UpdateUserTokenParams{ AccessToken: sql.NullString{String: token.AccessToken, Valid: true}, RefreshToken: sql.NullString{String: token.RefreshToken, Valid: true}, - Expiry: sql.NullTime{Time: token.Expiry, Valid: true}, + TokenExpiry: sql.NullTime{Time: token.Expiry, Valid: true}, Subject: sessionData.Subject, }) }); err != nil { @@ -208,6 +256,11 @@ func (l lavenderLoginRefresh) Valid() error { return l.RefreshTokenClaims.Valid( func (l lavenderLoginRefresh) Type() string { return "lavender-login-refresh" } +func (h *httpServer) setLoginDataCookie2(rw http.ResponseWriter, authData UserAuth) bool { + // TODO(melon): should probably merge there methods + return h.setLoginDataCookie(rw, authData, "") +} + func (h *httpServer) setLoginDataCookie(rw http.ResponseWriter, authData UserAuth, loginName string) bool { ps := auth.NewPermStorage() accId := uuid.NewString() @@ -286,13 +339,13 @@ func (h *httpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Re if err != nil { return err } - if !token.AccessToken.Valid || !token.RefreshToken.Valid || !token.Expiry.Valid { + if !token.AccessToken.Valid || !token.RefreshToken.Valid || !token.TokenExpiry.Valid { return fmt.Errorf("invalid oauth token") } oauthToken = &oauth2.Token{ AccessToken: token.AccessToken.String, RefreshToken: token.RefreshToken.String, - Expiry: token.Expiry.Time, + Expiry: token.TokenExpiry.Time, } return nil }) diff --git a/server/logout.go b/server/logout.go new file mode 100644 index 0000000..1d721d2 --- /dev/null +++ b/server/logout.go @@ -0,0 +1,25 @@ +package server + +import ( + "github.com/julienschmidt/httprouter" + "net/http" +) + +func (h *httpServer) logoutPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, _ UserAuth) { + http.SetCookie(rw, &http.Cookie{ + Name: "lavender-login-access", + Path: "/", + MaxAge: -1, + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + http.SetCookie(rw, &http.Cookie{ + Name: "lavender-login-refresh", + Path: "/", + MaxAge: -1, + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + + http.Redirect(rw, req, "/", http.StatusFound) +} diff --git a/server/mail.go b/server/mail.go new file mode 100644 index 0000000..55759f7 --- /dev/null +++ b/server/mail.go @@ -0,0 +1,123 @@ +package server + +import ( + "github.com/1f349/lavender/database" + "github.com/1f349/lavender/pages" + "github.com/emersion/go-message/mail" + "github.com/julienschmidt/httprouter" + "net/http" +) + +func (h *httpServer) MailVerify(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + code := params.ByName("code") + + k := mailLinkKey{mailLinkVerifyEmail, code} + + userSub, ok := h.mailLinkCache.Get(k) + if !ok { + http.Error(rw, "Invalid email verification code", http.StatusBadRequest) + return + } + if h.DbTx(rw, func(tx *database.Queries) error { + return tx.VerifyUserEmail(req.Context(), userSub) + }) { + return + } + + h.mailLinkCache.Delete(k) + + http.Error(rw, "Email address has been verified, you may close this tab and return to the login page.", http.StatusOK) +} + +func (h *httpServer) MailPassword(rw http.ResponseWriter, _ *http.Request, params httprouter.Params) { + code := params.ByName("code") + + k := mailLinkKey{mailLinkResetPassword, code} + _, ok := h.mailLinkCache.Get(k) + if !ok { + http.Error(rw, "Invalid password reset code", http.StatusBadRequest) + return + } + + pages.RenderPageTemplate(rw, "reset-password", map[string]any{ + "ServiceName": h.conf.ServiceName, + "Code": code, + }) +} + +func (h *httpServer) MailPasswordPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { + pw := req.PostFormValue("new_password") + rpw := req.PostFormValue("confirm_password") + code := req.PostFormValue("code") + + // reverse passwords are possible + if len(pw) == 0 { + http.Error(rw, "Cannot set an empty password", http.StatusBadRequest) + return + } + // bcrypt only allows up to 72 bytes anyway + if len(pw) > 64 { + http.Error(rw, "Security by extremely long password is a weird flex", http.StatusBadRequest) + return + } + if rpw != pw { + http.Error(rw, "Passwords do not match", http.StatusBadRequest) + return + } + + k := mailLinkKey{mailLinkResetPassword, code} + userSub, ok := h.mailLinkCache.Get(k) + if !ok { + http.Error(rw, "Invalid password reset code", http.StatusBadRequest) + return + } + + h.mailLinkCache.Delete(k) + + // reset password database call + if h.DbTx(rw, func(tx *database.Queries) error { + return tx.ChangePassword(req.Context(), userSub, pw) + }) { + return + } + + http.Error(rw, "Reset password successfully, you can login now.", http.StatusOK) +} + +func (h *httpServer) MailDelete(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + code := params.ByName("code") + + k := mailLinkKey{mailLinkDelete, code} + userSub, ok := h.mailLinkCache.Get(k) + if !ok { + http.Error(rw, "Invalid email delete code", http.StatusBadRequest) + return + } + var userInfo database.User + if h.DbTx(rw, func(tx *database.Queries) (err error) { + userInfo, err = tx.GetUser(req.Context(), userSub) + if err != nil { + return + } + return tx.FlagUserAsDeleted(req.Context(), userSub) + }) { + return + } + + h.mailLinkCache.Delete(k) + + // parse email for headers + address, err := mail.ParseAddress(userInfo.Email) + if err != nil { + http.Error(rw, "500 Internal Server Error: Failed to parse user email address", http.StatusInternalServerError) + return + } + + err = h.conf.Mail.SendEmailTemplate("mail-account-delete", "Account Deletion", userInfo.Name, address, nil) + if err != nil { + http.Error(rw, "Failed to send confirmation email.", http.StatusInternalServerError) + return + } + + http.Error(rw, "You will receive an email shortly to verify this action, you may close this tab.", http.StatusOK) +} diff --git a/server/manage-apps.go b/server/manage-apps.go index d95b752..404d46a 100644 --- a/server/manage-apps.go +++ b/server/manage-apps.go @@ -12,11 +12,17 @@ import ( "strconv" ) +func SetupManageApps(r *httprouter.Router, hs *httpServer) { + r.GET("/manage/apps", hs.RequireAuthentication(hs.ManageAppsGet)) + r.GET("/manage/apps/create", hs.RequireAuthentication(hs.ManageAppsCreateGet)) + r.POST("/manage/apps", hs.RequireAuthentication(hs.ManageAppsPost)) +} + func (h *httpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { q := req.URL.Query() offset, _ := strconv.Atoi(q.Get("offset")) - var roles string + var roles []string var appList []database.GetAppListRow if h.DbTx(rw, func(tx *database.Queries) (err error) { roles, err = tx.GetUserRoles(req.Context(), auth.Subject) @@ -24,9 +30,9 @@ func (h *httpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _ return } appList, err = tx.GetAppList(req.Context(), database.GetAppListParams{ - Owner: auth.Subject, - Column2: HasRole(roles, role.LavenderAdmin), - Offset: int64(offset), + OwnerSubject: auth.Subject, + Column2: HasRole(roles, role.LavenderAdmin), + Offset: int64(offset), }) return }) { @@ -61,7 +67,7 @@ func (h *httpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _ } func (h *httpServer) ManageAppsCreateGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { - var roles string + var roles []string if h.DbTx(rw, func(tx *database.Queries) (err error) { roles, err = tx.GetUserRoles(req.Context(), auth.Subject) return @@ -96,7 +102,7 @@ func (h *httpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _ active := req.Form.Has("active") if sso || hasPerms { - var roles string + var roles []string if h.DbTx(rw, func(tx *database.Queries) (err error) { roles, err = tx.GetUserRoles(req.Context(), auth.Subject) return @@ -121,15 +127,15 @@ func (h *httpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _ return err } return tx.InsertClientApp(req.Context(), database.InsertClientAppParams{ - Subject: uuid.NewString(), - Name: name, - Secret: secret, - Domain: domain, - Owner: auth.Subject, - Perms: perms, - Public: public, - Sso: sso, - Active: active, + Subject: uuid.NewString(), + Name: name, + Secret: secret, + Domain: domain, + OwnerSubject: auth.Subject, + Perms: perms, + Public: public, + Sso: sso, + Active: active, }) }) { return @@ -137,15 +143,15 @@ func (h *httpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _ case "edit": if h.DbTx(rw, func(tx *database.Queries) error { return tx.UpdateClientApp(req.Context(), database.UpdateClientAppParams{ - Name: name, - Domain: domain, - Column3: hasPerms, - Perms: perms, - Public: public, - Sso: sso, - Active: active, - Subject: req.FormValue("subject"), - Owner: auth.Subject, + Name: name, + Domain: domain, + Column3: hasPerms, + Perms: perms, + Public: public, + Sso: sso, + Active: active, + Subject: req.FormValue("subject"), + OwnerSubject: auth.Subject, }) }) { return @@ -164,9 +170,9 @@ func (h *httpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _ return err } err = tx.ResetClientAppSecret(req.Context(), database.ResetClientAppSecretParams{ - Secret: secret, - Subject: sub, - Owner: auth.Subject, + Secret: secret, + Subject: sub, + OwnerSubject: auth.Subject, }) return err }) { diff --git a/server/manage-users.go b/server/manage-users.go index b3f3c8f..bf4fa8d 100644 --- a/server/manage-users.go +++ b/server/manage-users.go @@ -5,16 +5,22 @@ import ( "github.com/1f349/lavender/pages" "github.com/1f349/lavender/role" "github.com/julienschmidt/httprouter" + "golang.org/x/sync/errgroup" "net/http" "net/url" "strconv" ) +func SetupManageUsers(r *httprouter.Router, hs *httpServer) { + r.GET("/manage/users", hs.RequireAdminAuthentication(hs.ManageUsersGet)) + r.POST("/manage/users", hs.RequireAdminAuthentication(hs.ManageUsersPost)) +} + func (h *httpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { q := req.URL.Query() offset, _ := strconv.Atoi(q.Get("offset")) - var roles string + var roles []string var userList []database.GetUserListRow if h.DbTx(rw, func(tx *database.Queries) (err error) { roles, err = tx.GetUserRoles(req.Context(), auth.Subject) @@ -64,7 +70,7 @@ func (h *httpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request, return } - var roles string + var roles []string if h.DbTx(rw, func(tx *database.Queries) (err error) { roles, err = tx.GetUserRoles(req.Context(), auth.Subject) return @@ -78,17 +84,37 @@ func (h *httpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request, offset := req.Form.Get("offset") action := req.Form.Get("action") - newRoles := req.Form.Get("roles") + newRoles := req.Form["roles"] active := req.Form.Has("active") switch action { case "edit": if h.DbTx(rw, func(tx *database.Queries) error { sub := req.Form.Get("subject") - return tx.UpdateUser(req.Context(), database.UpdateUserParams{ - Active: active, - Roles: newRoles, - Subject: sub, + return tx.UseTx(req.Context(), func(tx *database.Queries) (err error) { + err = tx.ChangeUserActive(req.Context(), database.ChangeUserActiveParams{Column1: active, Subject: sub}) + if err != nil { + return err + } + err = tx.RemoveUserRoles(req.Context(), sub) + if err != nil { + return err + } + errGrp := new(errgroup.Group) + errGrp.SetLimit(3) + for _, roleName := range newRoles { + errGrp.Go(func() error { + roleId, err := strconv.ParseInt(roleName, 10, 64) + if err != nil { + return err + } + return tx.AddUserRole(req.Context(), database.AddUserRoleParams{ + RoleID: roleId, + Subject: sub, + }) + }) + } + return errGrp.Wait() }) }) { return diff --git a/server/oauth.go b/server/oauth.go index 15ba6b7..ef11372 100644 --- a/server/oauth.go +++ b/server/oauth.go @@ -1,15 +1,143 @@ package server import ( + "encoding/json" + clientStore "github.com/1f349/lavender/client-store" + "github.com/1f349/lavender/database" "github.com/1f349/lavender/logger" "github.com/1f349/lavender/pages" "github.com/1f349/lavender/scope" + "github.com/1f349/lavender/utils" + "github.com/1f349/mjwt" + "github.com/go-oauth2/oauth2/v4/generates" + "github.com/go-oauth2/oauth2/v4/manage" + "github.com/go-oauth2/oauth2/v4/server" + "github.com/go-oauth2/oauth2/v4/store" "github.com/julienschmidt/httprouter" "net/http" "net/url" "strings" + "time" ) +func SetupOAuth2(r *httprouter.Router, hs *httpServer, key *mjwt.Issuer, db *database.Queries) { + oauthManager := manage.NewManager() + oauthManager.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) + oauthManager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) + oauthManager.MustTokenStorage(store.NewMemoryTokenStore()) + oauthManager.MapAccessGenerate(NewMJWTAccessGenerate(key, db)) + oauthManager.MapClientStorage(clientStore.New(db)) + + oauthSrv := server.NewDefaultServer(oauthManager) + oauthSrv.SetClientInfoHandler(func(req *http.Request) (clientID, clientSecret string, err error) { + cId, cSecret, err := server.ClientBasicHandler(req) + if cId == "" && cSecret == "" { + cId, cSecret, err = server.ClientFormHandler(req) + } + if err != nil { + return "", "", err + } + return cId, cSecret, nil + }) + oauthSrv.SetUserAuthorizationHandler(hs.oauthUserAuthorization) + oauthSrv.SetAuthorizeScopeHandler(func(rw http.ResponseWriter, req *http.Request) (string, error) { + var form url.Values + if req.Method == http.MethodPost { + form = req.PostForm + } else { + form = req.URL.Query() + } + a := form.Get("scope") + if !scope.ScopesExist(a) { + return "", errInvalidScope + } + return a, nil + }) + addIdTokenSupport(oauthSrv, db, key) + + r.GET("/authorize", hs.RequireAuthentication(hs.authorizeEndpoint)) + r.POST("/authorize", hs.RequireAuthentication(hs.authorizeEndpoint)) + r.POST("/token", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + if err := oauthSrv.HandleTokenRequest(rw, req); err != nil { + http.Error(rw, "Failed to handle token request", http.StatusInternalServerError) + } + }) +} + +func (h *httpServer) userInfoRequest(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("Access-Control-Allow-Credentials", "true") + rw.Header().Set("Access-Control-Allow-Headers", "Authorization,Content-Type") + rw.Header().Set("Access-Control-Allow-Origin", strings.TrimSuffix(req.Referer(), "/")) + rw.Header().Set("Access-Control-Allow-Methods", "GET") + if req.Method == http.MethodOptions { + return + } + + token, err := h.oauthSrv.ValidationBearerToken(req) + if err != nil { + http.Error(rw, "403 Forbidden", http.StatusForbidden) + return + } + userId := token.GetUserID() + + sso := h.manager.FindServiceFromLogin(userId) + if sso == nil { + http.Error(rw, "Invalid user", http.StatusBadRequest) + return + } + + var user database.User + if h.DbTx(rw, func(tx *database.Queries) (err error) { + user, err = tx.GetUser(req.Context(), userId) + return + }) { + return + } + + claims := ParseClaims(token.GetScope()) + if !claims["openid"] { + http.Error(rw, "Invalid scope", http.StatusBadRequest) + return + } + + m := make(map[string]any) + + if claims["name"] { + m["name"] = user.Name + } + if claims["username"] { + m["preferred_username"] = user.Login + m["login"] = user.Login + } + if claims["profile"] { + m["profile"] = user.ProfileUrl + m["picture"] = user.Picture + m["website"] = user.Website + } + if claims["email"] { + m["email"] = user.Email + m["email_verified"] = user.EmailVerified + } + if claims["birthdate"] && user.Birthdate.Valid { + m["birthdate"] = user.Birthdate.Date + } + if claims["age"] && user.Birthdate.Valid { + m["age"] = utils.Age(user.Birthdate.Date.ToTime()) + } + if claims["zoneinfo"] { + m["zoneinfo"] = user.Zone + } + if claims["locale"] { + m["locale"] = user.Locale + } + + m["sub"] = userId + m["aud"] = token.GetClientID() + m["updated_at"] = time.Now().Unix() + + _ = json.NewEncoder(rw).Encode(m) +} + func (h *httpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { // function is only called with GET or POST method isPost := req.Method == http.MethodPost @@ -95,7 +223,7 @@ func (h *httpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request "ServiceName": h.conf.ServiceName, "AppName": appName, "AppDomain": appDomain, - "DisplayName": auth.DisplayName, + "DisplayName": auth.UserInfo.GetStringOrEmpty("name"), "WantsList": scope.FancyScopeList(scopeList), "ResponseType": form.Get("response_type"), "ResponseMode": form.Get("response_mode"), diff --git a/server/otp.go b/server/otp.go new file mode 100644 index 0000000..0a7e799 --- /dev/null +++ b/server/otp.go @@ -0,0 +1,196 @@ +package server + +import ( + "bytes" + "context" + "encoding/base64" + "github.com/1f349/lavender/database" + "github.com/1f349/lavender/pages" + "github.com/julienschmidt/httprouter" + "github.com/skip2/go-qrcode" + "github.com/xlzd/gotp" + "html/template" + "image/png" + "net/http" + "time" +) + +func (h *httpServer) loginOtpGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { + if !auth.NeedOtp { + h.SafeRedirect(rw, req) + return + } + + pages.RenderPageTemplate(rw, "login-otp", map[string]any{ + "ServiceName": h.conf.ServiceName, + "Redirect": req.URL.Query().Get("redirect"), + }) +} + +func (h *httpServer) loginOtpPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { + if !auth.NeedOtp { + http.Redirect(rw, req, "/", http.StatusFound) + return + } + + otpInput := req.FormValue("code") + if h.fetchAndValidateOtp(rw, auth.Subject, otpInput) { + return + } + + auth.NeedOtp = false + + h.setLoginDataCookie2(rw, auth) + h.SafeRedirect(rw, req) +} + +func (h *httpServer) fetchAndValidateOtp(rw http.ResponseWriter, sub, code string) bool { + var hasOtp bool + var otpRow database.GetOtpRow + var secret string + var digits int64 + if h.DbTx(rw, func(tx *database.Queries) (err error) { + hasOtp, err = tx.HasOtp(context.Background(), sub) + if err != nil { + return + } + if hasOtp { + otpRow, err = tx.GetOtp(context.Background(), sub) + secret = otpRow.OtpSecret + digits = otpRow.OtpDigits + } + return + }) { + return true + } + + if hasOtp { + totp := gotp.NewTOTP(secret, int(digits), 30, nil) + if !verifyTotp(totp, code) { + http.Error(rw, "400 Bad Request: Invalid OTP code", http.StatusBadRequest) + return true + } + } + + return false +} + +func (h *httpServer) editOtpPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { + if req.Method == http.MethodPost && req.FormValue("remove") == "1" { + if !req.Form.Has("code") { + // render page + pages.RenderPageTemplate(rw, "remove-otp", map[string]any{ + "ServiceName": h.conf.ServiceName, + }) + return + } + + otpInput := req.Form.Get("code") + if h.fetchAndValidateOtp(rw, auth.Subject, otpInput) { + return + } + + if h.DbTx(rw, func(tx *database.Queries) error { + return tx.DeleteOtp(req.Context(), auth.Subject) + }) { + return + } + + http.Redirect(rw, req, "/", http.StatusFound) + return + } + + var digits int + switch req.FormValue("digits") { + case "6": + digits = 6 + case "7": + digits = 7 + case "8": + digits = 8 + default: + http.Error(rw, "400 Bad Request: Invalid number of digits for OTP code", http.StatusBadRequest) + return + } + + secret := req.FormValue("secret") + if !gotp.IsSecretValid(secret) { + http.Error(rw, "400 Bad Request: Invalid secret", http.StatusBadRequest) + return + } + + if secret == "" { + // get user email + var email string + if h.DbTx(rw, func(tx *database.Queries) error { + var err error + email, err = tx.GetUserEmail(req.Context(), auth.Subject) + return err + }) { + return + } + + secret = gotp.RandomSecret(64) + if secret == "" { + http.Error(rw, "500 Internal Server Error: failed to generate OTP secret", http.StatusInternalServerError) + return + } + totp := gotp.NewTOTP(secret, digits, 30, nil) + otpUri := totp.ProvisioningUri(email, h.conf.OtpIssuer) + code, err := qrcode.New(otpUri, qrcode.Medium) + if err != nil { + http.Error(rw, "500 Internal Server Error: failed to generate QR code", http.StatusInternalServerError) + return + } + qrImg := code.Image(60 * 4) + qrBounds := qrImg.Bounds() + qrWidth := qrBounds.Dx() + + qrBuf := new(bytes.Buffer) + if png.Encode(qrBuf, qrImg) != nil { + http.Error(rw, "500 Internal Server Error: failed to generate PNG image of QR code", http.StatusInternalServerError) + return + } + + // render page + pages.RenderPageTemplate(rw, "edit-otp", map[string]any{ + "ServiceName": h.conf.ServiceName, + "OtpQr": template.URL("data:qrImg/png;base64," + base64.StdEncoding.EncodeToString(qrBuf.Bytes())), + "QrWidth": qrWidth, + "OtpUrl": otpUri, + "OtpSecret": secret, + "OtpDigits": digits, + }) + return + } + + totp := gotp.NewTOTP(secret, digits, 30, nil) + + if !verifyTotp(totp, req.FormValue("code")) { + http.Error(rw, "400 Bad Request: invalid OTP code", http.StatusBadRequest) + return + } + + if h.DbTx(rw, func(tx *database.Queries) error { + return tx.SetOtp(req.Context(), database.SetOtpParams{ + Subject: auth.Subject, + OtpSecret: secret, + OtpDigits: int64(digits), + }) + }) { + return + } + + http.Redirect(rw, req, "/", http.StatusFound) +} + +func verifyTotp(totp *gotp.TOTP, code string) bool { + t := time.Now() + if totp.VerifyTime(code, t) { + return true + } + if totp.VerifyTime(code, t.Add(-30*time.Second)) { + return true + } + return totp.VerifyTime(code, t.Add(30*time.Second)) +} diff --git a/server/roles_test.go b/server/roles_test.go index 008ef00..2925f75 100644 --- a/server/roles_test.go +++ b/server/roles_test.go @@ -7,6 +7,5 @@ import ( func TestHasRole(t *testing.T) { assert.True(t, HasRole([]string{"lavender:admin", "test:something-else"}, "lavender:admin")) - assert.False(t, HasRole([]string{"lavender:admin", "test:something-else"}, "lavender:admin")) assert.False(t, HasRole([]string{"lavender:", "test:something-else"}, "lavender:admin")) } diff --git a/server/server.go b/server/server.go index 4af27a6..7e170e6 100644 --- a/server/server.go +++ b/server/server.go @@ -3,17 +3,14 @@ package server import ( "errors" "github.com/1f349/cache" - clientStore "github.com/1f349/lavender/client-store" "github.com/1f349/lavender/conf" "github.com/1f349/lavender/database" "github.com/1f349/lavender/issuer" + "github.com/1f349/lavender/logger" "github.com/1f349/lavender/pages" - scope2 "github.com/1f349/lavender/scope" "github.com/1f349/mjwt" - "github.com/go-oauth2/oauth2/v4/generates" "github.com/go-oauth2/oauth2/v4/manage" "github.com/go-oauth2/oauth2/v4/server" - "github.com/go-oauth2/oauth2/v4/store" "github.com/julienschmidt/httprouter" "net/http" "net/url" @@ -76,44 +73,15 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, db *database.Queries, s mailLinkCache: cache.New[mailLinkKey, string](), } - oauthManager := manage.NewManager() - oauthManager.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) - oauthManager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) - oauthManager.MustTokenStorage(store.NewMemoryTokenStore()) - oauthManager.MapAccessGenerate(NewMJWTAccessGenerate(signingKey, db)) - oauthManager.MapClientStorage(clientStore.New(db)) - - oauthSrv := server.NewDefaultServer(oauthManager) - oauthSrv.SetClientInfoHandler(func(req *http.Request) (clientID, clientSecret string, err error) { - cId, cSecret, err := server.ClientBasicHandler(req) - if cId == "" && cSecret == "" { - cId, cSecret, err = server.ClientFormHandler(req) - } - if err != nil { - return "", "", err - } - return cId, cSecret, nil - }) - oauthSrv.SetUserAuthorizationHandler(hs.oauthUserAuthorization) - oauthSrv.SetAuthorizeScopeHandler(func(rw http.ResponseWriter, req *http.Request) (scope string, err error) { - var form url.Values - if req.Method == http.MethodPost { - form = req.PostForm - } else { - form = req.URL.Query() - } - a := form.Get("scope") - if !scope2.ScopesExist(a) { - return "", errInvalidScope - } - return a, nil - }) - addIdTokenSupport(oauthSrv, db, signingKey) - - ssoManager := issuer.NewManager(config.SsoServices) + var err error + hs.manager, err = issuer.NewManager(config.SsoServices) + if err != nil { + logger.Logger.Fatal("Failed to load SSO services", "err", err) + } SetupOpenId(r, config.BaseUrl, signingKey) - r.POST("/logout", hs.RequireAuthentication(fu)) + r.GET("/", hs.OptionalAuthentication(false, hs.Home)) + r.POST("/logout", hs.RequireAuthentication(hs.logoutPost)) // theme styles r.GET("/assets/*filepath", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { @@ -126,8 +94,16 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, db *database.Queries, s http.ServeContent(rw, req, path.Base(name), contentCache, out) }) - SetupManageApps(r) - SetupManageUsers(r) + // login steps + r.GET("/login", hs.OptionalAuthentication(false, hs.loginGet)) + r.POST("/login", hs.OptionalAuthentication(false, hs.loginPost)) + r.GET("/login/otp", hs.OptionalAuthentication(true, hs.loginOtpGet)) + r.POST("/login/otp", hs.OptionalAuthentication(true, hs.loginOtpPost)) + r.GET("/callback", hs.OptionalAuthentication(false, hs.loginCallback)) + + SetupManageApps(r, hs) + SetupManageUsers(r, hs) + SetupOAuth2(r, hs, signingKey, db) } func (h *httpServer) SafeRedirect(rw http.ResponseWriter, req *http.Request) { diff --git a/sqlc.yaml b/sqlc.yaml index 2716e86..2b22daf 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -21,3 +21,11 @@ sql: go_type: "github.com/1f349/lavender/database/types.UserZone" - column: "users.locale" go_type: "github.com/1f349/lavender/database/types.UserLocale" + - column: "users.auth_type" + go_type: "github.com/1f349/lavender/database/types.AuthType" + - column: "users.access_token" + go_type: "database/sql.NullString" + - column: "users.refresh_token" + go_type: "database/sql.NullString" + - column: "users.token_expiry" + go_type: "database/sql.NullTime" diff --git a/utils/age.go b/utils/age.go new file mode 100644 index 0000000..eaf1b91 --- /dev/null +++ b/utils/age.go @@ -0,0 +1,28 @@ +package utils + +import ( + "time" +) + +var ageTimeNow = time.Now + +func Age(t time.Time) int { + n := ageTimeNow() + + // the birthday is in the future so the age is 0 + if n.Before(t) { + return 0 + } + + // the year difference + dy := n.Year() - t.Year() + + // the birthday in the current year + tCurrent := t.AddDate(dy, 0, 0) + + // minus 1 if the birthday has not yet occurred in the current year + if tCurrent.Before(n) { + dy -= 1 + } + return dy +} diff --git a/utils/age_test.go b/utils/age_test.go new file mode 100644 index 0000000..0972227 --- /dev/null +++ b/utils/age_test.go @@ -0,0 +1,30 @@ +package utils + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestAge(t *testing.T) { + lGmt := time.FixedZone("GMT", 0) + lBst := time.FixedZone("BST", 60*60) + + tPast := time.Date(1939, time.January, 5, 0, 0, 0, 0, lGmt) + tPastDst := time.Date(2001, time.January, 5, 1, 0, 0, 0, lBst) + tCur := time.Date(2005, time.January, 5, 0, 30, 0, 0, lGmt) + tCurDst := time.Date(2005, time.January, 5, 0, 30, 0, 0, lBst) + tFut := time.Date(2008, time.January, 5, 0, 0, 0, 0, time.UTC) + + ageTimeNow = func() time.Time { return tCur } + assert.Equal(t, 65, Age(tPast)) + assert.Equal(t, 3, Age(tPastDst)) + assert.Equal(t, 0, Age(tFut)) + + ageTimeNow = func() time.Time { return tCurDst } + assert.Equal(t, 66, Age(tPast)) + assert.Equal(t, 4, Age(tPastDst)) + fmt.Println(tPastDst.AddDate(4, 0, 0).UTC(), tCur.UTC()) + assert.Equal(t, 0, Age(tFut)) +}