diff --git a/database/db-types.go b/database/db-types.go index 6ce6c53..461bbfa 100644 --- a/database/db-types.go +++ b/database/db-types.go @@ -10,6 +10,7 @@ type User struct { Email string `json:"email"` EmailVerified bool `json:"email_verified"` Roles string `json:"roles"` + UserInfo string `json:"userinfo"` UpdatedAt time.Time `json:"updated_at"` Active bool `json:"active"` } diff --git a/database/init.sql b/database/init.sql index 2be31a7..0009f1d 100644 --- a/database/init.sql +++ b/database/init.sql @@ -4,6 +4,7 @@ CREATE TABLE IF NOT EXISTS users email TEXT NOT NULL, email_verified INTEGER DEFAULT 0 NOT NULL, roles TEXT NOT NULL, + userinfo TEXT, access_token TEXT, refresh_token TEXT, expiry DATETIME, diff --git a/database/tx.go b/database/tx.go index 68cf6d0..86a4abc 100644 --- a/database/tx.go +++ b/database/tx.go @@ -37,8 +37,13 @@ func (t *Tx) HasUser() error { return nil } -func (t *Tx) InsertUser(subject, email string, verifyEmail bool, roles string, active bool) error { - _, err := t.tx.Exec(`INSERT INTO users (subject, email, email_verified, roles, updated_at, active) VALUES (?, ?, ?, ?, ?, ?)`, subject, email, verifyEmail, roles, updatedAt(), active) +func (t *Tx) InsertUser(subject, email string, verifyEmail bool, roles, userinfo string, active bool) error { + _, err := t.tx.Exec(`INSERT INTO users (subject, email, email_verified, roles, userinfo, updated_at, active) VALUES (?, ?, ?, ?, ?, ?, ?)`, subject, email, verifyEmail, roles, userinfo, updatedAt(), active) + return err +} + +func (t *Tx) UpdateUserInfo(subject, email string, verified bool, userinfo string) error { + _, err := t.tx.Exec(`UPDATE users SET email = ?, email_verified = ?, userinfo = ? WHERE subject = ?`, email, verified, userinfo, subject) return err } @@ -51,8 +56,8 @@ func (t *Tx) GetUserRoles(sub string) (string, error) { func (t *Tx) GetUser(sub string) (*User, error) { var u User - row := t.tx.QueryRow(`SELECT email, email_verified, roles, updated_at, active FROM users WHERE subject = ?`, sub) - err := row.Scan(&u.Email, &u.EmailVerified, &u.Roles, &u.UpdatedAt, &u.Active) + row := t.tx.QueryRow(`SELECT email, email_verified, roles, userifo, updated_at, active FROM users WHERE subject = ?`, sub) + err := row.Scan(&u.Email, &u.EmailVerified, &u.Roles, &u.UserInfo, &u.UpdatedAt, &u.Active) u.Sub = sub return &u, err } diff --git a/issuer/manager.go b/issuer/manager.go index a5a6ff5..b37b8b1 100644 --- a/issuer/manager.go +++ b/issuer/manager.go @@ -30,17 +30,6 @@ func NewManager(services []SsoConfig) (*Manager, error) { return l, nil } -func NewManagerForTests(services []*WellKnownOIDC) *Manager { - l := &Manager{m: make(map[string]*WellKnownOIDC, len(services))} - for _, i := range services { - if !isValidNamespace.MatchString(i.Config.Namespace) { - panic("Invalid namespace in tests") - } - l.m[i.Config.Namespace] = i - } - return l -} - func (m *Manager) CheckNamespace(namespace string) bool { _, ok := m.m[namespace] return ok diff --git a/server/login.go b/server/login.go index 75d2552..63496bb 100644 --- a/server/login.go +++ b/server/login.go @@ -118,13 +118,19 @@ func (h *HttpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _ } if h.DbTx(rw, func(tx *database.Tx) error { - _, err := tx.GetUser(sessionData.ID) + jBytes, err := json.Marshal(sessionData.UserInfo) + if err != nil { + return err + } + _, err = tx.GetUser(sessionData.ID) if errors.Is(err, sql.ErrNoRows) { uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost") uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified") - return tx.InsertUser(sessionData.ID, uEmail, uEmailVerified, "", true) + return tx.InsertUser(sessionData.ID, uEmail, uEmailVerified, "", string(jBytes), true) } - return err + uEmail := sessionData.UserInfo.GetStringOrDefault("email", "unknown@localhost") + uEmailVerified, _ := sessionData.UserInfo.GetBoolean("email_verified") + return tx.UpdateUserInfo(sessionData.ID, uEmail, uEmailVerified, string(jBytes)) }) { return } diff --git a/server/server.go b/server/server.go index beb8674..f38ef2b 100644 --- a/server/server.go +++ b/server/server.go @@ -4,7 +4,6 @@ import ( "bytes" "crypto/subtle" "encoding/json" - "fmt" "github.com/1f349/cache" clientStore "github.com/1f349/lavender/client-store" "github.com/1f349/lavender/database" @@ -18,7 +17,6 @@ import ( "github.com/go-oauth2/oauth2/v4/server" "github.com/go-oauth2/oauth2/v4/store" "github.com/julienschmidt/httprouter" - "golang.org/x/oauth2" "log" "net/http" "net/url" @@ -185,20 +183,21 @@ func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Ser return } - var clientToken oauth2.Token - if hs.DbTx(rw, func(tx *database.Tx) error { - return tx.GetUserToken(userId, &clientToken.AccessToken, &clientToken.RefreshToken, &clientToken.Expiry) + var user *database.User + if hs.DbTx(rw, func(tx *database.Tx) (err error) { + user, err = tx.GetUser(userId) + return err }) { return } - info, err := hs.fetchUserInfo(sso, &clientToken) + var userInfo UserInfoFields + err = json.Unmarshal([]byte(user.UserInfo), &userInfo) if err != nil { - http.Error(rw, "Failed to fetch user info", http.StatusInternalServerError) + http.Error(rw, "500 Internal Server Error", http.StatusInternalServerError) return } - fmt.Printf("Using token for user: %s by app: %s with scope: '%s'\n", userId, token.GetClientID(), token.GetScope()) claims := ParseClaims(token.GetScope()) if !claims["openid"] { http.Error(rw, "Invalid scope", http.StatusBadRequest) @@ -208,32 +207,32 @@ func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Ser m := make(map[string]any) if claims["name"] { - m["name"] = info.UserInfo["name"] + m["name"] = userInfo["name"] } if claims["username"] { - m["preferred_username"] = info.UserInfo["preferred_username"] - m["login"] = info.UserInfo["login"] + m["preferred_username"] = userInfo["preferred_username"] + m["login"] = userInfo["login"] } if claims["profile"] { - m["profile"] = info.UserInfo["profile"] - m["picture"] = info.UserInfo["picture"] - m["website"] = info.UserInfo["website"] + m["profile"] = userInfo["profile"] + m["picture"] = userInfo["picture"] + m["website"] = userInfo["website"] } if claims["email"] { - m["email"] = info.UserInfo["email"] - m["email_verified"] = info.UserInfo["email_verified"] + m["email"] = userInfo["email"] + m["email_verified"] = userInfo["email_verified"] } if claims["birthdate"] { - m["birthdate"] = info.UserInfo["birthdate"] + m["birthdate"] = userInfo["birthdate"] } if claims["age"] { - m["age"] = info.UserInfo["age"] + m["age"] = userInfo["age"] } if claims["zoneinfo"] { - m["zoneinfo"] = info.UserInfo["zoneinfo"] + m["zoneinfo"] = userInfo["zoneinfo"] } if claims["locale"] { - m["locale"] = info.UserInfo["locale"] + m["locale"] = userInfo["locale"] } m["sub"] = userId