Save client tokens for userinfo requests

This commit is contained in:
Melon 2024-02-10 11:59:45 +00:00
parent 0c668253a8
commit 05b19e6bf2
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
6 changed files with 89 additions and 12 deletions

View File

@ -4,6 +4,9 @@ CREATE TABLE IF NOT EXISTS users
email TEXT NOT NULL, email TEXT NOT NULL,
email_verified INTEGER DEFAULT 0 NOT NULL, email_verified INTEGER DEFAULT 0 NOT NULL,
roles TEXT NOT NULL, roles TEXT NOT NULL,
access_token TEXT,
refresh_token TEXT,
expiry DATETIME,
updated_at DATETIME, updated_at DATETIME,
active INTEGER DEFAULT 1 active INTEGER DEFAULT 1
); );

View File

@ -138,6 +138,16 @@ func (t *Tx) UpdateUser(subject, roles string, active bool) error {
return err return err
} }
func (t *Tx) UpdateUserToken(subject, accessToken, refreshToken string, expiry time.Time) error {
_, err := t.tx.Exec(`UPDATE users SET access_token = ?, refresh_token = ?, expiry = ? WHERE subject = ?`, accessToken, refreshToken, expiry, subject)
return err
}
func (t *Tx) GetUserToken(subject string, accessToken, refreshToken *string, expiry *time.Time) error {
row := t.tx.QueryRow(`SELECT access_token, refresh_token, expiry FROM users WHERE subject = ? LIMIT 1`, subject)
return row.Scan(accessToken, refreshToken, expiry)
}
func (t *Tx) UserEmailExists(email string) (exists bool, err error) { func (t *Tx) UserEmailExists(email string) (exists bool, err error) {
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users WHERE email = ? and email_verified = 1)`, email) row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users WHERE email = ? and email_verified = 1)`, email)
err = row.Scan(&exists) err = row.Scan(&exists)

View File

@ -10,6 +10,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"github.com/1f349/lavender/database" "github.com/1f349/lavender/database"
"github.com/1f349/lavender/issuer" "github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/pages" "github.com/1f349/lavender/pages"
@ -108,7 +109,7 @@ func (h *HttpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _
return return
} }
sessionData := h.fetchUserInfo(rw, err, flowState.sso, token) sessionData, err := h.fetchUserInfo(flowState.sso, token)
if sessionData.ID == "" { if sessionData.ID == "" {
http.Error(rw, "Failed to fetch user info", http.StatusInternalServerError) http.Error(rw, "Failed to fetch user info", http.StatusInternalServerError)
return return
@ -133,6 +134,12 @@ func (h *HttpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _
return return
} }
if h.DbTx(rw, func(tx *database.Tx) error {
return tx.UpdateUserToken(auth.Data.ID, token.AccessToken, token.RefreshToken, token.Expiry)
}) {
return
}
if h.setLoginDataCookie(rw, auth.Data.ID, token) { if h.setLoginDataCookie(rw, auth.Data.ID, token) {
http.Error(rw, "Internal Server Error", http.StatusInternalServerError) http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return return
@ -199,25 +206,27 @@ func (h *HttpServer) readLoginDataCookie(rw http.ResponseWriter, req *http.Reque
return return
} }
u.Data = h.fetchUserInfo(rw, err, sso, token) u.Data, err = h.fetchUserInfo(sso, token)
if err != nil {
http.Error(rw, "Failed to fetch user info", http.StatusInternalServerError)
return
}
} }
func (h *HttpServer) fetchUserInfo(rw http.ResponseWriter, err error, sso *issuer.WellKnownOIDC, token *oauth2.Token) SessionData { func (h *HttpServer) fetchUserInfo(sso *issuer.WellKnownOIDC, token *oauth2.Token) (SessionData, error) {
res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint) res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint)
if err != nil || res.StatusCode != http.StatusOK { if err != nil || res.StatusCode != http.StatusOK {
return SessionData{} return SessionData{}, fmt.Errorf("request failed")
} }
defer res.Body.Close() defer res.Body.Close()
var userInfoJson UserInfoFields var userInfoJson UserInfoFields
if err := json.NewDecoder(res.Body).Decode(&userInfoJson); err != nil { if err := json.NewDecoder(res.Body).Decode(&userInfoJson); err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError) return SessionData{}, err
return SessionData{}
} }
subject, ok := userInfoJson.GetString("sub") subject, ok := userInfoJson.GetString("sub")
if !ok { if !ok {
http.Error(rw, "Invalid subject", http.StatusInternalServerError) return SessionData{}, fmt.Errorf("invalid subject")
return SessionData{}
} }
subject += "@" + sso.Config.Namespace subject += "@" + sso.Config.Namespace
@ -226,5 +235,5 @@ func (h *HttpServer) fetchUserInfo(rw http.ResponseWriter, err error, sso *issue
ID: subject, ID: subject,
DisplayName: displayName, DisplayName: displayName,
UserInfo: userInfoJson, UserInfo: userInfoJson,
} }, nil
} }

