From 7064afd55e37c3fe3815a2298d1adfd7ff4cce0b Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Fri, 13 Sep 2024 15:31:40 +0100 Subject: [PATCH] A load more changes --- auth/auth.go | 11 + auth/login.go | 1 + auth/oauth.go | 1 + {server => auth}/userinfofields.go | 2 +- cmd/lavender/serve.go | 4 +- conf/conf.go | 16 +- .../migrations/20240820202502_init.up.sql | 43 ++-- database/password-wrapper.go | 17 +- database/queries/users.sql | 6 +- database/types/authtype.go | 16 ++ database/users.sql.go | 10 +- issuer/sso.go | 5 +- role/role.go | 5 + server/auth.go | 40 +++- server/db.go | 4 +- server/home.go | 5 +- server/jwt.go | 9 +- server/login.go | 31 ++- server/manage-apps.go | 11 +- server/manage-users.go | 7 +- server/oauth.go | 4 +- server/openid.go | 38 +++ server/roles_test.go | 6 +- server/server.go | 225 ++++-------------- 24 files changed, 245 insertions(+), 272 deletions(-) create mode 100644 auth/auth.go create mode 100644 auth/login.go create mode 100644 auth/oauth.go rename {server => auth}/userinfofields.go (96%) create mode 100644 database/types/authtype.go create mode 100644 role/role.go create mode 100644 server/openid.go diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 0000000..413e728 --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,11 @@ +package auth + +import "github.com/1f349/lavender/database" + +type LoginProvider interface { + AttemptLogin(username, password string) (database.User, error) +} + +type OAuthProvider interface { + AttemptLogin(username string) (database.User, error) +} diff --git a/auth/login.go b/auth/login.go new file mode 100644 index 0000000..8832b06 --- /dev/null +++ b/auth/login.go @@ -0,0 +1 @@ +package auth diff --git a/auth/oauth.go b/auth/oauth.go new file mode 100644 index 0000000..8832b06 --- /dev/null +++ b/auth/oauth.go @@ -0,0 +1 @@ +package auth diff --git a/server/userinfofields.go b/auth/userinfofields.go similarity index 96% rename from server/userinfofields.go rename to auth/userinfofields.go index 5b28e7c..7f2093c 100644 --- a/server/userinfofields.go +++ b/auth/userinfofields.go @@ -1,4 +1,4 @@ -package server +package auth type UserInfoFields map[string]any diff --git a/cmd/lavender/serve.go b/cmd/lavender/serve.go index ec7714c..ebdb20a 100644 --- a/cmd/lavender/serve.go +++ b/cmd/lavender/serve.go @@ -13,6 +13,7 @@ import ( "github.com/cloudflare/tableflip" "github.com/golang-jwt/jwt/v4" "github.com/google/subcommands" + "github.com/julienschmidt/httprouter" _ "github.com/mattn/go-sqlite3" "github.com/spf13/afero" "gopkg.in/yaml.v3" @@ -122,7 +123,8 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) logger.Logger.Fatal("Listen failed", "err", err) } - mux := server.NewHttpServer(config, db, signingKey) + mux := httprouter.New() + server.SetupRouter(mux, config, db, signingKey) srv := &http.Server{ Handler: mux, ReadTimeout: time.Minute, diff --git a/conf/conf.go b/conf/conf.go index 28d17a3..fd8f314 100644 --- a/conf/conf.go +++ b/conf/conf.go @@ -6,12 +6,12 @@ import ( ) type Conf struct { - Listen string `yaml:"listen"` - BaseUrl string `yaml:"baseUrl"` - ServiceName string `yaml:"serviceName"` - Issuer string `yaml:"issuer"` - Kid string `yaml:"kid"` - Namespace string `yaml:"namespace"` - Mail mail.Mail `yaml:"mail"` - SsoServices map[string]issuer.SsoConfig `yaml:"ssoServices"` + Listen string `yaml:"listen"` + BaseUrl string `yaml:"baseUrl"` + ServiceName string `yaml:"serviceName"` + Issuer string `yaml:"issuer"` + Kid string `yaml:"kid"` + Namespace string `yaml:"namespace"` + Mail mail.Mail `yaml:"mail"` + SsoServices []issuer.SsoConfig `yaml:"ssoServices"` } diff --git a/database/migrations/20240820202502_init.up.sql b/database/migrations/20240820202502_init.up.sql index 27795f0..d78164e 100644 --- a/database/migrations/20240820202502_init.up.sql +++ b/database/migrations/20240820202502_init.up.sql @@ -1,32 +1,33 @@ CREATE TABLE users ( - id INTEGER NOT NULL UNIQUE PRIMARY KEY AUTOINCREMENT, - subject TEXT NOT NULL UNIQUE, - password TEXT NOT NULL, + id INTEGER NOT NULL UNIQUE PRIMARY KEY AUTOINCREMENT, + subject TEXT NOT NULL UNIQUE, + password TEXT NOT NULL, - email TEXT NOT NULL, - email_verified BOOLEAN NOT NULL DEFAULT 0, + change_password BOOLEAN NOT NULL, - updated_at DATETIME NOT NULL, - registered DATETIME NOT NULL, - active BOOLEAN NOT NULL DEFAULT 1 + email TEXT NOT NULL, + email_verified BOOLEAN NOT NULL, + + updated_at DATETIME NOT NULL, + registered DATETIME NOT NULL, + active BOOLEAN NOT NULL DEFAULT 1, + + name TEXT NOT NULL, + picture TEXT NOT NULL DEFAULT '', + website TEXT NOT NULL DEFAULT '', + pronouns TEXT NOT NULL DEFAULT 'they/them', + birthdate DATE NULL DEFAULT NULL, + zone TEXT NOT NULL DEFAULT 'UTC', + locale TEXT NOT NULL DEFAULT 'en-US', + + auth_type INTEGER NOT NULL, + auth_namespace TEXT NOT NULL, + auth_user TEXT NOT NULL ); CREATE INDEX users_subject ON users (subject); -CREATE TABLE profiles -( - subject TEXT NOT NULL UNIQUE PRIMARY KEY, - name TEXT NOT NULL, - picture TEXT NOT NULL DEFAULT '', - website TEXT NOT NULL DEFAULT '', - pronouns TEXT NOT NULL DEFAULT 'they/them', - birthdate DATE NULL, - zone TEXT NOT NULL DEFAULT 'UTC', - locale TEXT NOT NULL DEFAULT 'en-US', - updated_at DATETIME NOT NULL -); - CREATE TABLE roles ( id INTEGER NOT NULL UNIQUE PRIMARY KEY AUTOINCREMENT, diff --git a/database/password-wrapper.go b/database/password-wrapper.go index e7a0d48..07f94ee 100644 --- a/database/password-wrapper.go +++ b/database/password-wrapper.go @@ -2,20 +2,19 @@ package database import ( "context" - "github.com/1f349/lavender/database/types" "github.com/1f349/lavender/password" "github.com/google/uuid" "time" ) type AddUserParams struct { - Name string `json:"name"` - Username string `json:"username"` - Password string `json:"password"` - Email string `json:"email"` - Role types.UserRole `json:"role"` - UpdatedAt time.Time `json:"updated_at"` - Active bool `json:"active"` + Name string `json:"name"` + Subject string `json:"subject"` + Password string `json:"password"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + UpdatedAt time.Time `json:"updated_at"` + Active bool `json:"active"` } func (q *Queries) AddUser(ctx context.Context, arg AddUserParams) (string, error) { @@ -28,7 +27,7 @@ func (q *Queries) AddUser(ctx context.Context, arg AddUserParams) (string, error Subject: uuid.NewString(), Password: pwHash, Email: arg.Email, - EmailVerified: false, + EmailVerified: arg.EmailVerified, UpdatedAt: n, Registered: n, Active: true, diff --git a/database/queries/users.sql b/database/queries/users.sql index d41fcce..1a916aa 100644 --- a/database/queries/users.sql +++ b/database/queries/users.sql @@ -6,6 +6,10 @@ FROM users; INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active) VALUES (?, ?, ?, ?, ?, ?, ?); +-- name: addOAuthUser :exec +INSERT INTO users (subject, password, email, email_verified, updated_at, registered, active) +VALUES (?, ?, ?, ?, ?, ?, ?); + -- name: checkLogin :one SELECT subject, password, EXISTS(SELECT 1 FROM otp WHERE otp.subject = users.subject) == 1 AS has_otp, email, email_verified FROM users @@ -25,7 +29,7 @@ FROM users_roles INNER JOIN users u on u.id = users_roles.user_id WHERE u.subject = ?; --- name: UserHasRole :one +-- name: UserHasRole :exec SELECT 1 FROM roles INNER JOIN users_roles on users_roles.user_id = roles.id diff --git a/database/types/authtype.go b/database/types/authtype.go new file mode 100644 index 0000000..903f717 --- /dev/null +++ b/database/types/authtype.go @@ -0,0 +1,16 @@ +package types + +type AuthType byte + +const ( + AuthTypeBase AuthType = iota + AuthTypeOauth2 +) + +var authTypeNames = map[AuthType]string{ + AuthTypeOauth2: "OAuth2", +} + +func (t AuthType) String() string { + return authTypeNames[t] +} diff --git a/database/users.sql.go b/database/users.sql.go index c814a26..04c1c02 100644 --- a/database/users.sql.go +++ b/database/users.sql.go @@ -78,7 +78,7 @@ func (q *Queries) HasUser(ctx context.Context) (bool, error) { return hasuser, err } -const userHasRole = `-- name: UserHasRole :one +const userHasRole = `-- name: UserHasRole :exec SELECT 1 FROM roles INNER JOIN users_roles on users_roles.user_id = roles.id @@ -92,11 +92,9 @@ type UserHasRoleParams struct { Subject string `json:"subject"` } -func (q *Queries) UserHasRole(ctx context.Context, arg UserHasRoleParams) (int64, error) { - row := q.db.QueryRowContext(ctx, userHasRole, arg.Role, arg.Subject) - var column_1 int64 - err := row.Scan(&column_1) - return column_1, err +func (q *Queries) UserHasRole(ctx context.Context, arg UserHasRoleParams) error { + _, err := q.db.ExecContext(ctx, userHasRole, arg.Role, arg.Subject) + return err } const addUser = `-- name: addUser :exec diff --git a/issuer/sso.go b/issuer/sso.go index ee81aec..59ddf40 100644 --- a/issuer/sso.go +++ b/issuer/sso.go @@ -17,8 +17,9 @@ var httpGet = http.Get // SsoConfig is the base URL for an OAUTH/OPENID/SSO login service // The path `/.well-known/openid-configuration` should be available type SsoConfig struct { - Addr utils.JsonUrl `json:"addr"` // https://login.example.com - Client SsoConfigClient `json:"client"` + Addr utils.JsonUrl `json:"addr"` // https://login.example.com + Namespace string `json:"namespace"` // example.com + Client SsoConfigClient `json:"client"` } type SsoConfigClient struct { diff --git a/role/role.go b/role/role.go new file mode 100644 index 0000000..6a6ee2b --- /dev/null +++ b/role/role.go @@ -0,0 +1,5 @@ +package role + +const prefix = "lavender:" + +const LavenderAdmin = prefix + "admin" diff --git a/server/auth.go b/server/auth.go index 2e7d741..60e222f 100644 --- a/server/auth.go +++ b/server/auth.go @@ -1,8 +1,11 @@ package server import ( + "database/sql" "errors" + "github.com/1f349/lavender/auth" "github.com/1f349/lavender/database" + "github.com/1f349/lavender/role" "github.com/julienschmidt/httprouter" "net/http" "net/url" @@ -12,25 +15,42 @@ import ( type UserHandler func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) type UserAuth struct { - Subject string - DisplayName string - UserInfo UserInfoFields + Subject string + NeedOtp bool + UserInfo auth.UserInfoFields } func (u UserAuth) IsGuest() bool { return u.Subject == "" } +func (u UserAuth) NextFlowUrl(origin *url.URL) *url.URL { + if u.NeedOtp { + return PrepareRedirectUrl("/login/otp", origin) + } + return nil +} + var ErrAuthHttpError = errors.New("auth http error") -func (h *HttpServer) RequireAdminAuthentication(next UserHandler) httprouter.Handle { +func (h *httpServer) RequireAdminAuthentication(next UserHandler) httprouter.Handle { return h.RequireAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { - var roles []string + var hasRole bool if h.DbTx(rw, func(tx *database.Queries) (err error) { - roles, err = tx.GetUserRoles(req.Context(), auth.Subject) + err = tx.UserHasRole(req.Context(), database.UserHasRoleParams{ + Role: role.LavenderAdmin, + Subject: auth.Subject, + }) + switch { + case err == nil: + hasRole = true + case errors.Is(err, sql.ErrNoRows): + hasRole = false + err = nil + } return }) { return } - if !HasRole(roles, "lavender:admin") { + if !hasRole { http.Error(rw, "403 Forbidden", http.StatusForbidden) return } @@ -38,7 +58,7 @@ func (h *HttpServer) RequireAdminAuthentication(next UserHandler) httprouter.Han }) } -func (h *HttpServer) RequireAuthentication(next UserHandler) httprouter.Handle { +func (h *httpServer) RequireAuthentication(next UserHandler) httprouter.Handle { return h.OptionalAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { if auth.IsGuest() { redirectUrl := PrepareRedirectUrl("/login", req.URL) @@ -49,7 +69,7 @@ func (h *HttpServer) RequireAuthentication(next UserHandler) httprouter.Handle { }) } -func (h *HttpServer) OptionalAuthentication(next UserHandler) httprouter.Handle { +func (h *httpServer) OptionalAuthentication(next UserHandler) httprouter.Handle { return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { authUser, err := h.internalAuthenticationHandler(rw, req) if err != nil { @@ -62,7 +82,7 @@ func (h *HttpServer) OptionalAuthentication(next UserHandler) httprouter.Handle } } -func (h *HttpServer) internalAuthenticationHandler(rw http.ResponseWriter, req *http.Request) (UserAuth, error) { +func (h *httpServer) internalAuthenticationHandler(rw http.ResponseWriter, req *http.Request) (UserAuth, error) { // Delete previous login data cookie http.SetCookie(rw, &http.Cookie{ Name: "lavender-login-data", diff --git a/server/db.go b/server/db.go index f41e106..4627152 100644 --- a/server/db.go +++ b/server/db.go @@ -12,7 +12,7 @@ var ErrDatabaseActionFailed = errors.New("database action failed") // DbTx wraps a database transaction with http error messages and a simple action // function. If the action function returns an error the transaction will be // rolled back. If there is no error then the transaction is committed. -func (h *HttpServer) DbTx(rw http.ResponseWriter, action func(tx *database.Queries) error) bool { +func (h *httpServer) DbTx(rw http.ResponseWriter, action func(tx *database.Queries) error) bool { logger.Logger.Helper() if h.DbTxError(action) != nil { http.Error(rw, "Database error", http.StatusInternalServerError) @@ -22,7 +22,7 @@ func (h *HttpServer) DbTx(rw http.ResponseWriter, action func(tx *database.Queri return false } -func (h *HttpServer) DbTxError(action func(tx *database.Queries) error) error { +func (h *httpServer) DbTxError(action func(tx *database.Queries) error) error { logger.Logger.Helper() err := action(h.db) if err != nil { diff --git a/server/home.go b/server/home.go index 75571cb..cc93e98 100644 --- a/server/home.go +++ b/server/home.go @@ -3,13 +3,14 @@ package server import ( "github.com/1f349/lavender/database" "github.com/1f349/lavender/pages" + "github.com/1f349/lavender/role" "github.com/google/uuid" "github.com/julienschmidt/httprouter" "net/http" "time" ) -func (h *HttpServer) Home(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { +func (h *httpServer) Home(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { rw.Header().Set("Content-Type", "text/html") lNonce := uuid.NewString() http.SetCookie(rw, &http.Cookie{ @@ -30,7 +31,7 @@ func (h *HttpServer) Home(rw http.ResponseWriter, req *http.Request, _ httproute var isAdmin bool h.DbTx(rw, func(tx *database.Queries) (err error) { - _, err = tx.UserHasRole(req.Context(), database.UserHasRoleParams{Role: "lavender:admin", Subject: auth.Subject}) + err = tx.UserHasRole(req.Context(), database.UserHasRoleParams{Role: role.LavenderAdmin, Subject: auth.Subject}) isAdmin = err == nil return nil }) diff --git a/server/jwt.go b/server/jwt.go index 6d23725..1b19b7c 100644 --- a/server/jwt.go +++ b/server/jwt.go @@ -4,7 +4,6 @@ import ( "context" "crypto/sha256" "encoding/base64" - "github.com/1f349/lavender/database" "github.com/1f349/mjwt" "github.com/1f349/mjwt/auth" "github.com/go-oauth2/oauth2/v4" @@ -15,15 +14,19 @@ import ( type JWTAccessGenerate struct { signer *mjwt.Issuer - db *database.Queries + db mjwtGetUserRoles } -func NewJWTAccessGenerate(signer *mjwt.Issuer, db *database.Queries) *JWTAccessGenerate { +func NewMJWTAccessGenerate(signer *mjwt.Issuer, db mjwtGetUserRoles) *JWTAccessGenerate { return &JWTAccessGenerate{signer, db} } var _ oauth2.AccessGenerate = &JWTAccessGenerate{} +type mjwtGetUserRoles interface { + GetUserRoles(ctx context.Context, subject string) ([]string, error) +} + func (j *JWTAccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (access, refresh string, err error) { roles, err := j.db.GetUserRoles(ctx, data.UserID) if err != nil { diff --git a/server/login.go b/server/login.go index fb57bb1..323a03c 100644 --- a/server/login.go +++ b/server/login.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + auth2 "github.com/1f349/lavender/auth" "github.com/1f349/lavender/database" "github.com/1f349/lavender/issuer" "github.com/1f349/lavender/pages" @@ -21,7 +22,7 @@ import ( "time" ) -func (h *HttpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { +func (h *httpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { if !auth.IsGuest() { h.SafeRedirect(rw, req) return @@ -42,7 +43,7 @@ func (h *HttpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httpr }) } -func (h *HttpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { +func (h *httpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { if !auth.IsGuest() { h.SafeRedirect(rw, req) return @@ -95,7 +96,7 @@ func (h *HttpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ http http.Redirect(rw, req, nextUrl, http.StatusFound) } -func (h *HttpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, userAuth UserAuth) { +func (h *httpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, userAuth UserAuth) { flowState, ok := h.flowState.Get(req.FormValue("state")) if !ok { http.Error(rw, "Invalid flow state", http.StatusBadRequest) @@ -123,7 +124,7 @@ func (h *HttpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _ h.SafeRedirect(rw, req) } -func (h *HttpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (UserAuth, error) { +func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (UserAuth, error) { sessionData, err := h.fetchUserInfo(sso, token) if err != nil || sessionData.Subject == "" { return UserAuth{}, fmt.Errorf("failed to fetch user info") @@ -138,6 +139,16 @@ func (h *HttpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellK if errors.Is(err, sql.ErrNoRows) { uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost") uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified") + id, err := tx.AddUser(req.Context(), database.AddUserParams{ + Name: "", + Subject: sessionData.Subject, + Password: "", + Email: uEmail, + EmailVerified: uEmailVerified, + UpdatedAt: time.Now(), + Active: true, + }) + return err return tx.AddUser(req.Context(), database.AddUserParams{ Subject: sessionData.Subject, Email: uEmail, @@ -180,7 +191,7 @@ const twelveHours = 12 * time.Hour const oneWeek = 7 * 24 * time.Hour type lavenderLoginAccess struct { - UserInfo UserInfoFields `json:"user_info"` + UserInfo auth2.UserInfoFields `json:"user_info"` auth.AccessTokenClaims } @@ -197,7 +208,7 @@ func (l lavenderLoginRefresh) Valid() error { return l.RefreshTokenClaims.Valid( func (l lavenderLoginRefresh) Type() string { return "lavender-login-refresh" } -func (h *HttpServer) setLoginDataCookie(rw http.ResponseWriter, authData UserAuth, loginName string) bool { +func (h *httpServer) setLoginDataCookie(rw http.ResponseWriter, authData UserAuth, loginName string) bool { ps := auth.NewPermStorage() accId := uuid.NewString() gen, err := h.signingKey.GenerateJwt(authData.Subject, accId, jwt.ClaimStrings{h.conf.BaseUrl}, twelveHours, lavenderLoginAccess{ @@ -248,7 +259,7 @@ func readJwtCookie[T mjwt.Claims](req *http.Request, cookieName string, signingK return b, nil } -func (h *HttpServer) readLoginAccessCookie(rw http.ResponseWriter, req *http.Request, u *UserAuth) error { +func (h *httpServer) readLoginAccessCookie(rw http.ResponseWriter, req *http.Request, u *UserAuth) error { loginData, err := readJwtCookie[lavenderLoginAccess](req, "lavender-login-access", h.signingKey.KeyStore()) if err != nil { return h.readLoginRefreshCookie(rw, req, u) @@ -260,7 +271,7 @@ func (h *HttpServer) readLoginAccessCookie(rw http.ResponseWriter, req *http.Req return nil } -func (h *HttpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Request, userAuth *UserAuth) error { +func (h *httpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Request, userAuth *UserAuth) error { refreshData, err := readJwtCookie[lavenderLoginRefresh](req, "lavender-login-refresh", h.signingKey.KeyStore()) if err != nil { return err @@ -298,14 +309,14 @@ func (h *HttpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Re return nil } -func (h *HttpServer) fetchUserInfo(sso *issuer.WellKnownOIDC, token *oauth2.Token) (UserAuth, error) { +func (h *httpServer) fetchUserInfo(sso *issuer.WellKnownOIDC, token *oauth2.Token) (UserAuth, error) { res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint) if err != nil || res.StatusCode != http.StatusOK { return UserAuth{}, fmt.Errorf("request failed") } defer res.Body.Close() - var userInfoJson UserInfoFields + var userInfoJson auth2.UserInfoFields if err := json.NewDecoder(res.Body).Decode(&userInfoJson); err != nil { return UserAuth{}, err } diff --git a/server/manage-apps.go b/server/manage-apps.go index 6605567..d95b752 100644 --- a/server/manage-apps.go +++ b/server/manage-apps.go @@ -4,6 +4,7 @@ import ( "github.com/1f349/lavender/database" "github.com/1f349/lavender/pages" "github.com/1f349/lavender/password" + "github.com/1f349/lavender/role" "github.com/google/uuid" "github.com/julienschmidt/httprouter" "net/http" @@ -11,7 +12,7 @@ import ( "strconv" ) -func (h *HttpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { +func (h *httpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { q := req.URL.Query() offset, _ := strconv.Atoi(q.Get("offset")) @@ -24,7 +25,7 @@ func (h *HttpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _ } appList, err = tx.GetAppList(req.Context(), database.GetAppListParams{ Owner: auth.Subject, - Column2: HasRole(roles, "lavender:admin"), + Column2: HasRole(roles, role.LavenderAdmin), Offset: int64(offset), }) return @@ -59,7 +60,7 @@ func (h *HttpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _ pages.RenderPageTemplate(rw, "manage-apps", m) } -func (h *HttpServer) ManageAppsCreateGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { +func (h *httpServer) ManageAppsCreateGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { var roles string if h.DbTx(rw, func(tx *database.Queries) (err error) { roles, err = tx.GetUserRoles(req.Context(), auth.Subject) @@ -70,7 +71,7 @@ func (h *HttpServer) ManageAppsCreateGet(rw http.ResponseWriter, req *http.Reque m := map[string]any{ "ServiceName": h.conf.ServiceName, - "IsAdmin": HasRole(roles, "lavender:admin"), + "IsAdmin": HasRole(roles, role.LavenderAdmin), } rw.Header().Set("Content-Type", "text/html") @@ -78,7 +79,7 @@ func (h *HttpServer) ManageAppsCreateGet(rw http.ResponseWriter, req *http.Reque pages.RenderPageTemplate(rw, "manage-apps-create", m) } -func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { +func (h *httpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { err := req.ParseForm() if err != nil { http.Error(rw, "400 Bad Request: Failed to parse form", http.StatusBadRequest) diff --git a/server/manage-users.go b/server/manage-users.go index 02a6f37..b3f3c8f 100644 --- a/server/manage-users.go +++ b/server/manage-users.go @@ -3,13 +3,14 @@ package server import ( "github.com/1f349/lavender/database" "github.com/1f349/lavender/pages" + "github.com/1f349/lavender/role" "github.com/julienschmidt/httprouter" "net/http" "net/url" "strconv" ) -func (h *HttpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { +func (h *httpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { q := req.URL.Query() offset, _ := strconv.Atoi(q.Get("offset")) @@ -25,7 +26,7 @@ func (h *HttpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _ }) { return } - if !HasRole(roles, "lavender:admin") { + if !HasRole(roles, role.LavenderAdmin) { http.Error(rw, "403 Forbidden", http.StatusForbidden) return } @@ -56,7 +57,7 @@ func (h *HttpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _ pages.RenderPageTemplate(rw, "manage-users", m) } -func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { +func (h *httpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { err := req.ParseForm() if err != nil { http.Error(rw, "400 Bad Request: Failed to parse form", http.StatusBadRequest) diff --git a/server/oauth.go b/server/oauth.go index ed03179..15ba6b7 100644 --- a/server/oauth.go +++ b/server/oauth.go @@ -10,7 +10,7 @@ import ( "strings" ) -func (h *HttpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { +func (h *httpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { // function is only called with GET or POST method isPost := req.Method == http.MethodPost @@ -128,7 +128,7 @@ func (h *HttpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request http.Redirect(rw, req, parsedRedirect.String(), http.StatusFound) } -func (h *HttpServer) oauthUserAuthorization(rw http.ResponseWriter, req *http.Request) (string, error) { +func (h *httpServer) oauthUserAuthorization(rw http.ResponseWriter, req *http.Request) (string, error) { err := req.ParseForm() if err != nil { return "", err diff --git a/server/openid.go b/server/openid.go new file mode 100644 index 0000000..7a56b9c --- /dev/null +++ b/server/openid.go @@ -0,0 +1,38 @@ +package server + +import ( + "bytes" + "encoding/json" + "github.com/1f349/lavender/logger" + "github.com/1f349/lavender/openid" + "github.com/1f349/mjwt" + "github.com/julienschmidt/httprouter" + "net/http" +) + +func SetupOpenId(r *httprouter.Router, baseUrl string, signingKey *mjwt.Issuer) { + openIdConf := openid.GenConfig(baseUrl, []string{ + "openid", "name", "username", "profile", "email", "birthdate", "age", "zoneinfo", "locale", + }, []string{ + "sub", "name", "preferred_username", "profile", "picture", "website", "email", "email_verified", "gender", "birthdate", "zoneinfo", "locale", "updated_at", + }) + openIdBytes, err := json.Marshal(openIdConf) + if err != nil { + logger.Logger.Fatal("Failed to generate OpenID configuration", "err", err) + } + + jwkSetBuffer := new(bytes.Buffer) + err = mjwt.WriteJwkSetJson(jwkSetBuffer, []*mjwt.Issuer{signingKey}) + if err != nil { + logger.Logger.Fatal("Failed to generate JWK Set", "err", err) + } + + r.GET("/.well-known/openid-configuration", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write(openIdBytes) + }) + r.GET("/.well-known/jwks.json", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write(jwkSetBuffer.Bytes()) + }) +} diff --git a/server/roles_test.go b/server/roles_test.go index 10a3cd2..008ef00 100644 --- a/server/roles_test.go +++ b/server/roles_test.go @@ -6,7 +6,7 @@ import ( ) func TestHasRole(t *testing.T) { - assert.True(t, HasRole("lavender:admin test:something-else", "lavender:admin")) - assert.False(t, HasRole("lavender:admin,test:something-else", "lavender:admin")) - assert.False(t, HasRole("lavender: test:something-else", "lavender:admin")) + assert.True(t, HasRole([]string{"lavender:admin", "test:something-else"}, "lavender:admin")) + assert.False(t, HasRole([]string{"lavender:admin", "test:something-else"}, "lavender:admin")) + assert.False(t, HasRole([]string{"lavender:", "test:something-else"}, "lavender:admin")) } diff --git a/server/server.go b/server/server.go index e989543..4af27a6 100644 --- a/server/server.go +++ b/server/server.go @@ -1,20 +1,16 @@ package server import ( - "bytes" - "crypto/subtle" - "encoding/json" + "errors" "github.com/1f349/cache" clientStore "github.com/1f349/lavender/client-store" "github.com/1f349/lavender/conf" "github.com/1f349/lavender/database" "github.com/1f349/lavender/issuer" - "github.com/1f349/lavender/logger" - "github.com/1f349/lavender/openid" "github.com/1f349/lavender/pages" scope2 "github.com/1f349/lavender/scope" "github.com/1f349/mjwt" - "github.com/go-oauth2/oauth2/v4/errors" + "github.com/go-oauth2/oauth2/v4/generates" "github.com/go-oauth2/oauth2/v4/manage" "github.com/go-oauth2/oauth2/v4/server" "github.com/go-oauth2/oauth2/v4/store" @@ -28,7 +24,7 @@ import ( var errInvalidScope = errors.New("missing required scope") -type HttpServer struct { +type httpServer struct { r *httprouter.Router oauthSrv *server.Server oauthMgr *manage.Manager @@ -36,7 +32,12 @@ type HttpServer struct { conf conf.Conf signingKey *mjwt.Issuer manager *issuer.Manager - flowState *cache.Cache[string, flowStateData] + + // flowState contains the + flowState *cache.Cache[string, flowStateData] + + // mailLinkCache contains a mapping of verify uuids to user uuids + mailLinkCache *cache.Cache[mailLinkKey, string] } type flowStateData struct { @@ -45,52 +46,44 @@ type flowStateData struct { redirect string } -func NewHttpServer(config conf.Conf, db *database.Queries, signingKey *mjwt.Issuer) *httprouter.Router { - r := httprouter.New() +type mailLink byte + +const ( + mailLinkDelete mailLink = iota + mailLinkResetPassword + mailLinkVerifyEmail +) + +type mailLinkKey struct { + action mailLink + data string +} + +func SetupRouter(r *httprouter.Router, config conf.Conf, db *database.Queries, signingKey *mjwt.Issuer) { + // remove last slash from baseUrl + config.BaseUrl = strings.TrimRight(config.BaseUrl, "/") + contentCache := time.Now() - // remove last slash from baseUrl - { - l := len(config.BaseUrl) - if config.BaseUrl[l-1] == '/' { - config.BaseUrl = config.BaseUrl[:l-1] - } - } - - openIdConf := openid.GenConfig(config.BaseUrl, []string{"openid", "name", "username", "profile", "email", "birthdate", "age", "zoneinfo", "locale"}, []string{"sub", "name", "preferred_username", "profile", "picture", "website", "email", "email_verified", "gender", "birthdate", "zoneinfo", "locale", "updated_at"}) - openIdBytes, err := json.Marshal(openIdConf) - if err != nil { - logger.Logger.Fatal("Failed to generate OpenID configuration", "err", err) - } - - jwkSetBuffer := new(bytes.Buffer) - err = mjwt.WriteJwkSetJson(jwkSetBuffer, []*mjwt.Issuer{signingKey}) - if err != nil { - logger.Logger.Fatal("Failed to generate JWK Set", "err", err) - } - - oauthManager := manage.NewDefaultManager() - oauthSrv := server.NewServer(server.NewConfig(), oauthManager) - hs := &HttpServer{ - r: httprouter.New(), - oauthSrv: oauthSrv, - oauthMgr: oauthManager, + hs := &httpServer{ + r: r, db: db, conf: config, signingKey: signingKey, - flowState: cache.New[string, flowStateData](), - } - - hs.manager, err = issuer.NewManager(config.SsoServices) - if err != nil { - logger.Logger.Fatal("Failed to reload SSO service manager", "err", err) + + flowState: cache.New[string, flowStateData](), + + mailLinkCache: cache.New[mailLinkKey, string](), } + oauthManager := manage.NewManager() + oauthManager.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) oauthManager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) oauthManager.MustTokenStorage(store.NewMemoryTokenStore()) - oauthManager.MapAccessGenerate(NewJWTAccessGenerate(hs.signingKey, db)) + oauthManager.MapAccessGenerate(NewMJWTAccessGenerate(signingKey, db)) oauthManager.MapClientStorage(clientStore.New(db)) + oauthSrv := server.NewDefaultServer(oauthManager) oauthSrv.SetClientInfoHandler(func(req *http.Request) (clientID, clientSecret string, err error) { cId, cSecret, err := server.ClientBasicHandler(req) if cId == "" && cSecret == "" { @@ -117,47 +110,10 @@ func NewHttpServer(config conf.Conf, db *database.Queries, signingKey *mjwt.Issu }) addIdTokenSupport(oauthSrv, db, signingKey) - r.GET("/.well-known/openid-configuration", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { - rw.WriteHeader(http.StatusOK) - _, _ = rw.Write(openIdBytes) - }) - r.GET("/.well-known/jwks.json", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { - rw.WriteHeader(http.StatusOK) - _, _ = rw.Write(jwkSetBuffer.Bytes()) - }) - r.GET("/", hs.OptionalAuthentication(hs.Home)) + ssoManager := issuer.NewManager(config.SsoServices) - // login - r.GET("/login", hs.OptionalAuthentication(hs.loginGet)) - r.POST("/login", hs.OptionalAuthentication(hs.loginPost)) - r.GET("/callback", hs.OptionalAuthentication(hs.loginCallback)) - r.POST("/logout", hs.RequireAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { - cookie, err := req.Cookie("lavender-nonce") - if err != nil { - http.Error(rw, "Missing nonce", http.StatusBadRequest) - return - } - if subtle.ConstantTimeCompare([]byte(cookie.Value), []byte(req.PostFormValue("nonce"))) == 1 { - http.SetCookie(rw, &http.Cookie{ - Name: "lavender-login-access", - Path: "/", - MaxAge: -1, - Secure: true, - SameSite: http.SameSiteLaxMode, - }) - http.SetCookie(rw, &http.Cookie{ - Name: "lavender-login-refresh", - Path: "/", - MaxAge: -1, - Secure: true, - SameSite: http.SameSiteLaxMode, - }) - - http.Redirect(rw, req, "/", http.StatusFound) - return - } - http.Error(rw, "Logout failed", http.StatusInternalServerError) - })) + SetupOpenId(r, config.BaseUrl, signingKey) + r.POST("/logout", hs.RequireAuthentication(fu)) // theme styles r.GET("/assets/*filepath", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { @@ -170,108 +126,11 @@ func NewHttpServer(config conf.Conf, db *database.Queries, signingKey *mjwt.Issu http.ServeContent(rw, req, path.Base(name), contentCache, out) }) - // management pages - r.GET("/manage/apps", hs.RequireAuthentication(hs.ManageAppsGet)) - r.GET("/manage/apps/create", hs.RequireAuthentication(hs.ManageAppsCreateGet)) - r.POST("/manage/apps", hs.RequireAuthentication(hs.ManageAppsPost)) - r.GET("/manage/users", hs.RequireAdminAuthentication(hs.ManageUsersGet)) - r.POST("/manage/users", hs.RequireAdminAuthentication(hs.ManageUsersPost)) - - // oauth pages - r.GET("/authorize", hs.RequireAuthentication(hs.authorizeEndpoint)) - r.POST("/authorize", hs.RequireAuthentication(hs.authorizeEndpoint)) - r.POST("/token", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { - if err := oauthSrv.HandleTokenRequest(rw, req); err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - } - }) - userInfoRequest := func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { - rw.Header().Set("Access-Control-Allow-Credentials", "true") - rw.Header().Set("Access-Control-Allow-Headers", "Authorization,Content-Type") - rw.Header().Set("Access-Control-Allow-Origin", strings.TrimSuffix(req.Referer(), "/")) - rw.Header().Set("Access-Control-Allow-Methods", "GET") - if req.Method == http.MethodOptions { - return - } - - token, err := oauthSrv.ValidationBearerToken(req) - if err != nil { - http.Error(rw, "403 Forbidden", http.StatusForbidden) - return - } - userId := token.GetUserID() - - sso := hs.manager.FindServiceFromLogin(userId) - if sso == nil { - http.Error(rw, "Invalid user", http.StatusBadRequest) - return - } - - var user database.User - if hs.DbTx(rw, func(tx *database.Queries) (err error) { - user, err = tx.GetUser(req.Context(), userId) - return - }) { - return - } - - var userInfo UserInfoFields - err = json.Unmarshal([]byte(user.Userinfo), &userInfo) - if err != nil { - http.Error(rw, "500 Internal Server Error", http.StatusInternalServerError) - return - } - - claims := ParseClaims(token.GetScope()) - if !claims["openid"] { - http.Error(rw, "Invalid scope", http.StatusBadRequest) - return - } - - m := make(map[string]any) - - if claims["name"] { - m["name"] = userInfo["name"] - } - if claims["username"] { - m["preferred_username"] = userInfo["preferred_username"] - m["login"] = userInfo["login"] - } - if claims["profile"] { - m["profile"] = userInfo["profile"] - m["picture"] = userInfo["picture"] - m["website"] = userInfo["website"] - } - if claims["email"] { - m["email"] = userInfo["email"] - m["email_verified"] = userInfo["email_verified"] - } - if claims["birthdate"] { - m["birthdate"] = userInfo["birthdate"] - } - if claims["age"] { - m["age"] = userInfo["age"] - } - if claims["zoneinfo"] { - m["zoneinfo"] = userInfo["zoneinfo"] - } - if claims["locale"] { - m["locale"] = userInfo["locale"] - } - - m["sub"] = userId - m["aud"] = token.GetClientID() - m["updated_at"] = time.Now().Unix() - - _ = json.NewEncoder(rw).Encode(m) - } - r.GET("/userinfo", userInfoRequest) - r.OPTIONS("/userinfo", userInfoRequest) - - return r + SetupManageApps(r) + SetupManageUsers(r) } -func (h *HttpServer) SafeRedirect(rw http.ResponseWriter, req *http.Request) { +func (h *httpServer) SafeRedirect(rw http.ResponseWriter, req *http.Request) { redirectUrl := req.FormValue("redirect") if redirectUrl == "" { http.Redirect(rw, req, "/", http.StatusFound)