diff --git a/conf/conf.go b/conf/conf.go index 49c09eb..bd8f734 100644 --- a/conf/conf.go +++ b/conf/conf.go @@ -6,13 +6,13 @@ import ( ) type Conf struct { - Listen string `yaml:"listen"` - BaseUrl string `yaml:"baseUrl"` - ServiceName string `yaml:"serviceName"` - Issuer string `yaml:"issuer"` - Kid string `yaml:"kid"` - Namespace string `yaml:"namespace"` - OtpIssuer string `yaml:"otpIssuer"` - Mail mail.Mail `yaml:"mail"` - SsoServices map[string]issuer.SsoConfig `yaml:"ssoServices"` + Listen string `yaml:"listen"` + BaseUrl string `yaml:"baseUrl"` + ServiceName string `yaml:"serviceName"` + Issuer string `yaml:"issuer"` + Kid string `yaml:"kid"` + Namespace string `yaml:"namespace"` + OtpIssuer string `yaml:"otpIssuer"` + Mail mail.Mail `yaml:"mail"` + SsoServices []issuer.SsoConfig `yaml:"ssoServices"` } diff --git a/database/queries/users.sql b/database/queries/users.sql index 0d33bfc..2c57ab9 100644 --- a/database/queries/users.sql +++ b/database/queries/users.sql @@ -50,3 +50,10 @@ UPDATE users SET active= false, to_delete = true WHERE subject = ?; + +-- name: FindUserByAuth :one +SELECT subject +FROM users +WHERE auth_type = ? + AND auth_namespace = ? + AND auth_user = ?; diff --git a/database/users.sql.go b/database/users.sql.go index 7b7b6ef..ce0da96 100644 --- a/database/users.sql.go +++ b/database/users.sql.go @@ -13,6 +13,27 @@ import ( "github.com/1f349/lavender/password" ) +const findUserByAuth = `-- name: FindUserByAuth :one +SELECT subject +FROM users +WHERE auth_type = ? + AND auth_namespace = ? + AND auth_user = ? +` + +type FindUserByAuthParams struct { + AuthType types.AuthType `json:"auth_type"` + AuthNamespace string `json:"auth_namespace"` + AuthUser string `json:"auth_user"` +} + +func (q *Queries) FindUserByAuth(ctx context.Context, arg FindUserByAuthParams) (string, error) { + row := q.db.QueryRowContext(ctx, findUserByAuth, arg.AuthType, arg.AuthNamespace, arg.AuthUser) + var subject string + err := row.Scan(&subject) + return subject, err +} + const flagUserAsDeleted = `-- name: FlagUserAsDeleted :exec UPDATE users SET active= false, diff --git a/issuer/sso.go b/issuer/sso.go index 59ddf40..3683a17 100644 --- a/issuer/sso.go +++ b/issuer/sso.go @@ -9,7 +9,6 @@ import ( "net/http" "net/url" "slices" - "strings" ) var httpGet = http.Get @@ -17,9 +16,11 @@ var httpGet = http.Get // SsoConfig is the base URL for an OAUTH/OPENID/SSO login service // The path `/.well-known/openid-configuration` should be available type SsoConfig struct { - Addr utils.JsonUrl `json:"addr"` // https://login.example.com - Namespace string `json:"namespace"` // example.com - Client SsoConfigClient `json:"client"` + Addr utils.JsonUrl `json:"addr" yaml:"addr"` // https://login.example.com + Namespace string `json:"namespace" yaml:"namespace"` // example.com + Registration bool `json:"registration" yaml:"registration"` + LoginWithButton bool `json:"login_with_button" yaml:"loginWithButton"` + Client SsoConfigClient `json:"client" yaml:"client"` } type SsoConfigClient struct { @@ -30,14 +31,10 @@ type SsoConfigClient struct { func (s SsoConfig) FetchConfig() (*WellKnownOIDC, error) { // generate openid config url - u := s.Addr.String() - if !strings.HasSuffix(u, "/") { - u += "/" - } - u += ".well-known/openid-configuration" + u := s.Addr.JoinPath(".well-known/openid-configuration") // fetch metadata - get, err := httpGet(u) + get, err := httpGet(u.String()) if err != nil { return nil, err } diff --git a/server/db.go b/server/db.go index 4627152..e32bf5c 100644 --- a/server/db.go +++ b/server/db.go @@ -1,13 +1,24 @@ package server import ( - "errors" "github.com/1f349/lavender/database" "github.com/1f349/lavender/logger" "net/http" ) -var ErrDatabaseActionFailed = errors.New("database action failed") +var _ error = (*ErrDatabaseActionFailed)(nil) + +type ErrDatabaseActionFailed struct { + err error +} + +func (e ErrDatabaseActionFailed) Error() string { + return "database action failed: " + e.err.Error() +} + +func (e ErrDatabaseActionFailed) Unwrap() error { + return e.err +} // DbTx wraps a database transaction with http error messages and a simple action // function. If the action function returns an error the transaction will be @@ -27,7 +38,7 @@ func (h *httpServer) DbTxError(action func(tx *database.Queries) error) error { err := action(h.db) if err != nil { logger.Logger.Warn("Database action error", "err", err) - return ErrDatabaseActionFailed + return ErrDatabaseActionFailed{err: err} } return nil } diff --git a/server/login.go b/server/login.go index 96c56b6..5a55abd 100644 --- a/server/login.go +++ b/server/login.go @@ -149,72 +149,62 @@ func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellK return UserAuth{}, fmt.Errorf("failed to fetch user info") } - err = h.DbTxError(func(tx *database.Queries) error { - name := sessionData.UserInfo.GetStringOrDefault("name", "Unknown User") + // TODO(melon): fix this to use a merging of lavender and tulip auth - _, err = tx.GetUser(req.Context(), sessionData.Subject) - uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost") - uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified") - if errors.Is(err, sql.ErrNoRows) { - _, err := tx.AddOAuthUser(req.Context(), database.AddOAuthUserParams{ - Email: uEmail, - EmailVerified: uEmailVerified, - Name: name, - Username: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"), - AuthNamespace: sso.Namespace, - AuthUser: sessionData.UserInfo.GetStringOrEmpty("sub"), - }) - return err - } - - err = tx.ModifyUserEmail(req.Context(), database.ModifyUserEmailParams{ - Email: uEmail, - EmailVerified: uEmailVerified, - Subject: sessionData.Subject, - }) - if err != nil { - return err - } - - err = tx.ModifyUserAuth(req.Context(), database.ModifyUserAuthParams{ + // find an existing user with the matching oauth2 namespace and subject + var userSubject string + err = h.DbTxError(func(tx *database.Queries) (err error) { + userSubject, err = tx.FindUserByAuth(req.Context(), database.FindUserByAuthParams{ AuthType: types.AuthTypeOauth2, AuthNamespace: sso.Namespace, + AuthUser: sessionData.Subject, + }) + return + }) + switch { + case err == nil: + // user already exists + err = h.DbTxError(func(tx *database.Queries) error { + return h.updateOAuth2UserProfile(req.Context(), tx, sessionData) + }) + return UserAuth{ + Subject: userSubject, + NeedOtp: sessionData.NeedOtp, + UserInfo: sessionData.UserInfo, + }, err + case errors.Is(err, sql.ErrNoRows): + // happy path for registration + break + default: + // another error occurred + return UserAuth{}, err + } + + // guard for disabled registration + if !sso.Config.Registration { + return UserAuth{}, fmt.Errorf("registration is not enabled for this authentication source") + } + + // TODO(melon): rework this + name := sessionData.UserInfo.GetStringOrDefault("name", "Unknown User") + uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost") + uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified") + + err = h.DbTxError(func(tx *database.Queries) (err error) { + userSubject, err = tx.AddOAuthUser(req.Context(), database.AddOAuthUserParams{ + Email: uEmail, + EmailVerified: uEmailVerified, + Name: name, + Username: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"), + AuthNamespace: sso.Namespace, AuthUser: sessionData.UserInfo.GetStringOrEmpty("sub"), - Subject: sessionData.Subject, }) if err != nil { return err } - err = tx.ModifyUserRemoteLogin(req.Context(), database.ModifyUserRemoteLoginParams{ - Login: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"), - ProfileUrl: sessionData.UserInfo.GetStringOrEmpty("profile"), - Subject: sessionData.Subject, - }) - if err != nil { - return err - } - - pronoun, err := pronouns.FindPronoun(sessionData.UserInfo.GetStringOrEmpty("pronouns")) - if err != nil { - pronoun = pronouns.TheyThem - } - locale, err := language.Parse(sessionData.UserInfo.GetStringOrEmpty("locale")) - if err != nil { - locale = language.AmericanEnglish - } - - return tx.ModifyProfile(req.Context(), database.ModifyProfileParams{ - Name: name, - Picture: sessionData.UserInfo.GetStringOrEmpty("profile"), - Website: sessionData.UserInfo.GetStringOrEmpty("website"), - Pronouns: types.UserPronoun{Pronoun: pronoun}, - Birthdate: sessionData.UserInfo.GetNullDate("birthdate"), - Zone: sessionData.UserInfo.GetStringOrDefault("zoneinfo", "UTC"), - Locale: types.UserLocale{Tag: locale}, - UpdatedAt: time.Now(), - Subject: sessionData.Subject, - }) + // if adding the user succeeds then update the profile + return h.updateOAuth2UserProfile(req.Context(), tx, sessionData) }) if err != nil { return UserAuth{}, err @@ -232,9 +222,53 @@ func (h *httpServer) updateExternalUserInfo(req *http.Request, sso *issuer.WellK return UserAuth{}, err } + // TODO(melon): this feels bad + sessionData = UserAuth{ + Subject: userSubject, + NeedOtp: sessionData.NeedOtp, + UserInfo: sessionData.UserInfo, + } + return sessionData, nil } +func (h *httpServer) updateOAuth2UserProfile(ctx context.Context, tx *database.Queries, sessionData UserAuth) error { + // all of these updates must succeed + return tx.UseTx(ctx, func(tx *database.Queries) error { + name := sessionData.UserInfo.GetStringOrDefault("name", "Unknown User") + + err := tx.ModifyUserRemoteLogin(ctx, database.ModifyUserRemoteLoginParams{ + Login: sessionData.UserInfo.GetStringFromKeysOrEmpty("login", "preferred_username"), + ProfileUrl: sessionData.UserInfo.GetStringOrEmpty("profile"), + Subject: sessionData.Subject, + }) + if err != nil { + return err + } + + pronoun, err := pronouns.FindPronoun(sessionData.UserInfo.GetStringOrEmpty("pronouns")) + if err != nil { + pronoun = pronouns.TheyThem + } + locale, err := language.Parse(sessionData.UserInfo.GetStringOrEmpty("locale")) + if err != nil { + locale = language.AmericanEnglish + } + + return tx.ModifyProfile(ctx, database.ModifyProfileParams{ + Name: name, + Picture: sessionData.UserInfo.GetStringOrEmpty("profile"), + Website: sessionData.UserInfo.GetStringOrEmpty("website"), + Pronouns: types.UserPronoun{Pronoun: pronoun}, + Birthdate: sessionData.UserInfo.GetNullDate("birthdate"), + Zone: sessionData.UserInfo.GetStringOrDefault("zoneinfo", "UTC"), + Locale: types.UserLocale{Tag: locale}, + UpdatedAt: time.Now(), + Subject: sessionData.Subject, + }) + }) +} + const twelveHours = 12 * time.Hour const oneWeek = 7 * 24 * time.Hour @@ -257,7 +291,7 @@ func (l lavenderLoginRefresh) Valid() error { return l.RefreshTokenClaims.Valid( func (l lavenderLoginRefresh) Type() string { return "lavender-login-refresh" } func (h *httpServer) setLoginDataCookie2(rw http.ResponseWriter, authData UserAuth) bool { - // TODO(melon): should probably merge there methods + // TODO(melon): should probably merge these methods return h.setLoginDataCookie(rw, authData, "") } @@ -377,7 +411,9 @@ func (h *httpServer) fetchUserInfo(sso *issuer.WellKnownOIDC, token *oauth2.Toke if !ok { return UserAuth{}, fmt.Errorf("invalid subject") } - subject += "@" + sso.Config.Namespace + + // TODO(melon): there is no need for this + //subject += "@" + sso.Config.Namespace return UserAuth{ Subject: subject, diff --git a/server/oauth.go b/server/oauth.go index ef11372..e79fb4a 100644 --- a/server/oauth.go +++ b/server/oauth.go @@ -2,6 +2,7 @@ package server import ( "encoding/json" + "fmt" clientStore "github.com/1f349/lavender/client-store" "github.com/1f349/lavender/database" "github.com/1f349/lavender/logger" @@ -9,6 +10,8 @@ import ( "github.com/1f349/lavender/scope" "github.com/1f349/lavender/utils" "github.com/1f349/mjwt" + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/errors" "github.com/go-oauth2/oauth2/v4/generates" "github.com/go-oauth2/oauth2/v4/manage" "github.com/go-oauth2/oauth2/v4/server" @@ -16,12 +19,13 @@ import ( "github.com/julienschmidt/httprouter" "net/http" "net/url" + "runtime" "strings" "time" ) func SetupOAuth2(r *httprouter.Router, hs *httpServer, key *mjwt.Issuer, db *database.Queries) { - oauthManager := manage.NewManager() + oauthManager := manage.NewDefaultManager() oauthManager.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) oauthManager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) oauthManager.MustTokenStorage(store.NewMemoryTokenStore()) @@ -53,7 +57,19 @@ func SetupOAuth2(r *httprouter.Router, hs *httpServer, key *mjwt.Issuer, db *dat } return a, nil }) + oauthSrv.ClientAuthorizedHandler = func(clientID string, grant oauth2.GrantType) (allowed bool, err error) { + return true, nil + } addIdTokenSupport(oauthSrv, db, key) + oauthSrv.ResponseErrorHandler = func(re *errors.Response) { + buf := make([]byte, 1<<20) + n := runtime.Stack(buf, false) + fmt.Printf("%#v\n", re) + fmt.Printf("%s\n", buf[:n]) + } + + hs.oauthMgr = oauthManager + hs.oauthSrv = oauthSrv r.GET("/authorize", hs.RequireAuthentication(hs.authorizeEndpoint)) r.POST("/authorize", hs.RequireAuthentication(hs.authorizeEndpoint)) @@ -62,9 +78,11 @@ func SetupOAuth2(r *httprouter.Router, hs *httpServer, key *mjwt.Issuer, db *dat http.Error(rw, "Failed to handle token request", http.StatusInternalServerError) } }) + r.GET("/userinfo", hs.userInfoRequest) + r.OPTIONS("/userinfo", hs.userInfoRequest) } -func (h *httpServer) userInfoRequest(rw http.ResponseWriter, req *http.Request) { +func (h *httpServer) userInfoRequest(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { rw.Header().Set("Access-Control-Allow-Credentials", "true") rw.Header().Set("Access-Control-Allow-Headers", "Authorization,Content-Type") rw.Header().Set("Access-Control-Allow-Origin", strings.TrimSuffix(req.Referer(), "/")) @@ -80,12 +98,6 @@ func (h *httpServer) userInfoRequest(rw http.ResponseWriter, req *http.Request) } userId := token.GetUserID() - sso := h.manager.FindServiceFromLogin(userId) - if sso == nil { - http.Error(rw, "Invalid user", http.StatusBadRequest) - return - } - var user database.User if h.DbTx(rw, func(tx *database.Queries) (err error) { user, err = tx.GetUser(req.Context(), userId) diff --git a/test-client/pop2.js b/test-client/pop2.js index 0220d39..62199a2 100644 --- a/test-client/pop2.js +++ b/test-client/pop2.js @@ -33,7 +33,7 @@ parseInt(window.location.hash.replace(/^.*expires_in=([^&]+).*$/, '$1')) ); } - if (window.location.search.indexOf('error=')) { + if (window.location.hash.indexOf('error=')) { window.opener.POP2.receiveToken('ERROR'); } }