Separate auth context, move oauth functions to oauth provider

This commit is contained in:
Melon 2025-02-11 21:58:02 +00:00
parent d1ba2a779d
commit 5546b47da8
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
10 changed files with 328 additions and 196 deletions

10
auth/auth-callback.go Normal file
View File

@ -0,0 +1,10 @@
package auth
import "github.com/1f349/lavender/auth/authContext"
type Callback interface {
Provider
// AttemptCallback processes the login request.
AttemptCallback(ctx authContext.TemplateContext) error
}

View File

@ -0,0 +1,31 @@
package authContext
import (
"context"
"github.com/1f349/lavender/database"
"net/http"
)
func NewButtonContext(req *http.Request, user *database.User) *BaseButtonContext {
return &BaseButtonContext{
BaseTemplateContext: BaseTemplateContext{
req: req,
user: user,
},
}
}
type ButtonContext interface {
Context() context.Context
Request() *http.Request
Render(data any)
__buttonContext()
}
var _ ButtonContext = &BaseButtonContext{}
type BaseButtonContext struct {
BaseTemplateContext
}
func (b *BaseButtonContext) __buttonContext() {}

View File

@ -0,0 +1,19 @@
package authContext
import "net/http"
type CallbackContext interface {
TemplateContext
HandleCallback(rw http.ResponseWriter, req *http.Request)
__callbackContext()
}
var _ CallbackContext = &BaseCallbackContext{}
type BaseCallbackContext struct {
BaseTemplateContext
}
func (b *BaseCallbackContext) HandleCallback(rw http.ResponseWriter, req *http.Request) {}
func (b *BaseCallbackContext) __callbackContext() {}

40
auth/authContext/form.go Normal file
View File

@ -0,0 +1,40 @@
package authContext
import (
"github.com/1f349/lavender/auth/login-process"
"github.com/1f349/lavender/database"
"net/http"
)
func NewFormContext(req *http.Request, user *database.User) *BaseFormContext {
return &BaseFormContext{
BaseTemplateContext: BaseTemplateContext{
req: req,
user: user,
},
}
}
type FormContext interface {
TemplateContext
SetUser(user *database.User)
UpdateSession(data process.LoginProcessData)
__formContext()
}
var _ FormContext = &BaseFormContext{}
type BaseFormContext struct {
BaseTemplateContext
loginProcessData process.LoginProcessData
}
func (b *BaseFormContext) SetUser(user *database.User) {
b.BaseTemplateContext.user = user
}
func (b *BaseFormContext) UpdateSession(data process.LoginProcessData) {
b.loginProcessData = data
}
func (b *BaseFormContext) __formContext() {}

View File

@ -2,7 +2,6 @@ package authContext
import (
"context"
"github.com/1f349/lavender/auth/login-process"
"github.com/1f349/lavender/database"
"net/http"
)
@ -19,21 +18,7 @@ type TemplateContext interface {
Request() *http.Request
User() *database.User
Render(data any)
}
type FormContext interface {
Context() context.Context
Request() *http.Request
User() *database.User
SetUser(user *database.User)
Render(data any)
UpdateSession(data login_process.LoginProcessData)
}
type ButtonContext interface {
Context() context.Context
Request() *http.Request
Render(data any)
__templateContext()
}
var _ TemplateContext = &BaseTemplateContext{}
@ -52,6 +37,6 @@ func (t *BaseTemplateContext) User() *database.User { return t.user }
func (t *BaseTemplateContext) Render(data any) { t.data = data }
func (t *BaseTemplateContext) Data() any {
return t.data
}
func (t *BaseTemplateContext) Data() any { return t.data }
func (t *BaseTemplateContext) __templateContext() {}

View File

@ -1,4 +1,4 @@
package login_process
package process
import "github.com/1f349/mjwt"

View File

@ -2,19 +2,29 @@ package providers
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"github.com/1f349/cache"
"github.com/1f349/lavender/auth"
"github.com/1f349/lavender/auth/authContext"
"github.com/1f349/lavender/database"
"github.com/1f349/lavender/database/types"
"github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/url"
"github.com/google/uuid"
"github.com/mrmelon54/pronouns"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"net/http"
"time"
)
type OauthCallback interface {
OAuthCallback(rw http.ResponseWriter, req *http.Request, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request))
}
type flowStateData struct {
loginName string
sso *issuer.WellKnownOIDC
@ -117,3 +127,155 @@ type oauthServiceLogin int
func WithWellKnown(ctx context.Context, login *issuer.WellKnownOIDC) context.Context {
return context.WithValue(ctx, oauthServiceLogin(0), login)
}
func (o OAuthLogin) updateExternalUserInfo(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
sessionData, err := o.fetchUserInfo(sso, token)
if err != nil || sessionData.Subject == "" {
return auth.UserAuth{}, fmt.Errorf("failed to fetch user info")
}
// TODO(melon): fix this to use a merging of lavender and tulip auth
// find an existing user with the matching oauth2 namespace and subject
var userSubject string
err = o.DB.UseTx(req.Context(), func(tx *database.Queries) (err error) {
userSubject, err = tx.FindUserByAuth(req.Context(), database.FindUserByAuthParams{
AuthType: types.AuthTypeOauth2,
AuthNamespace: sso.Namespace,
AuthUser: sessionData.Subject,
})
return
})
switch {
case err == nil:
// user already exists
err = o.DB.UseTx(req.Context(), func(tx *database.Queries) (err error) {
return o.updateOAuth2UserProfile(req.Context(), tx, sessionData)
})
return auth.UserAuth{
Subject: userSubject,
Factor: auth.StateExtended,
UserInfo: sessionData.UserInfo,
}, err
case errors.Is(err, sql.ErrNoRows):
// happy path for registration
break
default:
// another error occurred
return auth.UserAuth{}, err
}
// guard for disabled registration
if !sso.Config.Registration {
return auth.UserAuth{}, fmt.Errorf("registration is not enabled for this authentication source")
}
// TODO(melon): rework this
name := sessionData.UserInfo.GetStringOrDefault("name", "Unknown User")
uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost")
uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified")
err = o.DB.UseTx(req.Context(), func(tx *database.Queries) (err error) {
userSubject, err = tx.AddOAuthUser(req.Context(), database.AddOAuthUserParams{
Email: uEmail,
EmailVerified: uEmailVerified,
Name: name,
Username: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"),
AuthNamespace: sso.Namespace,
AuthUser: sessionData.UserInfo.GetStringOrEmpty("sub"),
})
if err != nil {
return err
}
// if adding the user succeeds then update the profile
return o.updateOAuth2UserProfile(req.Context(), tx, sessionData)
})
if err != nil {
return auth.UserAuth{}, err
}
// only continues if the above tx succeeds
if err := o.DB.UseTx(req.Context(), func(tx *database.Queries) error {
return tx.UpdateUserToken(req.Context(), database.UpdateUserTokenParams{
AccessToken: sql.NullString{String: token.AccessToken, Valid: true},
RefreshToken: sql.NullString{String: token.RefreshToken, Valid: true},
TokenExpiry: sql.NullTime{Time: token.Expiry, Valid: true},
Subject: sessionData.Subject,
})
}); err != nil {
return auth.UserAuth{}, err
}
// TODO(melon): this feels bad
sessionData = auth.UserAuth{
Subject: userSubject,
Factor: auth.StateExtended,
UserInfo: sessionData.UserInfo,
}
return sessionData, nil
}
func (o OAuthLogin) updateOAuth2UserProfile(ctx context.Context, tx *database.Queries, sessionData auth.UserAuth) error {
// all of these updates must succeed
return tx.UseTx(ctx, func(tx *database.Queries) error {
name := sessionData.UserInfo.GetStringOrDefault("name", "Unknown User")
err := tx.ModifyUserRemoteLogin(ctx, database.ModifyUserRemoteLoginParams{
Login: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"),
ProfileUrl: sessionData.UserInfo.GetStringOrEmpty("profile"),
Subject: sessionData.Subject,
})
if err != nil {
return err
}
pronoun, err := pronouns.FindPronoun(sessionData.UserInfo.GetStringOrEmpty("pronouns"))
if err != nil {
pronoun = pronouns.TheyThem
}
locale, err := language.Parse(sessionData.UserInfo.GetStringOrEmpty("locale"))
if err != nil {
locale = language.AmericanEnglish
}
return tx.ModifyProfile(ctx, database.ModifyProfileParams{
Name: name,
Picture: sessionData.UserInfo.GetStringOrEmpty("profile"),
Website: sessionData.UserInfo.GetStringOrEmpty("website"),
Pronouns: types.UserPronoun{Pronoun: pronoun},
Birthdate: sessionData.UserInfo.GetNullDate("birthdate"),
Zone: sessionData.UserInfo.GetStringOrDefault("zoneinfo", "UTC"),
Locale: types.UserLocale{Tag: locale},
UpdatedAt: time.Now(),
Subject: sessionData.Subject,
})
})
}
func (o OAuthLogin) fetchUserInfo(sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint)
if err != nil || res.StatusCode != http.StatusOK {
return auth.UserAuth{}, fmt.Errorf("request failed")
}
defer res.Body.Close()
var userInfoJson auth.UserInfoFields
if err := json.NewDecoder(res.Body).Decode(&userInfoJson); err != nil {
return auth.UserAuth{}, err
}
subject, ok := userInfoJson.GetString("sub")
if !ok {
return auth.UserAuth{}, fmt.Errorf("invalid subject")
}
// TODO(melon): there is no need for this
//subject += "@" + sso.Config.Namespace
return auth.UserAuth{
Subject: subject,
Factor: auth.StateExtended,
UserInfo: userInfoJson,
}, nil
}

