diff --git a/auth/auth-buttons.go b/auth/auth-buttons.go new file mode 100644 index 0000000..b1f7b34 --- /dev/null +++ b/auth/auth-buttons.go @@ -0,0 +1,15 @@ +package auth + +import ( + "context" + "html/template" + "net/http" +) + +type Button interface { + // ButtonName defines the text to show on the button. + ButtonName() string + + // RenderButtonTemplate returns a template for the button widget. + RenderButtonTemplate(ctx context.Context, req *http.Request) template.HTML +} diff --git a/auth/auth.go b/auth/auth.go index 4d30deb..d9e0624 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -4,8 +4,8 @@ import ( "context" "errors" "fmt" + "github.com/1f349/lavender/auth/authContext" "github.com/1f349/lavender/database" - "html/template" "net/http" ) @@ -38,10 +38,10 @@ type Provider interface { Name() string // RenderTemplate returns HTML to embed in the page template. - RenderTemplate(ctx context.Context, req *http.Request, user *database.User) (template.HTML, error) + RenderTemplate(ctx authContext.TemplateContext) error // AttemptLogin processes the login request. - AttemptLogin(ctx context.Context, req *http.Request, user *database.User) error + AttemptLogin(ctx authContext.TemplateContext) error } type UserSafeError struct { diff --git a/auth/authContext/context.go b/auth/authContext/context.go new file mode 100644 index 0000000..09ea375 --- /dev/null +++ b/auth/authContext/context.go @@ -0,0 +1,48 @@ +package authContext + +import ( + "context" + "github.com/1f349/lavender/database" + "net/http" +) + +func NewTemplateContext(req *http.Request, user *database.User) TemplateContext { + return &BaseTemplateContext{ + req: req, + user: user, + } +} + +type TemplateContext interface { + Context() context.Context + Request() *http.Request + User() *database.User + Render(data any) + Data() any +} + +type BaseTemplateContext struct { + req *http.Request + user *database.User + data any +} + +func (t *BaseTemplateContext) Context() context.Context { + return t.req.Context() +} + +func (t *BaseTemplateContext) Request() *http.Request { + return t.req +} + +func (t *BaseTemplateContext) User() *database.User { + return t.user +} + +func (t *BaseTemplateContext) Render(data any) { + t.data = data +} + +func (t *BaseTemplateContext) Data() any { + return t.data +} diff --git a/auth/providers/login.go b/auth/providers/basic.go similarity index 59% rename from auth/providers/login.go rename to auth/providers/basic.go index e21c222..65c09aa 100644 --- a/auth/providers/login.go +++ b/auth/providers/basic.go @@ -4,10 +4,9 @@ import ( "context" "database/sql" "errors" - "fmt" "github.com/1f349/lavender/auth" + "github.com/1f349/lavender/auth/authContext" "github.com/1f349/lavender/database" - "html/template" "net/http" ) @@ -26,22 +25,36 @@ func (b *BasicLogin) AccessState() auth.State { return auth.StateUnauthorized } func (b *BasicLogin) Name() string { return "basic" } -func (b *BasicLogin) RenderTemplate(ctx context.Context, req *http.Request, user *database.User) (template.HTML, error) { +func (b *BasicLogin) RenderTemplate(ctx authContext.TemplateContext) error { // TODO(melon): rewrite this - return template.HTML(fmt.Sprintf("
%s
", req.FormValue("username"))), nil + req := ctx.Request() + un := req.FormValue("login") + redirect := req.FormValue("redirect") + if redirect == "" { + redirect = "/" + } + ctx.Render(struct { + UserEmail string + Redirect string + }{ + UserEmail: un, + Redirect: redirect, + }) + return nil } -func (b *BasicLogin) AttemptLogin(ctx context.Context, req *http.Request, user *database.User) error { +func (b *BasicLogin) AttemptLogin(ctx authContext.TemplateContext) error { + req := ctx.Request() un := req.FormValue("username") pw := req.FormValue("password") if len(pw) < 8 { return auth.BasicUserSafeError(http.StatusBadRequest, "Password too short") } - login, err := b.DB.CheckLogin(ctx, un, pw) + login, err := b.DB.CheckLogin(ctx.Context(), un, pw) switch { case err == nil: - return auth.LookupUser(ctx, b.DB, login.Subject, user) + return auth.LookupUser(ctx.Context(), b.DB, login.Subject, ctx.User()) case errors.Is(err, sql.ErrNoRows): return auth.BasicUserSafeError(http.StatusForbidden, "Username or password is invalid") default: diff --git a/auth/providers/oauth.go b/auth/providers/oauth.go index 855ffd5..6375246 100644 --- a/auth/providers/oauth.go +++ b/auth/providers/oauth.go @@ -5,8 +5,10 @@ import ( "fmt" "github.com/1f349/cache" "github.com/1f349/lavender/auth" + "github.com/1f349/lavender/auth/authContext" "github.com/1f349/lavender/database" "github.com/1f349/lavender/issuer" + "github.com/1f349/lavender/url" "github.com/google/uuid" "golang.org/x/oauth2" "html/template" @@ -20,12 +22,15 @@ type flowStateData struct { redirect string } -var _ auth.Provider = (*OAuthLogin)(nil) +var ( + _ auth.Provider = (*OAuthLogin)(nil) + _ auth.Button = (*OAuthLogin)(nil) +) type OAuthLogin struct { DB *database.Queries - BaseUrl string + BaseUrl *url.URL flow *cache.Cache[string, flowStateData] } @@ -34,29 +39,41 @@ func (o OAuthLogin) Init() { o.flow = cache.New[string, flowStateData]() } +func (o OAuthLogin) authUrlBase(ref string) *url.URL { + return o.BaseUrl.Resolve("oauth", o.Name(), ref) +} + func (o OAuthLogin) AccessState() auth.State { return auth.StateUnauthorized } func (o OAuthLogin) Name() string { return "oauth" } -func (o OAuthLogin) RenderTemplate(ctx context.Context, req *http.Request, user *database.User) (template.HTML, error) { - return "
OAuth Login Template
", nil +func (o OAuthLogin) RenderTemplate(ctx authContext.TemplateContext) error { + // TODO: does this need to exist? + ctx.Render(map[string]any{"Error": "no"}) + return nil } -func (o OAuthLogin) AttemptLogin(ctx context.Context, req *http.Request, user *database.User) error { - login, ok := ctx.Value(oauthServiceLogin(0)).(*issuer.WellKnownOIDC) +func (o OAuthLogin) AttemptLogin(ctx authContext.TemplateContext) error { + rCtx := ctx.Context() + + login, ok := rCtx.Value(oauthServiceLogin(0)).(*issuer.WellKnownOIDC) if !ok { return fmt.Errorf("missing issuer wellknown") } - loginName := ctx.Value("login_full").(string) - loginUn := ctx.Value("login_username").(string) + loginName := rCtx.Value("login_full").(string) + loginUn := rCtx.Value("login_username").(string) // save state for use later state := login.Config.Namespace + ":" + uuid.NewString() - o.flow.Set(state, flowStateData{loginName, login, req.PostFormValue("redirect")}, time.Now().Add(15*time.Minute)) + o.flow.Set(state, flowStateData{ + loginName: loginName, + sso: login, + redirect: ctx.Request().PostFormValue("redirect"), + }, time.Now().Add(15*time.Minute)) // generate oauth2 config and redirect to authorize URL oa2conf := login.OAuth2Config - oa2conf.RedirectURL = o.BaseUrl + "/callback" + oa2conf.RedirectURL = o.authUrlBase("callback").String() nextUrl := oa2conf.AuthCodeURL(state, oauth2.SetAuthURLParam("login_name", loginUn)) return auth.RedirectError{Target: nextUrl, Code: http.StatusFound} @@ -68,7 +85,7 @@ func (o OAuthLogin) OAuthCallback(rw http.ResponseWriter, req *http.Request, inf http.Error(rw, "Invalid flow state", http.StatusBadRequest) return } - token, err := flowState.sso.OAuth2Config.Exchange(context.Background(), req.FormValue("code"), oauth2.SetAuthURLParam("redirect_uri", o.BaseUrl+"/callback")) + token, err := flowState.sso.OAuth2Config.Exchange(context.Background(), req.FormValue("code"), oauth2.SetAuthURLParam("redirect_uri", o.authUrlBase("callback").String())) if err != nil { http.Error(rw, "Failed to exchange code for token", http.StatusInternalServerError) return @@ -90,6 +107,13 @@ func (o OAuthLogin) OAuthCallback(rw http.ResponseWriter, req *http.Request, inf redirect(rw, req) } +func (o OAuthLogin) ButtonName() string { return o.Name() } + +func (o OAuthLogin) RenderButtonTemplate(ctx context.Context, req *http.Request) template.HTML { + // o.authUrlBase("button") + return "
OAuth Login Template
" +} + type oauthServiceLogin int func WithWellKnown(ctx context.Context, login *issuer.WellKnownOIDC) context.Context { diff --git a/auth/providers/otp.go b/auth/providers/otp.go index ed1279b..cfb7f63 100644 --- a/auth/providers/otp.go +++ b/auth/providers/otp.go @@ -5,9 +5,9 @@ import ( "errors" "fmt" "github.com/1f349/lavender/auth" + "github.com/1f349/lavender/auth/authContext" "github.com/1f349/lavender/database" "github.com/xlzd/gotp" - "html/template" "net/http" "time" ) @@ -30,19 +30,8 @@ func (o *OtpLogin) AccessState() auth.State { return auth.StateBasic } func (o *OtpLogin) Name() string { return "basic" } -func (o *OtpLogin) RenderTemplate(_ context.Context, _ *http.Request, user *database.User) (template.HTML, error) { - if user == nil || user.Subject == "" { - return "", fmt.Errorf("requires previous factor") - } - if user.OtpSecret == "" || !isDigitsSupported(user.OtpDigits) { - return "", fmt.Errorf("user does not support factor") - } - - // no need to provide render data - return "
OTP login template
", nil -} - -func (o *OtpLogin) AttemptLogin(ctx context.Context, req *http.Request, user *database.User) error { +func (o *OtpLogin) RenderTemplate(ctx authContext.TemplateContext) error { + user := ctx.User() if user == nil || user.Subject == "" { return fmt.Errorf("requires previous factor") } @@ -50,7 +39,25 @@ func (o *OtpLogin) AttemptLogin(ctx context.Context, req *http.Request, user *da return fmt.Errorf("user does not support factor") } - code := req.FormValue("code") + // TODO: is this right? + ctx.Render(map[string]any{ + "Redirect": "/", + }) + + // no need to provide render data + return nil +} + +func (o *OtpLogin) AttemptLogin(ctx authContext.TemplateContext) error { + user := ctx.User() + if user == nil || user.Subject == "" { + return fmt.Errorf("requires previous factor") + } + if user.OtpSecret == "" || !isDigitsSupported(user.OtpDigits) { + return fmt.Errorf("user does not support factor") + } + + code := ctx.Request().FormValue("code") if !validateTotp(user.OtpSecret, int(user.OtpDigits), code) { return auth.BasicUserSafeError(http.StatusBadRequest, "invalid OTP code") diff --git a/auth/providers/passkey.go b/auth/providers/passkey.go index 642f5a4..7e029a4 100644 --- a/auth/providers/passkey.go +++ b/auth/providers/passkey.go @@ -4,7 +4,7 @@ import ( "context" "fmt" "github.com/1f349/lavender/auth" - "github.com/1f349/lavender/database" + "github.com/1f349/lavender/auth/authContext" "html/template" "net/http" ) @@ -13,7 +13,10 @@ type passkeyLoginDB interface { auth.LookupUserDB } -var _ auth.Provider = (*PasskeyLogin)(nil) +var ( + _ auth.Provider = (*PasskeyLogin)(nil) + _ auth.Button = (*PasskeyLogin)(nil) +) type PasskeyLogin struct { DB passkeyLoginDB @@ -23,12 +26,13 @@ func (p *PasskeyLogin) AccessState() auth.State { return auth.StateUnauthorized func (p *PasskeyLogin) Name() string { return "passkey" } -func (p *PasskeyLogin) RenderTemplate(ctx context.Context, req *http.Request, user *database.User) (template.HTML, error) { +func (p *PasskeyLogin) RenderTemplate(ctx authContext.TemplateContext) error { + user := ctx.User() if user == nil || user.Subject == "" { - return "", fmt.Errorf("requires previous factor") + return fmt.Errorf("requires previous factor") } if user.OtpSecret == "" { - return "", fmt.Errorf("user does not support factor") + return fmt.Errorf("user does not support factor") } panic("implement me") @@ -40,7 +44,8 @@ func init() { passkeyShortcut = true } -func (p *PasskeyLogin) AttemptLogin(ctx context.Context, req *http.Request, user *database.User) error { +func (p *PasskeyLogin) AttemptLogin(ctx authContext.TemplateContext) error { + user := ctx.User() if user.Subject == "" && !passkeyShortcut { return fmt.Errorf("requires previous factor") } @@ -48,3 +53,11 @@ func (p *PasskeyLogin) AttemptLogin(ctx context.Context, req *http.Request, user //TODO implement me panic("implement me") } + +func (p *PasskeyLogin) ButtonName() string { + return "Login with Passkey" +} + +func (p *PasskeyLogin) RenderButtonTemplate(ctx context.Context, req *http.Request) template.HTML { + return "
Passkey Button
" +} diff --git a/conf/conf.go b/conf/conf.go index 841e7ee..c7f9009 100644 --- a/conf/conf.go +++ b/conf/conf.go @@ -2,12 +2,13 @@ package conf import ( "github.com/1f349/lavender/issuer" + "github.com/1f349/lavender/url" "github.com/1f349/simplemail" ) type Conf struct { Listen string `yaml:"listen"` - BaseUrl string `yaml:"baseUrl"` + BaseUrl url.URL `yaml:"baseUrl"` ServiceName string `yaml:"serviceName"` Issuer string `yaml:"issuer"` Kid string `yaml:"kid"` diff --git a/mail/mail.go b/mail/mail.go index dca2e29..3bcde6b 100644 --- a/mail/mail.go +++ b/mail/mail.go @@ -28,7 +28,7 @@ func New(sender *simplemail.Mail, wd, name string) (*Mail, error) { err := os.Mkdir(mailDir, os.ModePerm) if err == nil || errors.Is(err, os.ErrExist) { wdFs := os.DirFS(mailDir) - o = overlapfs.OverlapFS{A: embeddedTemplates, B: wdFs} + o = overlapfs.OverlapFS{A: o, B: wdFs} } } diff --git a/openid/config.go b/openid/config.go index 99fbdf8..7c52e2b 100644 --- a/openid/config.go +++ b/openid/config.go @@ -1,8 +1,6 @@ package openid -import ( - "strings" -) +import "github.com/1f349/lavender/url" type Config struct { Issuer string `json:"issuer"` @@ -16,21 +14,17 @@ type Config struct { JwksUri string `json:"jwks_uri"` } -func GenConfig(baseUrl string, scopes, claims []string) Config { - baseUrlRaw := baseUrl - if !strings.HasSuffix(baseUrl, "/") { - baseUrl += "/" - } +func GenConfig(baseUrl *url.URL, scopes, claims []string) Config { return Config{ - Issuer: baseUrlRaw, - AuthorizationEndpoint: baseUrl + "authorize", - TokenEndpoint: baseUrl + "token", - UserInfoEndpoint: baseUrl + "userinfo", + Issuer: baseUrl.String(), + AuthorizationEndpoint: baseUrl.Resolve("authorize").String(), + TokenEndpoint: baseUrl.Resolve("token").String(), + UserInfoEndpoint: baseUrl.Resolve("userinfo").String(), ResponseTypesSupported: []string{"code"}, ScopesSupported: scopes, ClaimsSupported: claims, GrantTypesSupported: []string{"authorization_code", "refresh_token"}, - JwksUri: baseUrl + ".well-known/jwks.json", + JwksUri: baseUrl.Resolve(".well-known/jwks.json").String(), } } diff --git a/openid/config_test.go b/openid/config_test.go index d8a8158..b8afa2d 100644 --- a/openid/config_test.go +++ b/openid/config_test.go @@ -1,6 +1,7 @@ package openid import ( + "github.com/1f349/lavender/url" "github.com/stretchr/testify/assert" "testing" ) @@ -16,5 +17,5 @@ func TestGenConfig(t *testing.T) { ClaimsSupported: []string{"name", "email", "preferred_username"}, GrantTypesSupported: []string{"authorization_code", "refresh_token"}, JwksUri: "https://example.com/.well-known/jwks.json", - }, GenConfig("https://example.com", []string{"openid", "email"}, []string{"name", "email", "preferred_username"})) + }, GenConfig(url.MustParse("https://example.com"), []string{"openid", "email"}, []string{"name", "email", "preferred_username"})) } diff --git a/server/login.go b/server/login.go index 992700c..f6d0f5b 100644 --- a/server/login.go +++ b/server/login.go @@ -1,25 +1,29 @@ package server import ( + "bytes" "context" "database/sql" "encoding/json" "errors" "fmt" - auth2 "github.com/1f349/lavender/auth" + "github.com/1f349/lavender/auth" + "github.com/1f349/lavender/auth/authContext" "github.com/1f349/lavender/auth/providers" "github.com/1f349/lavender/database" "github.com/1f349/lavender/database/types" "github.com/1f349/lavender/issuer" + "github.com/1f349/lavender/logger" "github.com/1f349/lavender/web" "github.com/1f349/mjwt" - "github.com/1f349/mjwt/auth" + mjwtAuth "github.com/1f349/mjwt/auth" "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "github.com/julienschmidt/httprouter" "github.com/mrmelon54/pronouns" "golang.org/x/oauth2" "golang.org/x/text/language" + "html/template" "net/http" "net/url" "strings" @@ -42,7 +46,7 @@ func getUserLoginName(req *http.Request) string { return originUrl.Query().Get("login_name") } -func (h *httpServer) testAuthSources(req *http.Request, user *database.User, factor auth2.State) map[string]bool { +func (h *httpServer) testAuthSources(req *http.Request, user *database.User, factor auth.State) map[string]bool { authSource := make(map[string]bool) data := make(map[string]any) for _, i := range h.authSources { @@ -50,23 +54,46 @@ func (h *httpServer) testAuthSources(req *http.Request, user *database.User, fac if i.AccessState() != factor { continue } - page, err := i.RenderTemplate(req.Context(), req, user) - _ = page + err := i.RenderTemplate(authContext.NewTemplateContext(req, user)) authSource[i.Name()] = err == nil clear(data) } return authSource } -func (h *httpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth auth2.UserAuth) { - if !auth.IsGuest() { +func (h *httpServer) getAuthWithState(state auth.State) auth.Provider { + for _, i := range h.authSources { + if i.AccessState() == state { + return i + } + } + return nil +} + +func (h *httpServer) renderAuthTemplate(req *http.Request, provider auth.Provider) (template.HTML, error) { + tmpCtx := authContext.NewTemplateContext(req, new(database.User)) + + err := provider.RenderTemplate(tmpCtx) + if err != nil { + return "", err + } + + w := new(bytes.Buffer) + if web.RenderPageTemplate(w, "auth/"+provider.Name(), tmpCtx.Data()) { + return template.HTML(w.Bytes()), nil + } + return "", fmt.Errorf("failed to render auth template") +} + +func (h *httpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, userAuth auth.UserAuth) { + if !userAuth.IsGuest() { h.SafeRedirect(rw, req) return } cookie, err := req.Cookie("lavender-login-name") if err == nil && cookie.Valid() == nil { - user, err := h.db.GetUser(req.Context(), auth.Subject) + user, err := h.db.GetUser(req.Context(), userAuth.Subject) var userPtr *database.User switch { case err == nil: @@ -78,30 +105,55 @@ func (h *httpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httpr return } - fmt.Printf("%#v\n", h.testAuthSources(req, userPtr, auth2.StateBasic)) + fmt.Printf("%#v\n", h.testAuthSources(req, userPtr, auth.StateBasic)) web.RenderPageTemplate(rw, "login-memory", map[string]any{ "ServiceName": h.conf.ServiceName, "LoginName": cookie.Value, "Redirect": req.URL.Query().Get("redirect"), "Source": "start", - "Auth": h.testAuthSources(req, userPtr, auth2.StateBasic), + "Auth": h.testAuthSources(req, userPtr, auth.StateBasic), }) return } + buttonTemplates := make([]template.HTML, len(h.authButtons)) + for i := range h.authButtons { + buttonTemplates[i] = h.authButtons[i].RenderButtonTemplate(req.Context(), req) + } + + type loginError struct { + Error string `json:"error"` + } + + var renderTemplate template.HTML + + provider := h.getAuthWithState(auth.StateUnauthorized) + + // Maybe the admin has disabled some login providers but does have a button based provider available? + if provider != nil { + renderTemplate, err = h.renderAuthTemplate(req, provider) + if err != nil { + logger.Logger.Warn("No provider for login") + web.RenderPageTemplate(rw, "login-error", loginError{Error: "No available provider for login"}) + return + } + } + // render different page sources web.RenderPageTemplate(rw, "login", map[string]any{ - "ServiceName": h.conf.ServiceName, - "LoginName": "", - "Redirect": req.URL.Query().Get("redirect"), - "Source": "start", - "Auth": h.testAuthSources(req, nil, auth2.StateBasic), + "ServiceName": h.conf.ServiceName, + "LoginName": "", + "Redirect": req.URL.Query().Get("redirect"), + "Source": "start", + "Auth": h.testAuthSources(req, nil, auth.StateUnauthorized), + "AuthTemplate": renderTemplate, + "AuthButtons": buttonTemplates, }) } -func (h *httpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth auth2.UserAuth) { - if !auth.IsGuest() { +func (h *httpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth2 auth.UserAuth) { + if !auth2.IsGuest() { h.SafeRedirect(rw, req) return } @@ -120,7 +172,7 @@ func (h *httpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ http }).String(), http.StatusFound) return } - loginName := req.PostFormValue("loginname") + loginName := req.PostFormValue("email") // append local namespace if @ is missing n := strings.IndexByte(loginName, '@') @@ -156,12 +208,16 @@ func (h *httpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ http SameSite: http.SameSiteLaxMode, }) - var redirectError auth2.RedirectError + var redirectError auth.RedirectError + + // TODO(melon): rewrite login system here // if the login is the local server if login == issuer.MeWellKnown { // TODO(melon): work on this - err := h.authBasic.AttemptLogin(ctx, req, nil) + // TODO: rewrite + //err := h.authBasic.AttemptLogin(ctx, req, nil) + var err error switch { case errors.As(err, &redirectError): http.Redirect(rw, req, redirectError.Target, redirectError.Code) @@ -170,7 +226,9 @@ func (h *httpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ http return } - err := h.authOAuth.AttemptLogin(ctx, req, nil) + // TODO: rewrite + //err := h.authOAuth.AttemptLogin(ctx, req, nil) + var err error switch { case errors.As(err, &redirectError): http.Redirect(rw, req, redirectError.Target, redirectError.Code) @@ -178,14 +236,15 @@ func (h *httpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ http } } -func (h *httpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, userAuth auth2.UserAuth) { - h.authOAuth.OAuthCallback(rw, req, h.updateExternalUserInfo, h.setLoginDataCookie, h.SafeRedirect) +func (h *httpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, _ auth.UserAuth) { + // TODO: rewrite + //h.authOAuth.OAuthCallback(rw, req, h.updateExternalUserInfo, h.setLoginDataCookie, h.SafeRedirect) } -func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth2.UserAuth, error) { +func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) { sessionData, err := h.fetchUserInfo(sso, token) if err != nil || sessionData.Subject == "" { - return auth2.UserAuth{}, fmt.Errorf("failed to fetch user info") + return auth.UserAuth{}, fmt.Errorf("failed to fetch user info") } // TODO(melon): fix this to use a merging of lavender and tulip auth @@ -206,9 +265,9 @@ func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellK err = h.DbTxError(func(tx *database.Queries) error { return h.updateOAuth2UserProfile(req.Context(), tx, sessionData) }) - return auth2.UserAuth{ + return auth.UserAuth{ Subject: userSubject, - Factor: auth2.StateExtended, + Factor: auth.StateExtended, UserInfo: sessionData.UserInfo, }, err case errors.Is(err, sql.ErrNoRows): @@ -216,12 +275,12 @@ func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellK break default: // another error occurred - return auth2.UserAuth{}, err + return auth.UserAuth{}, err } // guard for disabled registration if !sso.Config.Registration { - return auth2.UserAuth{}, fmt.Errorf("registration is not enabled for this authentication source") + return auth.UserAuth{}, fmt.Errorf("registration is not enabled for this authentication source") } // TODO(melon): rework this @@ -246,7 +305,7 @@ func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellK return h.updateOAuth2UserProfile(req.Context(), tx, sessionData) }) if err != nil { - return auth2.UserAuth{}, err + return auth.UserAuth{}, err } // only continues if the above tx succeeds @@ -258,20 +317,20 @@ func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellK Subject: sessionData.Subject, }) }); err != nil { - return auth2.UserAuth{}, err + return auth.UserAuth{}, err } // TODO(melon): this feels bad - sessionData = auth2.UserAuth{ + sessionData = auth.UserAuth{ Subject: userSubject, - Factor: auth2.StateExtended, + Factor: auth.StateExtended, UserInfo: sessionData.UserInfo, } return sessionData, nil } -func (h *httpServer) updateOAuth2UserProfile(ctx context.Context, tx *database.Queries, sessionData auth2.UserAuth) error { +func (h *httpServer) updateOAuth2UserProfile(ctx context.Context, tx *database.Queries, sessionData auth.UserAuth) error { // all of these updates must succeed return tx.UseTx(ctx, func(tx *database.Queries) error { name := sessionData.UserInfo.GetStringOrDefault("name", "Unknown User") @@ -312,9 +371,9 @@ const twelveHours = 12 * time.Hour const oneWeek = 7 * 24 * time.Hour type lavenderLoginAccess struct { - UserInfo auth2.UserInfoFields `json:"user_info"` - Factor auth2.State `json:"factor"` - auth.AccessTokenClaims + UserInfo auth.UserInfoFields `json:"user_info"` + Factor auth.State `json:"factor"` + mjwtAuth.AccessTokenClaims } func (l lavenderLoginAccess) Valid() error { return l.AccessTokenClaims.Valid() } @@ -323,28 +382,28 @@ func (l lavenderLoginAccess) Type() string { return "lavender-login-access" } type lavenderLoginRefresh struct { Login string `json:"login"` - auth.RefreshTokenClaims + mjwtAuth.RefreshTokenClaims } func (l lavenderLoginRefresh) Valid() error { return l.RefreshTokenClaims.Valid() } func (l lavenderLoginRefresh) Type() string { return "lavender-login-refresh" } -func (h *httpServer) setLoginDataCookie(rw http.ResponseWriter, authData auth2.UserAuth, loginName string) bool { - ps := auth.NewPermStorage() +func (h *httpServer) setLoginDataCookie(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool { + ps := mjwtAuth.NewPermStorage() accId := uuid.NewString() - gen, err := h.signingKey.GenerateJwt(authData.Subject, accId, jwt.ClaimStrings{h.conf.BaseUrl}, twelveHours, lavenderLoginAccess{ + gen, err := h.signingKey.GenerateJwt(authData.Subject, accId, jwt.ClaimStrings{h.conf.BaseUrl.String()}, twelveHours, lavenderLoginAccess{ UserInfo: authData.UserInfo, Factor: authData.Factor, - AccessTokenClaims: auth.AccessTokenClaims{Perms: ps}, + AccessTokenClaims: mjwtAuth.AccessTokenClaims{Perms: ps}, }) if err != nil { http.Error(rw, "Failed to generate cookie token", http.StatusInternalServerError) return true } - ref, err := h.signingKey.GenerateJwt(authData.Subject, uuid.NewString(), jwt.ClaimStrings{h.conf.BaseUrl}, oneWeek, lavenderLoginRefresh{ + ref, err := h.signingKey.GenerateJwt(authData.Subject, uuid.NewString(), jwt.ClaimStrings{h.conf.BaseUrl.String()}, oneWeek, lavenderLoginRefresh{ Login: loginName, - RefreshTokenClaims: auth.RefreshTokenClaims{AccessTokenId: accId}, + RefreshTokenClaims: mjwtAuth.RefreshTokenClaims{AccessTokenId: accId}, }) if err != nil { http.Error(rw, "Failed to generate cookie token", http.StatusInternalServerError) @@ -382,12 +441,12 @@ func readJwtCookie[T mjwt.Claims](req *http.Request, cookieName string, signingK return b, nil } -func (h *httpServer) readLoginAccessCookie(rw http.ResponseWriter, req *http.Request, u *auth2.UserAuth) error { +func (h *httpServer) readLoginAccessCookie(rw http.ResponseWriter, req *http.Request, u *auth.UserAuth) error { loginData, err := readJwtCookie[lavenderLoginAccess](req, "lavender-login-access", h.signingKey.KeyStore()) if err != nil { return h.readLoginRefreshCookie(rw, req, u) } - *u = auth2.UserAuth{ + *u = auth.UserAuth{ Subject: loginData.Subject, Factor: loginData.Claims.Factor, UserInfo: loginData.Claims.UserInfo, @@ -395,7 +454,7 @@ func (h *httpServer) readLoginAccessCookie(rw http.ResponseWriter, req *http.Req return nil } -func (h *httpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Request, userAuth *auth2.UserAuth) error { +func (h *httpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Request, userAuth *auth.UserAuth) error { refreshData, err := readJwtCookie[lavenderLoginRefresh](req, "lavender-login-refresh", h.signingKey.KeyStore()) if err != nil { return err @@ -433,28 +492,28 @@ func (h *httpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Re return nil } -func (h *httpServer) fetchUserInfo(sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth2.UserAuth, error) { +func (h *httpServer) fetchUserInfo(sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) { res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint) if err != nil || res.StatusCode != http.StatusOK { - return auth2.UserAuth{}, fmt.Errorf("request failed") + return auth.UserAuth{}, fmt.Errorf("request failed") } defer res.Body.Close() - var userInfoJson auth2.UserInfoFields + var userInfoJson auth.UserInfoFields if err := json.NewDecoder(res.Body).Decode(&userInfoJson); err != nil { - return auth2.UserAuth{}, err + return auth.UserAuth{}, err } subject, ok := userInfoJson.GetString("sub") if !ok { - return auth2.UserAuth{}, fmt.Errorf("invalid subject") + return auth.UserAuth{}, fmt.Errorf("invalid subject") } // TODO(melon): there is no need for this //subject += "@" + sso.Config.Namespace - return auth2.UserAuth{ + return auth.UserAuth{ Subject: subject, - Factor: auth2.StateExtended, + Factor: auth.StateExtended, UserInfo: userInfoJson, }, nil } diff --git a/server/oauth.go b/server/oauth.go index 1046b59..89c0bb2 100644 --- a/server/oauth.go +++ b/server/oauth.go @@ -25,6 +25,9 @@ import ( "time" ) +// TODO(melon): add ldap client, radius client and other login support +// TODO(melon): add ldap server, radius server support + func SetupOAuth2(r *httprouter.Router, hs *httpServer, key *mjwt.Issuer, db *database.Queries) { oauthManager := manage.NewDefaultManager() oauthManager.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) diff --git a/server/openid.go b/server/openid.go index 7a56b9c..fd61587 100644 --- a/server/openid.go +++ b/server/openid.go @@ -5,12 +5,13 @@ import ( "encoding/json" "github.com/1f349/lavender/logger" "github.com/1f349/lavender/openid" + "github.com/1f349/lavender/url" "github.com/1f349/mjwt" "github.com/julienschmidt/httprouter" "net/http" ) -func SetupOpenId(r *httprouter.Router, baseUrl string, signingKey *mjwt.Issuer) { +func SetupOpenId(r *httprouter.Router, baseUrl *url.URL, signingKey *mjwt.Issuer) { openIdConf := openid.GenConfig(baseUrl, []string{ "openid", "name", "username", "profile", "email", "birthdate", "age", "zoneinfo", "locale", }, []string{ diff --git a/server/otp.go b/server/otp.go index 25171fb..ed4dac8 100644 --- a/server/otp.go +++ b/server/otp.go @@ -26,7 +26,10 @@ func (h *httpServer) editOtpPost(rw http.ResponseWriter, req *http.Request, _ ht } otpInput := req.Form.Get("code") - err := h.authOtp.VerifyOtpCode(req.Context(), auth.Subject, otpInput) + _ = otpInput + // TODO: rewrite + //err := h.authOtp.VerifyOtpCode(req.Context(), auth.Subject, otpInput) + var err error if err != nil { http.Error(rw, "Invalid OTP code", http.StatusBadRequest) return diff --git a/server/server.go b/server/server.go index 28bb9e4..a6baef1 100644 --- a/server/server.go +++ b/server/server.go @@ -35,11 +35,8 @@ type httpServer struct { // mailLinkCache contains a mapping of verify uuids to user uuids mailLinkCache *cache.Cache[mailLinkKey, string] - authBasic *providers.BasicLogin - authOtp *providers.OtpLogin - authOAuth *providers.OAuthLogin - authSources []auth.Provider + authButtons []auth.Button } type mailLink byte @@ -56,13 +53,26 @@ type mailLinkKey struct { } func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail, db *database.Queries, signingKey *mjwt.Issuer) { - // remove last slash from baseUrl - config.BaseUrl = strings.TrimRight(config.BaseUrl, "/") - + // TODO: move auth provider init to main function + // TODO: allow dynamically changing the providers based on database information authBasic := &providers.BasicLogin{DB: db} authOtp := &providers.OtpLogin{DB: db} - authOAuth := &providers.OAuthLogin{DB: db, BaseUrl: config.BaseUrl} + authOAuth := &providers.OAuthLogin{DB: db, BaseUrl: &config.BaseUrl} authOAuth.Init() + authPasskey := &providers.PasskeyLogin{DB: db} + + authSources := []auth.Provider{ + authBasic, + authOtp, + authOAuth, + authPasskey, + } + authButtons := make([]auth.Button, 0) + for _, source := range authSources { + if button, isButton := source.(auth.Button); isButton { + authButtons = append(authButtons, button) + } + } hs := &httpServer{ r: r, @@ -73,15 +83,8 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail, mailLinkCache: cache.New[mailLinkKey, string](), - authBasic: authBasic, - authOtp: authOtp, - authOAuth: authOAuth, - //authPasskey: &auth.PasskeyLogin{DB: db}, - - authSources: []auth.Provider{ - authBasic, - authOtp, - }, + authSources: authSources, + authButtons: authButtons, } var err error @@ -90,7 +93,7 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail, logger.Logger.Fatal("Failed to load SSO services", "err", err) } - SetupOpenId(r, config.BaseUrl, signingKey) + SetupOpenId(r, &config.BaseUrl, signingKey) r.GET("/", hs.OptionalAuthentication(false, hs.Home)) r.POST("/logout", hs.RequireAuthentication(hs.logoutPost)) diff --git a/tmp/main b/tmp/main new file mode 100644 index 0000000..75ae253 Binary files /dev/null and b/tmp/main differ diff --git a/url/url.go b/url/url.go new file mode 100644 index 0000000..5aa80ba --- /dev/null +++ b/url/url.go @@ -0,0 +1,40 @@ +package url + +import ( + "encoding" + "net/url" + "path" +) + +type URL struct { + url.URL +} + +func (u *URL) Resolve(paths ...string) *URL { + return &URL{URL: *u.URL.ResolveReference(&url.URL{Path: path.Join(paths...)})} +} + +func (u URL) MarshalText() (text []byte, err error) { + return []byte(u.String()), nil +} + +func (u *URL) UnmarshalText(text []byte) error { + parse, err := u.Parse(string(text)) + if err != nil { + return err + } + + u.URL = *parse + return nil +} + +var _ encoding.TextMarshaler = (*URL)(nil) +var _ encoding.TextUnmarshaler = (*URL)(nil) + +func MustParse(rawURL string) *URL { + u, err := url.Parse(rawURL) + if err != nil { + panic(err) + } + return &URL{*u} +} diff --git a/web/astro.config.mjs b/web/astro.config.mjs index e2b8773..a588bca 100644 --- a/web/astro.config.mjs +++ b/web/astro.config.mjs @@ -7,9 +7,12 @@ import svelte from '@astrojs/svelte'; // https://astro.build/config export default defineConfig({ - integrations: [tailwind({ - nesting: true, - }), svelte()], + integrations: [ + tailwind({ + nesting: true, + }), + svelte({extensions: ['.svelte']}), + ], build: { format: 'file', }, diff --git a/web/src/components/auth-buttons/PasskeyButton.svelte b/web/src/components/auth-buttons/PasskeyButton.svelte new file mode 100644 index 0000000..3e04e69 --- /dev/null +++ b/web/src/components/auth-buttons/PasskeyButton.svelte @@ -0,0 +1,7 @@ + + +Sign in with Passkey diff --git a/web/src/pages/auth-buttons/oauth.astro b/web/src/pages/auth-buttons/oauth.astro new file mode 100644 index 0000000..ee55cc5 --- /dev/null +++ b/web/src/pages/auth-buttons/oauth.astro @@ -0,0 +1,5 @@ +--- +export const partial = true; +--- + +[[ .ButtonName ]] diff --git a/web/src/pages/auth-buttons/passkey.astro b/web/src/pages/auth-buttons/passkey.astro new file mode 100644 index 0000000..6da81e7 --- /dev/null +++ b/web/src/pages/auth-buttons/passkey.astro @@ -0,0 +1,7 @@ +--- +export const partial = true; + +import PasskeyButton from '../../components/auth-buttons/PasskeyButton.svelte'; +--- + + diff --git a/web/src/pages/auth/basic.astro b/web/src/pages/auth/basic.astro new file mode 100644 index 0000000..d044c5a --- /dev/null +++ b/web/src/pages/auth/basic.astro @@ -0,0 +1,16 @@ +--- +export const partial = true; +--- + +
+ +
+ + +
+
+ + +
+ +
diff --git a/web/src/pages/auth/password.astro b/web/src/pages/auth/password.astro deleted file mode 100644 index eb82814..0000000 --- a/web/src/pages/auth/password.astro +++ /dev/null @@ -1,26 +0,0 @@ ---- -export const partial = true; ---- - -
- -
- - -
-
- - -
- -
- -
-

Enter your email address below to receive an email with instructions on how to reset your password.

-

Please note this only works if your email address is already verified.

-
- - -
- -
diff --git a/web/src/pages/auth/sso.astro b/web/src/pages/auth/sso.astro new file mode 100644 index 0000000..8eff4ec --- /dev/null +++ b/web/src/pages/auth/sso.astro @@ -0,0 +1,12 @@ +--- +export const partial = true; +--- + +
+ +
+ + +
+ +
diff --git a/web/src/pages/login.astro b/web/src/pages/login.astro index ca2bf80..fea5857 100644 --- a/web/src/pages/login.astro +++ b/web/src/pages/login.astro @@ -9,6 +9,17 @@ import Layout from "../layouts/Layout.astro";

Check your inbox for a verification email

[[ end ]] [[ .AuthTemplate ]] +
+
+

Enter your email address below to receive an email with instructions on how to reset your password.

+

Please note this only works if your email address is already verified.

+
+ + +
+ +
+
[[ if gt (len .AuthButtons) 0 ]]
[[ range $authButton := .AuthButtons ]] @@ -16,9 +27,4 @@ import Layout from "../layouts/Layout.astro"; [[ end ]]
[[ end ]] - diff --git a/web/web.go b/web/web.go index 88c781d..8ddad9f 100644 --- a/web/web.go +++ b/web/web.go @@ -18,7 +18,7 @@ import ( var ( //go:embed dist - webBuild embed.FS + webDist embed.FS webCombinedDir fs.FS pageTemplates *template.Template @@ -27,6 +27,11 @@ var ( func LoadPages(wd string) error { return loadOnce.Do(func() (err error) { + webBuild, err := fs.Sub(webDist, "dist") + if err != nil { + return err + } + webCombinedDir = webBuild if wd != "" { @@ -44,16 +49,39 @@ func LoadPages(wd string) error { // TODO(melon): figure this out layer webCombinedDir = webBuild - pageTemplates, err = template.New("web").Delims("[[", "]]").Funcs(template.FuncMap{ + pageTemplates, err = findAndParseTemplates(webCombinedDir, template.FuncMap{ "emailHide": utils.EmailHide, "renderOptionTag": renderOptionTag, "renderCheckboxTag": renderCheckboxTag, - }).ParseFS(webCombinedDir, "dist/*.html") - + }) return err }) } +func findAndParseTemplates(rootDir fs.FS, funcMap template.FuncMap) (*template.Template, error) { + root := template.New("") + + err := fs.WalkDir(rootDir, ".", func(p string, d fs.DirEntry, e1 error) error { + if d.IsDir() || !strings.HasSuffix(p, ".html") { + return nil + } + + if e1 != nil { + return e1 + } + + fileContents, err := fs.ReadFile(webCombinedDir, p) + if err != nil { + return err + } + + t := root.New(p).Delims("[[", "]]").Funcs(funcMap) + _, err = t.Parse(string(fileContents)) + return err + }) + return root, err +} + func renderOptionTag(value, display string, selectedValue string) template.HTML { var selectedParam string if value == selectedValue { @@ -70,12 +98,16 @@ func renderCheckboxTag(name, id string, checked bool) template.HTML { return template.HTML("") } -func RenderPageTemplate(wr io.Writer, name string, data any) { +func RenderPageTemplate(wr io.Writer, name string, data any) bool { + logger.Logger.Helper() + p := name + ".html" err := pageTemplates.ExecuteTemplate(wr, p, data) if err != nil { logger.Logger.Warn("Failed to render page", "name", name, "err", err) } + + return err == nil } func RenderWebAsset(rw http.ResponseWriter, req *http.Request, name string) { diff --git a/web/web_test.go b/web/web_test.go index 5fdf002..890ddf1 100644 --- a/web/web_test.go +++ b/web/web_test.go @@ -1,16 +1,20 @@ package web import ( + "embed" "fmt" "github.com/1f349/lavender/utils" "github.com/stretchr/testify/assert" "html/template" "io/fs" + "path" + "slices" + "strings" "testing" ) func TestLoadPages_FindErrors(t *testing.T) { - glob, err := fs.Glob(webBuild, "dist/*/index.html") + glob, err := fs.Glob(webDist, "dist/*/index.html") assert.NoError(t, err) fmt.Println(glob) @@ -21,8 +25,46 @@ func TestLoadPages_FindErrors(t *testing.T) { "emailHide": utils.EmailHide, "renderOptionTag": renderOptionTag, "renderCheckboxTag": renderCheckboxTag, - }).ParseFS(webBuild, fileName) + }).ParseFS(webDist, fileName) assert.NoError(t, err) }) } } + +//go:embed src/pages +var webSrcPages embed.FS + +func TestLoadPage_FindMissing(t *testing.T) { + paths := make([]string, 0) + err := fs.WalkDir(webSrcPages, "src/pages", func(p string, d fs.DirEntry, err error) error { + if d.IsDir() { + return nil + } + if strings.HasSuffix(path.Base(p), ".astro") { + p = strings.TrimPrefix(p, "src/pages/") + p = strings.TrimSuffix(p, ".astro") + p += ".html" + paths = append(paths, p) + } + return nil + }) + assert.NoError(t, err) + + slices.Sort(paths) + + err = LoadPages("") + assert.NoError(t, err) + + tmpls := make([]string, 0) + + for _, i := range pageTemplates.Templates() { + if i.Name() == "" { + continue + } + tmpls = append(tmpls, i.Name()) + } + + slices.Sort(tmpls) + + assert.ElementsMatch(t, paths, tmpls) +}