View File

@ -20,6 +20,7 @@ import (
"github.com/go-oauth2/oauth2/v4/store" "github.com/go-oauth2/oauth2/v4/store"
"github.com/go-session/session" "github.com/go-session/session"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
oauth22 "golang.org/x/oauth2"
"log" "log"
"net/http" "net/http"
"net/url" "net/url"
@ -186,6 +187,26 @@ func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Ser
} }
userId := token.GetUserID() userId := token.GetUserID()
log.Println(userId)
sso := hs.manager.FindServiceFromLogin(userId)
if sso == nil {
http.Error(rw, "Invalid user", http.StatusBadRequest)
return
}
var clientToken oauth22.Token
if hs.DbTx(rw, func(tx *database.Tx) error {
return tx.GetUserToken(userId, &clientToken.AccessToken, &clientToken.RefreshToken, &clientToken.Expiry)
}) {
return
}
info, err := hs.fetchUserInfo(sso, &clientToken)
if err != nil {
http.Error(rw, "Failed to fetch user info", http.StatusInternalServerError)
return
}
fmt.Printf("Using token for user: %s by app: %s with scope: '%s'\n", userId, token.GetClientID(), token.GetScope()) fmt.Printf("Using token for user: %s by app: %s with scope: '%s'\n", userId, token.GetClientID(), token.GetScope())
claims := ParseClaims(token.GetScope()) claims := ParseClaims(token.GetScope())
if !claims["openid"] { if !claims["openid"] {
@ -193,7 +214,36 @@ func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Ser
return return
} }
m := map[string]any{} m := make(map[string]any)
if claims["name"] {
m["name"] = info.UserInfo["name"]
}
if claims["username"] {
m["preferred_username"] = info.UserInfo["preferred_username"]
}
if claims["profile"] {
m["profile"] = info.UserInfo["profile"]
m["picture"] = info.UserInfo["picture"]
m["website"] = info.UserInfo["website"]
}
if claims["email"] {
m["email"] = info.UserInfo["email"]
m["email_verified"] = info.UserInfo["email_verified"]
}
if claims["birthdate"] {
m["birthdate"] = info.UserInfo["birthdate"]
}
if claims["age"] {
m["age"] = info.UserInfo["age"]
}
if claims["zoneinfo"] {
m["zoneinfo"] = info.UserInfo["zoneinfo"]
}
if claims["locale"] {
m["locale"] = info.UserInfo["locale"]
}
m["sub"] = userId m["sub"] = userId
m["aud"] = token.GetClientID() m["aud"] = token.GetClientID()
m["updated_at"] = time.Now().Unix() m["updated_at"] = time.Now().Unix()

View File

@ -8,13 +8,18 @@ func (u UserInfoFields) GetString(key string) (string, bool) {
} }
func (u UserInfoFields) GetStringOrDefault(key, other string) string { func (u UserInfoFields) GetStringOrDefault(key, other string) string {
s, ok := u.GetString(key) s, ok := u[key].(string)
if !ok { if !ok {
s = other s = other
} }
return s return s
} }
func (u UserInfoFields) GetStringOrEmpty(key string) string {
s, _ := u[key].(string)
return s
}
func (u UserInfoFields) GetBoolean(key string) (bool, bool) { func (u UserInfoFields) GetBoolean(key string) (bool, bool) {
b, ok := u[key].(bool) b, ok := u[key].(bool)
return b, ok return b, ok

View File

@ -6,7 +6,7 @@
<script> <script>
const ssoService = "http://localhost:9090"; const ssoService = "http://localhost:9090";
POP2.init(ssoService + "/authorize", "f4cdb93d-fe28-427b-b037-f03f44c86a16", "openid profile", 500, 600); POP2.init(ssoService + "/authorize", "f4cdb93d-fe28-427b-b037-f03f44c86a16", "openid profile age", 500, 600);
function updateTokenInfo(data) { function updateTokenInfo(data) {
document.getElementById("someTextArea").textContent = JSON.stringify(data, null, 2); document.getElementById("someTextArea").textContent = JSON.stringify(data, null, 2);