View File

@ -4,25 +4,22 @@ import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"github.com/1f349/lavender/auth"
"github.com/1f349/lavender/auth/authContext"
"github.com/1f349/lavender/auth/providers"
"github.com/1f349/lavender/database"
"github.com/1f349/lavender/database/types"
"github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/logger"
"github.com/1f349/lavender/utils"
"github.com/1f349/lavender/web"
"github.com/1f349/mjwt"
mjwtAuth "github.com/1f349/mjwt/auth"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/julienschmidt/httprouter"
"github.com/mrmelon54/pronouns"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"html/template"
"net/http"
"net/url"
@ -72,7 +69,7 @@ func (h *httpServer) renderAuthTemplate(req *http.Request, provider auth.Form) (
func (h *httpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, userAuth auth.UserAuth) {
if !userAuth.IsGuest() {
h.SafeRedirect(rw, req)
utils.SafeRedirect(rw, req)
return
}
@ -148,7 +145,7 @@ func (h *httpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httpr
func (h *httpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth2 auth.UserAuth) {
if !auth2.IsGuest() {
h.SafeRedirect(rw, req)
utils.SafeRedirect(rw, req)
return
}
@ -222,9 +219,27 @@ func (h *httpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ http
return
}
var authForm auth.Form
{
for _, i := range h.authSources {
if form, ok := i.(auth.Form); ok {
if req.PostFormValue("provider") == form.Name() {
authForm = form
break
}
}
}
}
if authForm == nil {
http.Error(rw, "Invalid auth provider", http.StatusBadRequest)
return
}
// TODO: rewrite
//err := h.authOAuth.AttemptLogin(ctx, req, nil)
var err error
formContext := authContext.NewFormContext(req, nil)
err := authForm.AttemptLogin(formContext)
switch {
case errors.As(err, &redirectError):
http.Redirect(rw, req, redirectError.Target, redirectError.Code)
@ -234,133 +249,18 @@ func (h *httpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ http
func (h *httpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, _ auth.UserAuth) {
// TODO: rewrite
//h.authOAuth.OAuthCallback(rw, req, h.updateExternalUserInfo, h.setLoginDataCookie, h.SafeRedirect)
}
func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
sessionData, err := h.fetchUserInfo(sso, token)
if err != nil || sessionData.Subject == "" {
return auth.UserAuth{}, fmt.Errorf("failed to fetch user info")
}
// TODO(melon): fix this to use a merging of lavender and tulip auth
// find an existing user with the matching oauth2 namespace and subject
var userSubject string
err = h.DbTxError(func(tx *database.Queries) (err error) {
userSubject, err = tx.FindUserByAuth(req.Context(), database.FindUserByAuthParams{
AuthType: types.AuthTypeOauth2,
AuthNamespace: sso.Namespace,
AuthUser: sessionData.Subject,
})
return
})
switch {
case err == nil:
// user already exists
err = h.DbTxError(func(tx *database.Queries) error {
return h.updateOAuth2UserProfile(req.Context(), tx, sessionData)
})
return auth.UserAuth{
Subject: userSubject,
Factor: auth.StateExtended,
UserInfo: sessionData.UserInfo,
}, err
case errors.Is(err, sql.ErrNoRows):
// happy path for registration
break
default:
// another error occurred
return auth.UserAuth{}, err
}
// guard for disabled registration
if !sso.Config.Registration {
return auth.UserAuth{}, fmt.Errorf("registration is not enabled for this authentication source")
}
// TODO(melon): rework this
name := sessionData.UserInfo.GetStringOrDefault("name", "Unknown User")
uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost")
uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified")
err = h.DbTxError(func(tx *database.Queries) (err error) {
userSubject, err = tx.AddOAuthUser(req.Context(), database.AddOAuthUserParams{
Email: uEmail,
EmailVerified: uEmailVerified,
Name: name,
Username: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"),
AuthNamespace: sso.Namespace,
AuthUser: sessionData.UserInfo.GetStringOrEmpty("sub"),
})
if err != nil {
return err
for _, i := range h.authSources {
if callback, ok := i.(authContext.CallbackContext); ok {
callback.HandleCallback(rw, req)
user := callback.User()
h.setLoginDataCookie(rw, auth.UserAuth{
Subject: user.Subject,
Factor: auth.StateExtended,
UserInfo: auth.UserInfoFields{},
}, "loginName")
break
}
// if adding the user succeeds then update the profile
return h.updateOAuth2UserProfile(req.Context(), tx, sessionData)
})
if err != nil {
return auth.UserAuth{}, err
}
// only continues if the above tx succeeds
if err := h.DbTxError(func(tx *database.Queries) error {
return tx.UpdateUserToken(req.Context(), database.UpdateUserTokenParams{
AccessToken: sql.NullString{String: token.AccessToken, Valid: true},
RefreshToken: sql.NullString{String: token.RefreshToken, Valid: true},
TokenExpiry: sql.NullTime{Time: token.Expiry, Valid: true},
Subject: sessionData.Subject,
})
}); err != nil {
return auth.UserAuth{}, err
}
// TODO(melon): this feels bad
sessionData = auth.UserAuth{
Subject: userSubject,
Factor: auth.StateExtended,
UserInfo: sessionData.UserInfo,
}
return sessionData, nil
}
func (h *httpServer) updateOAuth2UserProfile(ctx context.Context, tx *database.Queries, sessionData auth.UserAuth) error {
// all of these updates must succeed
return tx.UseTx(ctx, func(tx *database.Queries) error {
name := sessionData.UserInfo.GetStringOrDefault("name", "Unknown User")
err := tx.ModifyUserRemoteLogin(ctx, database.ModifyUserRemoteLoginParams{
Login: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"),
ProfileUrl: sessionData.UserInfo.GetStringOrEmpty("profile"),
Subject: sessionData.Subject,
})
if err != nil {
return err
}
pronoun, err := pronouns.FindPronoun(sessionData.UserInfo.GetStringOrEmpty("pronouns"))
if err != nil {
pronoun = pronouns.TheyThem
}
locale, err := language.Parse(sessionData.UserInfo.GetStringOrEmpty("locale"))
if err != nil {
locale = language.AmericanEnglish
}
return tx.ModifyProfile(ctx, database.ModifyProfileParams{
Name: name,
Picture: sessionData.UserInfo.GetStringOrEmpty("profile"),
Website: sessionData.UserInfo.GetStringOrEmpty("website"),
Pronouns: types.UserPronoun{Pronoun: pronoun},
Birthdate: sessionData.UserInfo.GetNullDate("birthdate"),
Zone: sessionData.UserInfo.GetStringOrDefault("zoneinfo", "UTC"),
Locale: types.UserLocale{Tag: locale},
UpdatedAt: time.Now(),
Subject: sessionData.Subject,
})
})
}
const twelveHours = 12 * time.Hour
@ -476,6 +376,7 @@ func (h *httpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Re
return nil
})
// TODO: not sure how I want to handle this yet...
*userAuth, err = h.updateExternalUserInfo(req, sso, oauthToken)
if err != nil {
return err
@ -488,28 +389,7 @@ func (h *httpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Re
return nil
}
func (h *httpServer) fetchUserInfo(sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint)
if err != nil || res.StatusCode != http.StatusOK {
return auth.UserAuth{}, fmt.Errorf("request failed")
}
defer res.Body.Close()
var userInfoJson auth.UserInfoFields
if err := json.NewDecoder(res.Body).Decode(&userInfoJson); err != nil {
return auth.UserAuth{}, err
}
subject, ok := userInfoJson.GetString("sub")
if !ok {
return auth.UserAuth{}, fmt.Errorf("invalid subject")
}
// TODO(melon): there is no need for this
//subject += "@" + sso.Config.Namespace
return auth.UserAuth{
Subject: subject,
Factor: auth.StateExtended,
UserInfo: userInfoJson,
}, nil
// TODO: not sure how I want to handle this yet...
func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
return auth.UserAuth{}, fmt.Errorf("no")
}

View File

@ -16,7 +16,6 @@ import (
"github.com/go-oauth2/oauth2/v4/server"
"github.com/julienschmidt/httprouter"
"net/http"
"net/url"
"path"
"strings"
)
@ -118,24 +117,6 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail,
SetupOAuth2(r, hs, signingKey, db)
}
func (h *httpServer) SafeRedirect(rw http.ResponseWriter, req *http.Request) {
redirectUrl := req.FormValue("redirect")
if redirectUrl == "" {
http.Redirect(rw, req, "/", http.StatusFound)
return
}
parse, err := url.Parse(redirectUrl)
if err != nil {
http.Error(rw, "Failed to parse redirect url: "+redirectUrl, http.StatusBadRequest)
return
}
if parse.Scheme != "" && parse.Opaque != "" && parse.User != nil && parse.Host != "" {
http.Error(rw, "Invalid redirect url: "+redirectUrl, http.StatusBadRequest)
return
}
http.Redirect(rw, req, parse.String(), http.StatusFound)
}
func ParseClaims(claims string) map[string]bool {
m := make(map[string]bool)
for {

24
utils/safe-redirect.go Normal file
View File

@ -0,0 +1,24 @@
package utils
import (
"net/http"
"net/url"
)
func SafeRedirect(rw http.ResponseWriter, req *http.Request) {
redirectUrl := req.FormValue("redirect")
if redirectUrl == "" {
http.Redirect(rw, req, "/", http.StatusFound)
return
}
parse, err := url.Parse(redirectUrl)
if err != nil {
http.Error(rw, "Failed to parse redirect url: "+redirectUrl, http.StatusBadRequest)
return
}
if parse.Scheme != "" && parse.Opaque != "" && parse.User != nil && parse.Host != "" {
http.Error(rw, "Invalid redirect url: "+redirectUrl, http.StatusBadRequest)
return
}
http.Redirect(rw, req, parse.String(), http.StatusFound)
}