From 182c424b332b8932d8fa0069adae949448a5030a Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Tue, 5 Dec 2023 18:10:47 +0000 Subject: [PATCH] Start implementing refresh tokens --- go.mod | 1 + go.sum | 2 + server/flow.go | 105 ++--------------------- server/refresh.go | 183 +++++++++++++++++++++++++++++++++++++++++ server/server.go | 20 +++++ test-client/index.html | 38 ++++++++- 6 files changed, 249 insertions(+), 100 deletions(-) create mode 100644 server/refresh.go diff --git a/go.mod b/go.mod index 6e942d2..719f216 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/google/subcommands v1.2.0 github.com/google/uuid v1.4.0 github.com/julienschmidt/httprouter v1.3.0 + github.com/rs/cors v1.10.1 github.com/stretchr/testify v1.8.4 golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d ) diff --git a/go.sum b/go.sum index b92086c..3356947 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rs/cors v1.10.1 h1:L0uuZVXIKlI1SShY2nhFfo44TYvDPQ1w4oFkUJNfhyo= +github.com/rs/cors v1.10.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/server/flow.go b/server/flow.go index c1c6ae2..8030f06 100644 --- a/server/flow.go +++ b/server/flow.go @@ -3,18 +3,13 @@ package server import ( "context" _ "embed" - "encoding/json" "fmt" "github.com/1f349/lavender/issuer" "github.com/1f349/lavender/server/pages" - "github.com/1f349/mjwt/auth" - "github.com/1f349/mjwt/claims" - "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "github.com/julienschmidt/httprouter" "golang.org/x/oauth2" "net/http" - "net/mail" "net/url" "strings" "time" @@ -138,98 +133,14 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h http.Error(rw, "Failed to exchange code", http.StatusInternalServerError) return } - v2, err := testOa2UserInfo(v.sso, req.Context(), exchange) - if err != nil { - fmt.Println("Failed to get userinfo:", err) - http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError) - return - } - defer v2.Body.Close() - if v2.StatusCode != http.StatusOK { - http.Error(rw, "Failed to get userinfo: unexpected status code", http.StatusInternalServerError) - return - } - var v3 map[string]any - if err = json.NewDecoder(v2.Body).Decode(&v3); err != nil { - fmt.Println("Failed to decode userinfo:", err) - http.Error(rw, "Failed to decode userinfo", http.StatusInternalServerError) - return - } - - sub, ok := v3["sub"].(string) - if !ok { - http.Error(rw, "Invalid subject in userinfo", http.StatusInternalServerError) - return - } - aud, ok := v3["aud"].(string) - if !ok { - http.Error(rw, "Invalid audience in userinfo", http.StatusInternalServerError) - return - } - - var needsMailFlag, needsDomains bool - - ps := claims.NewPermStorage() - for _, i := range v.target.Permissions { - if strings.HasPrefix(i, "dynamic:") { - switch i { - case "dynamic:mail-inbox": - needsMailFlag = true - case "dynamic:domain-owns": - needsDomains = true - } - } else { - ps.Set(i) - } - } - - if needsMailFlag { - if verified, ok := v3["email_verified"].(bool); ok && verified { - if mailAddress, ok := v3["email"].(string); ok { - address, err := mail.ParseAddress(mailAddress) - if err != nil { - http.Error(rw, "Invalid email in userinfo", http.StatusInternalServerError) - return - } - n := strings.IndexByte(address.Address, '@') - if n != -1 { - if address.Address[n+1:] == v.sso.Config.Namespace { - ps.Set("mail:inbox=" + address.Address) - } - } - } - } - } - - if needsDomains { - a := h.conf.Load().Users.AllDomains(sub + "@" + v.sso.Config.Namespace) - for _, i := range a { - ps.Set("domain:owns=" + i) - } - } - - nsSub := sub + "@" + v.sso.Config.Namespace - ati := uuidNewStringAti() - accessToken, err := h.signer.GenerateJwt(nsSub, ati, jwt.ClaimStrings{aud}, 15*time.Minute, auth.AccessTokenClaims{ - Perms: ps, - }) - if err != nil { - http.Error(rw, "Error generating access token", http.StatusInternalServerError) - return - } - - refreshToken, err := h.signer.GenerateJwt(nsSub, uuidNewStringRti(), jwt.ClaimStrings{aud}, 15*time.Minute, auth.RefreshTokenClaims{AccessTokenId: ati}) - if err != nil { - http.Error(rw, "Error generating refresh token", http.StatusInternalServerError) - return - } - - pages.RenderPageTemplate(rw, "flow-callback", map[string]any{ - "ServiceName": h.conf.Load().ServiceName, - "TargetOrigin": v.target.Url.String(), - "TargetMessage": v3, - "AccessToken": accessToken, - "RefreshToken": refreshToken, + h.finishTokenGenerateFlow(rw, req, v, exchange, func(accessToken, refreshToken string, v3 map[string]any) { + pages.RenderPageTemplate(rw, "flow-callback", map[string]any{ + "ServiceName": h.conf.Load().ServiceName, + "TargetOrigin": v.target.Url.String(), + "TargetMessage": v3, + "AccessToken": accessToken, + "RefreshToken": refreshToken, + }) }) } diff --git a/server/refresh.go b/server/refresh.go new file mode 100644 index 0000000..f8231ec --- /dev/null +++ b/server/refresh.go @@ -0,0 +1,183 @@ +package server + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "github.com/1f349/mjwt/auth" + "github.com/1f349/mjwt/claims" + "github.com/golang-jwt/jwt/v4" + "github.com/julienschmidt/httprouter" + "golang.org/x/oauth2" + "net/http" + "net/mail" + "strings" + "time" +) + +func (h *HttpServer) refreshHandler(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + ref := strings.TrimSuffix(req.Referer(), "/") + allowedClient, ok := (*h.services.Load())[ref] + if !ok { + http.Error(rw, "Invalid origin", http.StatusBadRequest) + return + } + loginNameCookie, err := req.Cookie("lavender-login-name") + if err != nil { + http.Error(rw, "Failed to read cookie", http.StatusBadRequest) + return + } + loginService := h.manager.Load().FindServiceFromLogin(loginNameCookie.Value) + cookie, err := req.Cookie("sso-exchange") + if err != nil { + http.Error(rw, "Failed to read cookie", http.StatusBadRequest) + return + } + rawEncrypt, err := base64.RawURLEncoding.DecodeString(cookie.Value) + if err != nil { + http.Error(rw, "Internal Server Error", http.StatusBadRequest) + return + } + rawTokens, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, h.signer.PrivateKey(), rawEncrypt, []byte("sso-exchange")) + if err != nil { + http.Error(rw, "Internal Server Error", http.StatusBadRequest) + return + } + var exchange oauth2.Token + err = json.Unmarshal(rawTokens, &exchange) + if err != nil { + http.Error(rw, "Internal Server Error", http.StatusBadRequest) + return + } + h.finishTokenGenerateFlow(rw, req, flowStateData{ + sso: loginService, + target: allowedClient, + }, &exchange, func(accessToken string, refreshToken string, v3 map[string]any) { + tokens := map[string]any{ + "target": allowedClient.Url.String(), + "userinfo": v3, + "tokens": map[string]any{ + "access": accessToken, + "refresh": refreshToken, + }, + } + _ = json.NewEncoder(rw).Encode(tokens) + }) +} + +func (h *HttpServer) finishTokenGenerateFlow(rw http.ResponseWriter, req *http.Request, v flowStateData, exchange *oauth2.Token, response func(accessToken string, refreshToken string, v3 map[string]any)) { + // fetch user info + v2, err := testOa2UserInfo(v.sso, req.Context(), exchange) + if err != nil { + fmt.Println("Failed to get userinfo:", err) + http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError) + return + } + defer v2.Body.Close() + if v2.StatusCode != http.StatusOK { + http.Error(rw, "Failed to get userinfo: unexpected status code", http.StatusInternalServerError) + return + } + + // encrypt exchange tokens for cookie storage + marshal, err := json.Marshal(exchange) + if err != nil { + fmt.Println("Failed to marshal exchange tokens", err) + http.Error(rw, "Internal server error", http.StatusInternalServerError) + return + } + oaepBytes, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, h.signer.PublicKey(), marshal, []byte("sso-exchange")) + if err != nil { + fmt.Println("Failed to encrypt exchange tokens", err) + http.Error(rw, "Internal server error", http.StatusInternalServerError) + return + } + http.SetCookie(rw, &http.Cookie{ + Name: "sso-exchange", + Value: base64.RawURLEncoding.EncodeToString(oaepBytes), + Path: "/", + Expires: time.Now().AddDate(0, 3, 0), + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + + var v3 map[string]any + if err = json.NewDecoder(v2.Body).Decode(&v3); err != nil { + fmt.Println("Failed to decode userinfo:", err) + http.Error(rw, "Failed to decode userinfo", http.StatusInternalServerError) + return + } + + sub, ok := v3["sub"].(string) + if !ok { + http.Error(rw, "Invalid subject in userinfo", http.StatusInternalServerError) + return + } + aud, ok := v3["aud"].(string) + if !ok { + http.Error(rw, "Invalid audience in userinfo", http.StatusInternalServerError) + return + } + + var needsMailFlag, needsDomains bool + + ps := claims.NewPermStorage() + for _, i := range v.target.Permissions { + if strings.HasPrefix(i, "dynamic:") { + switch i { + case "dynamic:mail-inbox": + needsMailFlag = true + case "dynamic:domain-owns": + needsDomains = true + } + } else { + ps.Set(i) + } + } + + if needsMailFlag { + if verified, ok := v3["email_verified"].(bool); ok && verified { + if mailAddress, ok := v3["email"].(string); ok { + address, err := mail.ParseAddress(mailAddress) + if err != nil { + http.Error(rw, "Invalid email in userinfo", http.StatusInternalServerError) + return + } + n := strings.IndexByte(address.Address, '@') + if n != -1 { + if address.Address[n+1:] == v.sso.Config.Namespace { + ps.Set("mail:inbox=" + address.Address) + } + } + } + } + } + + if needsDomains { + a := h.conf.Load().Users.AllDomains(sub + "@" + v.sso.Config.Namespace) + for _, i := range a { + ps.Set("domain:owns=" + i) + } + } + + nsSub := sub + "@" + v.sso.Config.Namespace + ati := uuidNewStringAti() + accessToken, err := h.signer.GenerateJwt(nsSub, ati, jwt.ClaimStrings{aud}, 15*time.Minute, auth.AccessTokenClaims{ + Perms: ps, + }) + if err != nil { + http.Error(rw, "Error generating access token", http.StatusInternalServerError) + return + } + + refreshToken, err := h.signer.GenerateJwt(nsSub, uuidNewStringRti(), jwt.ClaimStrings{aud}, 15*time.Minute, auth.RefreshTokenClaims{AccessTokenId: ati}) + if err != nil { + http.Error(rw, "Error generating refresh token", http.StatusInternalServerError) + return + } + + response(accessToken, refreshToken, v3) +} diff --git a/server/server.go b/server/server.go index 0a6a9f8..749d75f 100644 --- a/server/server.go +++ b/server/server.go @@ -6,6 +6,7 @@ import ( "github.com/1f349/lavender/issuer" "github.com/1f349/mjwt" "github.com/julienschmidt/httprouter" + "github.com/rs/cors" "log" "net/http" "sync/atomic" @@ -66,6 +67,25 @@ func NewHttpServer(conf Conf, signer mjwt.Signer) *HttpServer { r.GET("/popup", hs.flowPopup) r.POST("/popup", hs.flowPopupPost) r.GET("/callback", hs.flowCallback) + + var corsAccessControl = cors.New(cors.Options{ + AllowOriginFunc: func(origin string) bool { + load := hs.services.Load() + _, ok := (*load)[origin] + return ok + }, + AllowedMethods: []string{http.MethodPost, http.MethodOptions}, + AllowedHeaders: []string{"Content-Type"}, + AllowCredentials: true, + }) + r.POST("/refresh", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + corsAccessControl.ServeHTTP(rw, req, func(writer http.ResponseWriter, request *http.Request) { + hs.refreshHandler(rw, req, params) + }) + }) + r.OPTIONS("/refresh", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + corsAccessControl.ServeHTTP(rw, req, func(_ http.ResponseWriter, _ *http.Request) {}) + }) return hs } diff --git a/test-client/index.html b/test-client/index.html index 3af37e8..f2226b2 100644 --- a/test-client/index.html +++ b/test-client/index.html @@ -4,15 +4,20 @@ Test Client @@ -88,10 +114,16 @@
+
- +
+ +
+
+ +

Permissions: