diff --git a/database/init.sql b/database/init.sql index 7271513..083ecde 100644 --- a/database/init.sql +++ b/database/init.sql @@ -4,6 +4,9 @@ CREATE TABLE IF NOT EXISTS users email TEXT NOT NULL, email_verified INTEGER DEFAULT 0 NOT NULL, roles TEXT NOT NULL, + access_token TEXT, + refresh_token TEXT, + expiry DATETIME, updated_at DATETIME, active INTEGER DEFAULT 1 ); diff --git a/database/tx.go b/database/tx.go index 6b85021..b7f2a94 100644 --- a/database/tx.go +++ b/database/tx.go @@ -138,6 +138,16 @@ func (t *Tx) UpdateUser(subject, roles string, active bool) error { return err } +func (t *Tx) UpdateUserToken(subject, accessToken, refreshToken string, expiry time.Time) error { + _, err := t.tx.Exec(`UPDATE users SET access_token = ?, refresh_token = ?, expiry = ? WHERE subject = ?`, accessToken, refreshToken, expiry, subject) + return err +} + +func (t *Tx) GetUserToken(subject string, accessToken, refreshToken *string, expiry *time.Time) error { + row := t.tx.QueryRow(`SELECT access_token, refresh_token, expiry FROM users WHERE subject = ? LIMIT 1`, subject) + return row.Scan(accessToken, refreshToken, expiry) +} + func (t *Tx) UserEmailExists(email string) (exists bool, err error) { row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users WHERE email = ? and email_verified = 1)`, email) err = row.Scan(&exists) diff --git a/server/login.go b/server/login.go index 5d5cd68..e950fff 100644 --- a/server/login.go +++ b/server/login.go @@ -10,6 +10,7 @@ import ( "encoding/base64" "encoding/json" "errors" + "fmt" "github.com/1f349/lavender/database" "github.com/1f349/lavender/issuer" "github.com/1f349/lavender/pages" @@ -108,7 +109,7 @@ func (h *HttpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _ return } - sessionData := h.fetchUserInfo(rw, err, flowState.sso, token) + sessionData, err := h.fetchUserInfo(flowState.sso, token) if sessionData.ID == "" { http.Error(rw, "Failed to fetch user info", http.StatusInternalServerError) return @@ -133,6 +134,12 @@ func (h *HttpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _ return } + if h.DbTx(rw, func(tx *database.Tx) error { + return tx.UpdateUserToken(auth.Data.ID, token.AccessToken, token.RefreshToken, token.Expiry) + }) { + return + } + if h.setLoginDataCookie(rw, auth.Data.ID, token) { http.Error(rw, "Internal Server Error", http.StatusInternalServerError) return @@ -199,25 +206,27 @@ func (h *HttpServer) readLoginDataCookie(rw http.ResponseWriter, req *http.Reque return } - u.Data = h.fetchUserInfo(rw, err, sso, token) + u.Data, err = h.fetchUserInfo(sso, token) + if err != nil { + http.Error(rw, "Failed to fetch user info", http.StatusInternalServerError) + return + } } -func (h *HttpServer) fetchUserInfo(rw http.ResponseWriter, err error, sso *issuer.WellKnownOIDC, token *oauth2.Token) SessionData { +func (h *HttpServer) fetchUserInfo(sso *issuer.WellKnownOIDC, token *oauth2.Token) (SessionData, error) { res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint) if err != nil || res.StatusCode != http.StatusOK { - return SessionData{} + return SessionData{}, fmt.Errorf("request failed") } defer res.Body.Close() var userInfoJson UserInfoFields if err := json.NewDecoder(res.Body).Decode(&userInfoJson); err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return SessionData{} + return SessionData{}, err } subject, ok := userInfoJson.GetString("sub") if !ok { - http.Error(rw, "Invalid subject", http.StatusInternalServerError) - return SessionData{} + return SessionData{}, fmt.Errorf("invalid subject") } subject += "@" + sso.Config.Namespace @@ -226,5 +235,5 @@ func (h *HttpServer) fetchUserInfo(rw http.ResponseWriter, err error, sso *issue ID: subject, DisplayName: displayName, UserInfo: userInfoJson, - } + }, nil } diff --git a/server/server.go b/server/server.go index 75bf0f2..450977a 100644 --- a/server/server.go +++ b/server/server.go @@ -20,6 +20,7 @@ import ( "github.com/go-oauth2/oauth2/v4/store" "github.com/go-session/session" "github.com/julienschmidt/httprouter" + oauth22 "golang.org/x/oauth2" "log" "net/http" "net/url" @@ -186,6 +187,26 @@ func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Ser } userId := token.GetUserID() + log.Println(userId) + sso := hs.manager.FindServiceFromLogin(userId) + if sso == nil { + http.Error(rw, "Invalid user", http.StatusBadRequest) + return + } + + var clientToken oauth22.Token + if hs.DbTx(rw, func(tx *database.Tx) error { + return tx.GetUserToken(userId, &clientToken.AccessToken, &clientToken.RefreshToken, &clientToken.Expiry) + }) { + return + } + + info, err := hs.fetchUserInfo(sso, &clientToken) + if err != nil { + http.Error(rw, "Failed to fetch user info", http.StatusInternalServerError) + return + } + fmt.Printf("Using token for user: %s by app: %s with scope: '%s'\n", userId, token.GetClientID(), token.GetScope()) claims := ParseClaims(token.GetScope()) if !claims["openid"] { @@ -193,7 +214,36 @@ func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Ser return } - m := map[string]any{} + m := make(map[string]any) + + if claims["name"] { + m["name"] = info.UserInfo["name"] + } + if claims["username"] { + m["preferred_username"] = info.UserInfo["preferred_username"] + } + if claims["profile"] { + m["profile"] = info.UserInfo["profile"] + m["picture"] = info.UserInfo["picture"] + m["website"] = info.UserInfo["website"] + } + if claims["email"] { + m["email"] = info.UserInfo["email"] + m["email_verified"] = info.UserInfo["email_verified"] + } + if claims["birthdate"] { + m["birthdate"] = info.UserInfo["birthdate"] + } + if claims["age"] { + m["age"] = info.UserInfo["age"] + } + if claims["zoneinfo"] { + m["zoneinfo"] = info.UserInfo["zoneinfo"] + } + if claims["locale"] { + m["locale"] = info.UserInfo["locale"] + } + m["sub"] = userId m["aud"] = token.GetClientID() m["updated_at"] = time.Now().Unix() diff --git a/server/userinfofields.go b/server/userinfofields.go index e40ed49..5b28e7c 100644 --- a/server/userinfofields.go +++ b/server/userinfofields.go @@ -8,13 +8,18 @@ func (u UserInfoFields) GetString(key string) (string, bool) { } func (u UserInfoFields) GetStringOrDefault(key, other string) string { - s, ok := u.GetString(key) + s, ok := u[key].(string) if !ok { s = other } return s } +func (u UserInfoFields) GetStringOrEmpty(key string) string { + s, _ := u[key].(string) + return s +} + func (u UserInfoFields) GetBoolean(key string) (bool, bool) { b, ok := u[key].(bool) return b, ok diff --git a/test-client/index.html b/test-client/index.html index 955835b..8969f4d 100644 --- a/test-client/index.html +++ b/test-client/index.html @@ -6,7 +6,7 @@