diff --git a/auth/authContext/form.go b/auth/authContext/form.go index e4c47ea..9a73fa8 100644 --- a/auth/authContext/form.go +++ b/auth/authContext/form.go @@ -19,7 +19,7 @@ func NewFormContext(req *http.Request, user *database.User, rw http.ResponseWrit type FormContext interface { TemplateContext SetUser(user *database.User) - UpdateSession(data process.LoginProcessData) + UpdateSession(update process.UpdateLoginProcessData) GetLoginProcessData() process.LoginProcessData ResponseWriter() http.ResponseWriter __formContext() @@ -39,7 +39,9 @@ func (b *BaseFormContext) GetLoginProcessData() process.LoginProcessData { retur func (b *BaseFormContext) SetUser(user *database.User) { b.BaseTemplateContext.user = user } -func (b *BaseFormContext) UpdateSession(data process.LoginProcessData) { b.loginProcessData = data } +func (b *BaseFormContext) UpdateSession(data process.UpdateLoginProcessData) { + b.loginProcessData = b.loginProcessData.Merge(data) +} func (b *BaseFormContext) ResponseWriter() http.ResponseWriter { return b.rw } diff --git a/auth/process/login-process-data.go b/auth/process/login-process-data.go index da13bd5..1507ab5 100644 --- a/auth/process/login-process-data.go +++ b/auth/process/login-process-data.go @@ -1,7 +1,9 @@ package process import ( + "errors" "github.com/1f349/mjwt" + "github.com/gobuffalo/nulls" ) var _ mjwt.Claims = (*LoginProcessData)(nil) @@ -13,10 +15,34 @@ var _ mjwt.Claims = (*LoginProcessData)(nil) // // TODO: add some actual session management type LoginProcessData struct { - State State - Email string + State State + Email string + Subject string } -func (d LoginProcessData) Valid() error { return nil } - func (d LoginProcessData) Type() string { return "login-process" } + +func (d LoginProcessData) Valid() error { + if !d.State.IsValid() { + return errors.New("invalid state") + } + return nil +} + +func (d LoginProcessData) Merge(update UpdateLoginProcessData) LoginProcessData { + d.State = update.State + if update.Email.Valid { + d.Email = update.Email.String + } + if update.Subject.Valid { + d.Subject = update.Subject.String + } + return d +} + +// UpdateLoginProcessData will modify the values in LoginProcessData using Merge +type UpdateLoginProcessData struct { + State State + Email nulls.String + Subject nulls.String +} diff --git a/auth/providers/base.go b/auth/providers/base.go index 9e7d2a3..e8fdba6 100644 --- a/auth/providers/base.go +++ b/auth/providers/base.go @@ -10,6 +10,7 @@ import ( "github.com/1f349/lavender/issuer" "github.com/1f349/lavender/logger" "github.com/1f349/lavender/utils" + "github.com/gobuffalo/nulls" "github.com/google/uuid" "golang.org/x/oauth2" "net/http" @@ -140,9 +141,9 @@ func (b *Base) AttemptLogin(ctx authContext.FormContext) error { } } - ctx.UpdateSession(process.LoginProcessData{ - Email: loginName, + ctx.UpdateSession(process.UpdateLoginProcessData{ State: process.StateBase, + Email: nulls.NewString(loginName), }) return nil diff --git a/auth/providers/otp.go b/auth/providers/otp.go index 6245cd4..0526025 100644 --- a/auth/providers/otp.go +++ b/auth/providers/otp.go @@ -66,6 +66,10 @@ func (o *OtpLogin) AttemptLogin(ctx authContext.FormContext) error { if !validateTotp(user.OtpSecret, int(user.OtpDigits), code) { return auth.BasicUserSafeError(http.StatusBadRequest, "invalid OTP code") } + + ctx.UpdateSession(process.UpdateLoginProcessData{ + State: process.StateAuthenticated, + }) return nil } diff --git a/auth/providers/password.go b/auth/providers/password.go index 130264a..a55b949 100644 --- a/auth/providers/password.go +++ b/auth/providers/password.go @@ -8,6 +8,7 @@ import ( "github.com/1f349/lavender/auth/authContext" "github.com/1f349/lavender/auth/process" "github.com/1f349/lavender/database" + "github.com/gobuffalo/nulls" "net/http" ) @@ -64,9 +65,10 @@ func (p *PasswordLogin) AttemptLogin(ctx authContext.FormContext) error { return err } ctx.SetUser(&user) - ctx.UpdateSession(process.LoginProcessData{ + ctx.UpdateSession(process.UpdateLoginProcessData{ State: process.StateBasic, - Email: un, + Email: nulls.NewString(un), + Subject: nulls.NewString(user.Subject), }) return nil case errors.Is(err, sql.ErrNoRows): diff --git a/go.mod b/go.mod index 2a29420..6b3919b 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/cloudflare/tableflip v1.2.3 github.com/emersion/go-message v0.18.2 github.com/go-oauth2/oauth2/v4 v4.5.2 + github.com/gobuffalo/nulls v0.4.2 github.com/golang-jwt/jwt/v4 v4.5.1 github.com/golang-migrate/migrate/v4 v4.18.1 github.com/google/subcommands v1.2.0 @@ -42,6 +43,7 @@ require ( github.com/emersion/go-smtp v0.21.3 // indirect github.com/go-jose/go-jose/v4 v4.0.4 // indirect github.com/go-logfmt/logfmt v0.6.0 // indirect + github.com/gofrs/uuid v4.2.0+incompatible // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index d43f63e..0b6e246 100644 --- a/go.sum +++ b/go.sum @@ -49,6 +49,11 @@ github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KE github.com/go-oauth2/oauth2/v4 v4.5.2 h1:CuZhD3lhGuI6aNLyUbRHXsgG2RwGRBOuCBfd4WQKqBQ= github.com/go-oauth2/oauth2/v4 v4.5.2/go.mod h1:wk/2uLImWIa9VVQDgxz99H2GDbhmfi/9/Xr+GvkSUSQ= github.com/go-session/session v3.1.2+incompatible/go.mod h1:8B3iivBQjrz/JtC68Np2T1yBBLxTan3mn/3OM0CyRt0= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/gobuffalo/nulls v0.4.2 h1:GAqBR29R3oPY+WCC7JL9KKk9erchaNuV6unsOSZGQkw= +github.com/gobuffalo/nulls v0.4.2/go.mod h1:EElw2zmBYafU2R9W4Ii1ByIj177wA/pc0JdjtD0EsH8= +github.com/gofrs/uuid v4.2.0+incompatible h1:yyYWMnhkhrKwwr8gAOcOCYxOOscHgDS9yZgBrnJfGa0= +github.com/gofrs/uuid v4.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/golang-jwt/jwt v3.2.1+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= @@ -92,6 +97,8 @@ github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9 github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk= github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA= +github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= +github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= @@ -107,6 +114,7 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= @@ -117,6 +125,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/moul/http2curl v1.0.0 h1:dRMWoAtb+ePxMlLkrCbAqh4TlPHXvoGUSQ323/9Zahs= @@ -154,9 +164,12 @@ github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9 github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs= github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/assert v0.1.0 h1:aWcKyRBUAdLoVebxo95N7+YZVTFF/ASTr7BN4sLP6XI= diff --git a/server/login.go b/server/login.go index c0db5e5..5bb8087 100644 --- a/server/login.go +++ b/server/login.go @@ -50,8 +50,8 @@ func (h *httpServer) getAuthWithState(state process.State) auth.Provider { return nil } -func (h *httpServer) renderAuthTemplate(req *http.Request, provider auth.Form, processData process.LoginProcessData) (template.HTML, error) { - tmpCtx := authContext.NewTemplateContext(req, new(database.User), processData) +func (h *httpServer) renderAuthTemplate(req *http.Request, provider auth.Form, processData process.LoginProcessData, user *database.User) (template.HTML, error) { + tmpCtx := authContext.NewTemplateContext(req, user, processData) err := provider.RenderTemplate(tmpCtx) if err != nil { @@ -72,13 +72,16 @@ func (h *httpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httpr } var processData process.LoginProcessData + var user *database.User jwtCookie, err := readJwtCookie[process.LoginProcessData](req, "lavender-login-process", h.signingKey.KeyStore()) if err == nil { processData = jwtCookie.Claims + user = h.resolveUser(req.Context(), processData) } // TODO: some of this should be more like tulip + fmt.Println("Starting login process with data", "process", processData) buttonCtx := authContext.NewTemplateContext(req, new(database.User), processData) @@ -100,15 +103,29 @@ func (h *httpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httpr // Maybe the admin has disabled some login providers but does have a button based provider available? form, ok := provider.(auth.Form) - if provider != nil && ok { - renderTemplate, err = h.renderAuthTemplate(req, form, processData) - if err != nil { - logger.Logger.Warn("No provider for login") - web.RenderPageTemplate(rw, "login-error", struct { - Error string `json:"error"` - }{Error: "No available provider for login"}) - return - } + if provider == nil || !ok { + logger.Logger.Warn("Provider does not support forms", "state", processData.State, "provider", provider) + web.RenderPageTemplate(rw, "login-error", struct { + ServiceName string `json:"service_name"` + Error string `json:"error"` + }{ + ServiceName: h.conf.ServiceName, + Error: "No available provider for login", + }) + return + } + + renderTemplate, err = h.renderAuthTemplate(req, form, processData, user) + if err != nil { + logger.Logger.Warn("renderAuthTemplate()", "state", processData.State, "provider", provider.Name(), "err", err) + web.RenderPageTemplate(rw, "login-error", struct { + ServiceName string `json:"service_name"` + Error string `json:"error"` + }{ + ServiceName: h.conf.ServiceName, + Error: "No available provider for login", + }) + return } // render different page sources @@ -129,6 +146,7 @@ func (h *httpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ http } var processData process.LoginProcessData + jwtCookie, err := readJwtCookie[process.LoginProcessData](req, "lavender-login-process", h.signingKey.KeyStore()) if err == nil { processData = jwtCookie.Claims @@ -350,6 +368,17 @@ func (h *httpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Re return nil } +func (h *httpServer) resolveUser(ctx context.Context, data process.LoginProcessData) *database.User { + // resolve database.User struct + if data.Subject != "" { + userRaw, err := h.db.GetUser(ctx, data.Subject) + if err == nil { + return &userRaw + } + } + return nil +} + // TODO: not sure how I want to handle this yet... func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) { return auth.UserAuth{}, fmt.Errorf("no") diff --git a/web/src/pages/login-error.astro b/web/src/pages/login-error.astro new file mode 100644 index 0000000..77bc9d3 --- /dev/null +++ b/web/src/pages/login-error.astro @@ -0,0 +1,7 @@ +--- +import Layout from '../layouts/Layout.astro'; +--- + + +

An error occurred while logging in: [[ .Error ]]

+