mirror of
https://github.com/1f349/lavender.git
synced 2025-02-23 06:05:08 +00:00
Separate auth context, move oauth functions to oauth provider
This commit is contained in:
parent
d1ba2a779d
commit
5546b47da8
10
auth/auth-callback.go
Normal file
10
auth/auth-callback.go
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import "github.com/1f349/lavender/auth/authContext"
|
||||||
|
|
||||||
|
type Callback interface {
|
||||||
|
Provider
|
||||||
|
|
||||||
|
// AttemptCallback processes the login request.
|
||||||
|
AttemptCallback(ctx authContext.TemplateContext) error
|
||||||
|
}
|
31
auth/authContext/button.go
Normal file
31
auth/authContext/button.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
package authContext
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/1f349/lavender/database"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewButtonContext(req *http.Request, user *database.User) *BaseButtonContext {
|
||||||
|
return &BaseButtonContext{
|
||||||
|
BaseTemplateContext: BaseTemplateContext{
|
||||||
|
req: req,
|
||||||
|
user: user,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ButtonContext interface {
|
||||||
|
Context() context.Context
|
||||||
|
Request() *http.Request
|
||||||
|
Render(data any)
|
||||||
|
__buttonContext()
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ButtonContext = &BaseButtonContext{}
|
||||||
|
|
||||||
|
type BaseButtonContext struct {
|
||||||
|
BaseTemplateContext
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BaseButtonContext) __buttonContext() {}
|
19
auth/authContext/callback.go
Normal file
19
auth/authContext/callback.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package authContext
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type CallbackContext interface {
|
||||||
|
TemplateContext
|
||||||
|
HandleCallback(rw http.ResponseWriter, req *http.Request)
|
||||||
|
__callbackContext()
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ CallbackContext = &BaseCallbackContext{}
|
||||||
|
|
||||||
|
type BaseCallbackContext struct {
|
||||||
|
BaseTemplateContext
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BaseCallbackContext) HandleCallback(rw http.ResponseWriter, req *http.Request) {}
|
||||||
|
|
||||||
|
func (b *BaseCallbackContext) __callbackContext() {}
|
40
auth/authContext/form.go
Normal file
40
auth/authContext/form.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package authContext
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/1f349/lavender/auth/login-process"
|
||||||
|
"github.com/1f349/lavender/database"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewFormContext(req *http.Request, user *database.User) *BaseFormContext {
|
||||||
|
return &BaseFormContext{
|
||||||
|
BaseTemplateContext: BaseTemplateContext{
|
||||||
|
req: req,
|
||||||
|
user: user,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type FormContext interface {
|
||||||
|
TemplateContext
|
||||||
|
SetUser(user *database.User)
|
||||||
|
UpdateSession(data process.LoginProcessData)
|
||||||
|
__formContext()
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ FormContext = &BaseFormContext{}
|
||||||
|
|
||||||
|
type BaseFormContext struct {
|
||||||
|
BaseTemplateContext
|
||||||
|
loginProcessData process.LoginProcessData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BaseFormContext) SetUser(user *database.User) {
|
||||||
|
b.BaseTemplateContext.user = user
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BaseFormContext) UpdateSession(data process.LoginProcessData) {
|
||||||
|
b.loginProcessData = data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BaseFormContext) __formContext() {}
|
@ -2,7 +2,6 @@ package authContext
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/1f349/lavender/auth/login-process"
|
|
||||||
"github.com/1f349/lavender/database"
|
"github.com/1f349/lavender/database"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
@ -19,21 +18,7 @@ type TemplateContext interface {
|
|||||||
Request() *http.Request
|
Request() *http.Request
|
||||||
User() *database.User
|
User() *database.User
|
||||||
Render(data any)
|
Render(data any)
|
||||||
}
|
__templateContext()
|
||||||
|
|
||||||
type FormContext interface {
|
|
||||||
Context() context.Context
|
|
||||||
Request() *http.Request
|
|
||||||
User() *database.User
|
|
||||||
SetUser(user *database.User)
|
|
||||||
Render(data any)
|
|
||||||
UpdateSession(data login_process.LoginProcessData)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ButtonContext interface {
|
|
||||||
Context() context.Context
|
|
||||||
Request() *http.Request
|
|
||||||
Render(data any)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ TemplateContext = &BaseTemplateContext{}
|
var _ TemplateContext = &BaseTemplateContext{}
|
||||||
@ -52,6 +37,6 @@ func (t *BaseTemplateContext) User() *database.User { return t.user }
|
|||||||
|
|
||||||
func (t *BaseTemplateContext) Render(data any) { t.data = data }
|
func (t *BaseTemplateContext) Render(data any) { t.data = data }
|
||||||
|
|
||||||
func (t *BaseTemplateContext) Data() any {
|
func (t *BaseTemplateContext) Data() any { return t.data }
|
||||||
return t.data
|
|
||||||
}
|
func (t *BaseTemplateContext) __templateContext() {}
|
@ -1,4 +1,4 @@
|
|||||||
package login_process
|
package process
|
||||||
|
|
||||||
import "github.com/1f349/mjwt"
|
import "github.com/1f349/mjwt"
|
||||||
|
|
||||||
|
@ -2,19 +2,29 @@ package providers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/1f349/cache"
|
"github.com/1f349/cache"
|
||||||
"github.com/1f349/lavender/auth"
|
"github.com/1f349/lavender/auth"
|
||||||
"github.com/1f349/lavender/auth/authContext"
|
"github.com/1f349/lavender/auth/authContext"
|
||||||
"github.com/1f349/lavender/database"
|
"github.com/1f349/lavender/database"
|
||||||
|
"github.com/1f349/lavender/database/types"
|
||||||
"github.com/1f349/lavender/issuer"
|
"github.com/1f349/lavender/issuer"
|
||||||
"github.com/1f349/lavender/url"
|
"github.com/1f349/lavender/url"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/mrmelon54/pronouns"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/text/language"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type OauthCallback interface {
|
||||||
|
OAuthCallback(rw http.ResponseWriter, req *http.Request, info func(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error), cookie func(rw http.ResponseWriter, authData auth.UserAuth, loginName string) bool, redirect func(rw http.ResponseWriter, req *http.Request))
|
||||||
|
}
|
||||||
|
|
||||||
type flowStateData struct {
|
type flowStateData struct {
|
||||||
loginName string
|
loginName string
|
||||||
sso *issuer.WellKnownOIDC
|
sso *issuer.WellKnownOIDC
|
||||||
@ -117,3 +127,155 @@ type oauthServiceLogin int
|
|||||||
func WithWellKnown(ctx context.Context, login *issuer.WellKnownOIDC) context.Context {
|
func WithWellKnown(ctx context.Context, login *issuer.WellKnownOIDC) context.Context {
|
||||||
return context.WithValue(ctx, oauthServiceLogin(0), login)
|
return context.WithValue(ctx, oauthServiceLogin(0), login)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o OAuthLogin) updateExternalUserInfo(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
|
||||||
|
sessionData, err := o.fetchUserInfo(sso, token)
|
||||||
|
if err != nil || sessionData.Subject == "" {
|
||||||
|
return auth.UserAuth{}, fmt.Errorf("failed to fetch user info")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(melon): fix this to use a merging of lavender and tulip auth
|
||||||
|
|
||||||
|
// find an existing user with the matching oauth2 namespace and subject
|
||||||
|
var userSubject string
|
||||||
|
err = o.DB.UseTx(req.Context(), func(tx *database.Queries) (err error) {
|
||||||
|
userSubject, err = tx.FindUserByAuth(req.Context(), database.FindUserByAuthParams{
|
||||||
|
AuthType: types.AuthTypeOauth2,
|
||||||
|
AuthNamespace: sso.Namespace,
|
||||||
|
AuthUser: sessionData.Subject,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
})
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
// user already exists
|
||||||
|
err = o.DB.UseTx(req.Context(), func(tx *database.Queries) (err error) {
|
||||||
|
return o.updateOAuth2UserProfile(req.Context(), tx, sessionData)
|
||||||
|
})
|
||||||
|
return auth.UserAuth{
|
||||||
|
Subject: userSubject,
|
||||||
|
Factor: auth.StateExtended,
|
||||||
|
UserInfo: sessionData.UserInfo,
|
||||||
|
}, err
|
||||||
|
case errors.Is(err, sql.ErrNoRows):
|
||||||
|
// happy path for registration
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
// another error occurred
|
||||||
|
return auth.UserAuth{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// guard for disabled registration
|
||||||
|
if !sso.Config.Registration {
|
||||||
|
return auth.UserAuth{}, fmt.Errorf("registration is not enabled for this authentication source")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(melon): rework this
|
||||||
|
name := sessionData.UserInfo.GetStringOrDefault("name", "Unknown User")
|
||||||
|
uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost")
|
||||||
|
uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified")
|
||||||
|
|
||||||
|
err = o.DB.UseTx(req.Context(), func(tx *database.Queries) (err error) {
|
||||||
|
userSubject, err = tx.AddOAuthUser(req.Context(), database.AddOAuthUserParams{
|
||||||
|
Email: uEmail,
|
||||||
|
EmailVerified: uEmailVerified,
|
||||||
|
Name: name,
|
||||||
|
Username: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"),
|
||||||
|
AuthNamespace: sso.Namespace,
|
||||||
|
AuthUser: sessionData.UserInfo.GetStringOrEmpty("sub"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// if adding the user succeeds then update the profile
|
||||||
|
return o.updateOAuth2UserProfile(req.Context(), tx, sessionData)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return auth.UserAuth{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// only continues if the above tx succeeds
|
||||||
|
if err := o.DB.UseTx(req.Context(), func(tx *database.Queries) error {
|
||||||
|
return tx.UpdateUserToken(req.Context(), database.UpdateUserTokenParams{
|
||||||
|
AccessToken: sql.NullString{String: token.AccessToken, Valid: true},
|
||||||
|
RefreshToken: sql.NullString{String: token.RefreshToken, Valid: true},
|
||||||
|
TokenExpiry: sql.NullTime{Time: token.Expiry, Valid: true},
|
||||||
|
Subject: sessionData.Subject,
|
||||||
|
})
|
||||||
|
}); err != nil {
|
||||||
|
return auth.UserAuth{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(melon): this feels bad
|
||||||
|
sessionData = auth.UserAuth{
|
||||||
|
Subject: userSubject,
|
||||||
|
Factor: auth.StateExtended,
|
||||||
|
UserInfo: sessionData.UserInfo,
|
||||||
|
}
|
||||||
|
|
||||||
|
return sessionData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o OAuthLogin) updateOAuth2UserProfile(ctx context.Context, tx *database.Queries, sessionData auth.UserAuth) error {
|
||||||
|
// all of these updates must succeed
|
||||||
|
return tx.UseTx(ctx, func(tx *database.Queries) error {
|
||||||
|
name := sessionData.UserInfo.GetStringOrDefault("name", "Unknown User")
|
||||||
|
|
||||||
|
err := tx.ModifyUserRemoteLogin(ctx, database.ModifyUserRemoteLoginParams{
|
||||||
|
Login: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"),
|
||||||
|
ProfileUrl: sessionData.UserInfo.GetStringOrEmpty("profile"),
|
||||||
|
Subject: sessionData.Subject,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pronoun, err := pronouns.FindPronoun(sessionData.UserInfo.GetStringOrEmpty("pronouns"))
|
||||||
|
if err != nil {
|
||||||
|
pronoun = pronouns.TheyThem
|
||||||
|
}
|
||||||
|
locale, err := language.Parse(sessionData.UserInfo.GetStringOrEmpty("locale"))
|
||||||
|
if err != nil {
|
||||||
|
locale = language.AmericanEnglish
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.ModifyProfile(ctx, database.ModifyProfileParams{
|
||||||
|
Name: name,
|
||||||
|
Picture: sessionData.UserInfo.GetStringOrEmpty("profile"),
|
||||||
|
Website: sessionData.UserInfo.GetStringOrEmpty("website"),
|
||||||
|
Pronouns: types.UserPronoun{Pronoun: pronoun},
|
||||||
|
Birthdate: sessionData.UserInfo.GetNullDate("birthdate"),
|
||||||
|
Zone: sessionData.UserInfo.GetStringOrDefault("zoneinfo", "UTC"),
|
||||||
|
Locale: types.UserLocale{Tag: locale},
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
Subject: sessionData.Subject,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o OAuthLogin) fetchUserInfo(sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
|
||||||
|
res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint)
|
||||||
|
if err != nil || res.StatusCode != http.StatusOK {
|
||||||
|
return auth.UserAuth{}, fmt.Errorf("request failed")
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
var userInfoJson auth.UserInfoFields
|
||||||
|
if err := json.NewDecoder(res.Body).Decode(&userInfoJson); err != nil {
|
||||||
|
return auth.UserAuth{}, err
|
||||||
|
}
|
||||||
|
subject, ok := userInfoJson.GetString("sub")
|
||||||
|
if !ok {
|
||||||
|
return auth.UserAuth{}, fmt.Errorf("invalid subject")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(melon): there is no need for this
|
||||||
|
//subject += "@" + sso.Config.Namespace
|
||||||
|
|
||||||
|
return auth.UserAuth{
|
||||||
|
Subject: subject,
|
||||||
|
Factor: auth.StateExtended,
|
||||||
|
UserInfo: userInfoJson,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
190
server/login.go
190
server/login.go
@ -4,25 +4,22 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/1f349/lavender/auth"
|
"github.com/1f349/lavender/auth"
|
||||||
"github.com/1f349/lavender/auth/authContext"
|
"github.com/1f349/lavender/auth/authContext"
|
||||||
"github.com/1f349/lavender/auth/providers"
|
"github.com/1f349/lavender/auth/providers"
|
||||||
"github.com/1f349/lavender/database"
|
"github.com/1f349/lavender/database"
|
||||||
"github.com/1f349/lavender/database/types"
|
|
||||||
"github.com/1f349/lavender/issuer"
|
"github.com/1f349/lavender/issuer"
|
||||||
"github.com/1f349/lavender/logger"
|
"github.com/1f349/lavender/logger"
|
||||||
|
"github.com/1f349/lavender/utils"
|
||||||
"github.com/1f349/lavender/web"
|
"github.com/1f349/lavender/web"
|
||||||
"github.com/1f349/mjwt"
|
"github.com/1f349/mjwt"
|
||||||
mjwtAuth "github.com/1f349/mjwt/auth"
|
mjwtAuth "github.com/1f349/mjwt/auth"
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
"github.com/mrmelon54/pronouns"
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/text/language"
|
|
||||||
"html/template"
|
"html/template"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -72,7 +69,7 @@ func (h *httpServer) renderAuthTemplate(req *http.Request, provider auth.Form) (
|
|||||||
|
|
||||||
func (h *httpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, userAuth auth.UserAuth) {
|
func (h *httpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, userAuth auth.UserAuth) {
|
||||||
if !userAuth.IsGuest() {
|
if !userAuth.IsGuest() {
|
||||||
h.SafeRedirect(rw, req)
|
utils.SafeRedirect(rw, req)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,7 +145,7 @@ func (h *httpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httpr
|
|||||||
|
|
||||||
func (h *httpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth2 auth.UserAuth) {
|
func (h *httpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth2 auth.UserAuth) {
|
||||||
if !auth2.IsGuest() {
|
if !auth2.IsGuest() {
|
||||||
h.SafeRedirect(rw, req)
|
utils.SafeRedirect(rw, req)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -222,9 +219,27 @@ func (h *httpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ http
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var authForm auth.Form
|
||||||
|
|
||||||
|
{
|
||||||
|
for _, i := range h.authSources {
|
||||||
|
if form, ok := i.(auth.Form); ok {
|
||||||
|
if req.PostFormValue("provider") == form.Name() {
|
||||||
|
authForm = form
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if authForm == nil {
|
||||||
|
http.Error(rw, "Invalid auth provider", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: rewrite
|
// TODO: rewrite
|
||||||
//err := h.authOAuth.AttemptLogin(ctx, req, nil)
|
formContext := authContext.NewFormContext(req, nil)
|
||||||
var err error
|
err := authForm.AttemptLogin(formContext)
|
||||||
switch {
|
switch {
|
||||||
case errors.As(err, &redirectError):
|
case errors.As(err, &redirectError):
|
||||||
http.Redirect(rw, req, redirectError.Target, redirectError.Code)
|
http.Redirect(rw, req, redirectError.Target, redirectError.Code)
|
||||||
@ -234,133 +249,18 @@ func (h *httpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ http
|
|||||||
|
|
||||||
func (h *httpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, _ auth.UserAuth) {
|
func (h *httpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, _ auth.UserAuth) {
|
||||||
// TODO: rewrite
|
// TODO: rewrite
|
||||||
//h.authOAuth.OAuthCallback(rw, req, h.updateExternalUserInfo, h.setLoginDataCookie, h.SafeRedirect)
|
for _, i := range h.authSources {
|
||||||
}
|
if callback, ok := i.(authContext.CallbackContext); ok {
|
||||||
|
callback.HandleCallback(rw, req)
|
||||||
func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
|
user := callback.User()
|
||||||
sessionData, err := h.fetchUserInfo(sso, token)
|
h.setLoginDataCookie(rw, auth.UserAuth{
|
||||||
if err != nil || sessionData.Subject == "" {
|
Subject: user.Subject,
|
||||||
return auth.UserAuth{}, fmt.Errorf("failed to fetch user info")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(melon): fix this to use a merging of lavender and tulip auth
|
|
||||||
|
|
||||||
// find an existing user with the matching oauth2 namespace and subject
|
|
||||||
var userSubject string
|
|
||||||
err = h.DbTxError(func(tx *database.Queries) (err error) {
|
|
||||||
userSubject, err = tx.FindUserByAuth(req.Context(), database.FindUserByAuthParams{
|
|
||||||
AuthType: types.AuthTypeOauth2,
|
|
||||||
AuthNamespace: sso.Namespace,
|
|
||||||
AuthUser: sessionData.Subject,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
})
|
|
||||||
switch {
|
|
||||||
case err == nil:
|
|
||||||
// user already exists
|
|
||||||
err = h.DbTxError(func(tx *database.Queries) error {
|
|
||||||
return h.updateOAuth2UserProfile(req.Context(), tx, sessionData)
|
|
||||||
})
|
|
||||||
return auth.UserAuth{
|
|
||||||
Subject: userSubject,
|
|
||||||
Factor: auth.StateExtended,
|
Factor: auth.StateExtended,
|
||||||
UserInfo: sessionData.UserInfo,
|
UserInfo: auth.UserInfoFields{},
|
||||||
}, err
|
}, "loginName")
|
||||||
case errors.Is(err, sql.ErrNoRows):
|
|
||||||
// happy path for registration
|
|
||||||
break
|
break
|
||||||
default:
|
|
||||||
// another error occurred
|
|
||||||
return auth.UserAuth{}, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// guard for disabled registration
|
|
||||||
if !sso.Config.Registration {
|
|
||||||
return auth.UserAuth{}, fmt.Errorf("registration is not enabled for this authentication source")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(melon): rework this
|
|
||||||
name := sessionData.UserInfo.GetStringOrDefault("name", "Unknown User")
|
|
||||||
uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost")
|
|
||||||
uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified")
|
|
||||||
|
|
||||||
err = h.DbTxError(func(tx *database.Queries) (err error) {
|
|
||||||
userSubject, err = tx.AddOAuthUser(req.Context(), database.AddOAuthUserParams{
|
|
||||||
Email: uEmail,
|
|
||||||
EmailVerified: uEmailVerified,
|
|
||||||
Name: name,
|
|
||||||
Username: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"),
|
|
||||||
AuthNamespace: sso.Namespace,
|
|
||||||
AuthUser: sessionData.UserInfo.GetStringOrEmpty("sub"),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// if adding the user succeeds then update the profile
|
|
||||||
return h.updateOAuth2UserProfile(req.Context(), tx, sessionData)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return auth.UserAuth{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// only continues if the above tx succeeds
|
|
||||||
if err := h.DbTxError(func(tx *database.Queries) error {
|
|
||||||
return tx.UpdateUserToken(req.Context(), database.UpdateUserTokenParams{
|
|
||||||
AccessToken: sql.NullString{String: token.AccessToken, Valid: true},
|
|
||||||
RefreshToken: sql.NullString{String: token.RefreshToken, Valid: true},
|
|
||||||
TokenExpiry: sql.NullTime{Time: token.Expiry, Valid: true},
|
|
||||||
Subject: sessionData.Subject,
|
|
||||||
})
|
|
||||||
}); err != nil {
|
|
||||||
return auth.UserAuth{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(melon): this feels bad
|
|
||||||
sessionData = auth.UserAuth{
|
|
||||||
Subject: userSubject,
|
|
||||||
Factor: auth.StateExtended,
|
|
||||||
UserInfo: sessionData.UserInfo,
|
|
||||||
}
|
|
||||||
|
|
||||||
return sessionData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *httpServer) updateOAuth2UserProfile(ctx context.Context, tx *database.Queries, sessionData auth.UserAuth) error {
|
|
||||||
// all of these updates must succeed
|
|
||||||
return tx.UseTx(ctx, func(tx *database.Queries) error {
|
|
||||||
name := sessionData.UserInfo.GetStringOrDefault("name", "Unknown User")
|
|
||||||
|
|
||||||
err := tx.ModifyUserRemoteLogin(ctx, database.ModifyUserRemoteLoginParams{
|
|
||||||
Login: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"),
|
|
||||||
ProfileUrl: sessionData.UserInfo.GetStringOrEmpty("profile"),
|
|
||||||
Subject: sessionData.Subject,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
pronoun, err := pronouns.FindPronoun(sessionData.UserInfo.GetStringOrEmpty("pronouns"))
|
|
||||||
if err != nil {
|
|
||||||
pronoun = pronouns.TheyThem
|
|
||||||
}
|
|
||||||
locale, err := language.Parse(sessionData.UserInfo.GetStringOrEmpty("locale"))
|
|
||||||
if err != nil {
|
|
||||||
locale = language.AmericanEnglish
|
|
||||||
}
|
|
||||||
|
|
||||||
return tx.ModifyProfile(ctx, database.ModifyProfileParams{
|
|
||||||
Name: name,
|
|
||||||
Picture: sessionData.UserInfo.GetStringOrEmpty("profile"),
|
|
||||||
Website: sessionData.UserInfo.GetStringOrEmpty("website"),
|
|
||||||
Pronouns: types.UserPronoun{Pronoun: pronoun},
|
|
||||||
Birthdate: sessionData.UserInfo.GetNullDate("birthdate"),
|
|
||||||
Zone: sessionData.UserInfo.GetStringOrDefault("zoneinfo", "UTC"),
|
|
||||||
Locale: types.UserLocale{Tag: locale},
|
|
||||||
UpdatedAt: time.Now(),
|
|
||||||
Subject: sessionData.Subject,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const twelveHours = 12 * time.Hour
|
const twelveHours = 12 * time.Hour
|
||||||
@ -476,6 +376,7 @@ func (h *httpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Re
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// TODO: not sure how I want to handle this yet...
|
||||||
*userAuth, err = h.updateExternalUserInfo(req, sso, oauthToken)
|
*userAuth, err = h.updateExternalUserInfo(req, sso, oauthToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -488,28 +389,7 @@ func (h *httpServer) readLoginRefreshCookie(rw http.ResponseWriter, req *http.Re
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpServer) fetchUserInfo(sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
|
// TODO: not sure how I want to handle this yet...
|
||||||
res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint)
|
func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellKnownOIDC, token *oauth2.Token) (auth.UserAuth, error) {
|
||||||
if err != nil || res.StatusCode != http.StatusOK {
|
return auth.UserAuth{}, fmt.Errorf("no")
|
||||||
return auth.UserAuth{}, fmt.Errorf("request failed")
|
|
||||||
}
|
|
||||||
defer res.Body.Close()
|
|
||||||
|
|
||||||
var userInfoJson auth.UserInfoFields
|
|
||||||
if err := json.NewDecoder(res.Body).Decode(&userInfoJson); err != nil {
|
|
||||||
return auth.UserAuth{}, err
|
|
||||||
}
|
|
||||||
subject, ok := userInfoJson.GetString("sub")
|
|
||||||
if !ok {
|
|
||||||
return auth.UserAuth{}, fmt.Errorf("invalid subject")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(melon): there is no need for this
|
|
||||||
//subject += "@" + sso.Config.Namespace
|
|
||||||
|
|
||||||
return auth.UserAuth{
|
|
||||||
Subject: subject,
|
|
||||||
Factor: auth.StateExtended,
|
|
||||||
UserInfo: userInfoJson,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,6 @@ import (
|
|||||||
"github.com/go-oauth2/oauth2/v4/server"
|
"github.com/go-oauth2/oauth2/v4/server"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@ -118,24 +117,6 @@ func SetupRouter(r *httprouter.Router, config conf.Conf, mailSender *mail.Mail,
|
|||||||
SetupOAuth2(r, hs, signingKey, db)
|
SetupOAuth2(r, hs, signingKey, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpServer) SafeRedirect(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
redirectUrl := req.FormValue("redirect")
|
|
||||||
if redirectUrl == "" {
|
|
||||||
http.Redirect(rw, req, "/", http.StatusFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
parse, err := url.Parse(redirectUrl)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(rw, "Failed to parse redirect url: "+redirectUrl, http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if parse.Scheme != "" && parse.Opaque != "" && parse.User != nil && parse.Host != "" {
|
|
||||||
http.Error(rw, "Invalid redirect url: "+redirectUrl, http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
http.Redirect(rw, req, parse.String(), http.StatusFound)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseClaims(claims string) map[string]bool {
|
func ParseClaims(claims string) map[string]bool {
|
||||||
m := make(map[string]bool)
|
m := make(map[string]bool)
|
||||||
for {
|
for {
|
||||||
|
24
utils/safe-redirect.go
Normal file
24
utils/safe-redirect.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SafeRedirect(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
redirectUrl := req.FormValue("redirect")
|
||||||
|
if redirectUrl == "" {
|
||||||
|
http.Redirect(rw, req, "/", http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
parse, err := url.Parse(redirectUrl)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, "Failed to parse redirect url: "+redirectUrl, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if parse.Scheme != "" && parse.Opaque != "" && parse.User != nil && parse.Host != "" {
|
||||||
|
http.Error(rw, "Invalid redirect url: "+redirectUrl, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Redirect(rw, req, parse.String(), http.StatusFound)
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user