diff --git a/go.mod b/go.mod index 375cfd7..5e74698 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/google/uuid v1.6.0 github.com/julienschmidt/httprouter v1.3.0 github.com/mattn/go-sqlite3 v1.14.22 + github.com/rs/cors v1.10.1 github.com/stretchr/testify v1.8.4 golang.org/x/oauth2 v0.16.0 ) @@ -35,7 +36,7 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/rtred v0.1.2 // indirect github.com/tidwall/tinyqueue v0.1.1 // indirect - golang.org/x/net v0.20.0 // indirect + golang.org/x/net v0.21.0 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/protobuf v1.32.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index a0da316..14871f0 100644 --- a/go.sum +++ b/go.sum @@ -100,6 +100,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rs/cors v1.10.1 h1:L0uuZVXIKlI1SShY2nhFfo44TYvDPQ1w4oFkUJNfhyo= +github.com/rs/cors v1.10.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw= github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= @@ -177,8 +179,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= -golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.16.0 h1:aDkGMBSYxElaoP81NpoUoz2oo2R2wHdZpGToUxfyQrQ= golang.org/x/oauth2 v0.16.0/go.mod h1:hqZ+0LWXsiVoZpeld6jVt06P3adbS2Uu911W1SsJv2o= diff --git a/pages/login-memory.go.html b/pages/login-memory.go.html index 9d9f143..a566467 100644 --- a/pages/login-memory.go.html +++ b/pages/login-memory.go.html @@ -12,13 +12,11 @@
Log in as: {{.LoginName}}
-
-
diff --git a/pages/login.go.html b/pages/login.go.html index f19eac4..99380de 100644 --- a/pages/login.go.html +++ b/pages/login.go.html @@ -10,7 +10,6 @@
-
diff --git a/server/auth.go b/server/auth.go index 8d0f106..d85ffb5 100644 --- a/server/auth.go +++ b/server/auth.go @@ -67,8 +67,9 @@ func (h *HttpServer) OptionalAuthentication(next UserHandler) httprouter.Handle http.Error(rw, err.Error(), http.StatusInternalServerError) return } - if auth.IsGuest() && h.readLoginDataCookie(rw, req, &auth) { - return + if auth.IsGuest() { + // if this fails internally it just sees the user as logged out + h.readLoginDataCookie(rw, req, &auth) } next(rw, req, params, auth) } diff --git a/server/login.go b/server/login.go index 2d5cca2..653c85d 100644 --- a/server/login.go +++ b/server/login.go @@ -32,14 +32,12 @@ func (h *HttpServer) loginGet(rw http.ResponseWriter, req *http.Request, _ httpr if err == nil && cookie.Valid() == nil { pages.RenderPageTemplate(rw, "login-memory", map[string]any{ "ServiceName": h.conf.ServiceName, - "Origin": req.URL.Query().Get("origin"), "LoginName": cookie.Value, }) return } pages.RenderPageTemplate(rw, "login", map[string]any{ "ServiceName": h.conf.ServiceName, - "Origin": req.URL.Query().Get("origin"), }) } @@ -60,9 +58,6 @@ func (h *HttpServer) loginPost(rw http.ResponseWriter, req *http.Request, _ http }) http.Redirect(rw, req, (&url.URL{ Path: "/login", - RawQuery: url.Values{ - "origin": []string{req.PostFormValue("origin")}, - }.Encode(), }).String(), http.StatusFound) return } @@ -111,8 +106,9 @@ func (h *HttpServer) loginCallback(rw http.ResponseWriter, req *http.Request, _ return } - sessionData, done := h.fetchUserInfo(rw, err, flowState.sso, token) - if !done { + sessionData := h.fetchUserInfo(rw, err, flowState.sso, token) + if sessionData.ID == "" { + http.Error(rw, "Failed to fetch user info", http.StatusInternalServerError) return } @@ -166,63 +162,57 @@ func (h *HttpServer) setLoginDataCookie(rw http.ResponseWriter, userId string, t return false } -func (h *HttpServer) readLoginDataCookie(rw http.ResponseWriter, req *http.Request, u *UserAuth) bool { +func (h *HttpServer) readLoginDataCookie(rw http.ResponseWriter, req *http.Request, u *UserAuth) { loginCookie, err := req.Cookie("lavender-login-data") if err != nil { - return false + return } decryptedBytes, err := base64.RawStdEncoding.DecodeString(loginCookie.Value) if err != nil { - return false + return } decryptedData, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, h.signingKey.PrivateKey(), decryptedBytes, []byte("lavender-login-data")) if err != nil { - return false + return } buf := bytes.NewBuffer(decryptedData) userId, err := buf.ReadString(0) if err != nil { - return false + return } userId = strings.TrimSuffix(userId, "\x00") var token *oauth2.Token err = json.NewDecoder(buf).Decode(&token) if err != nil { - return false + return } sso := h.manager.FindServiceFromLogin(userId) if sso == nil { - return false + return } - sessionData, done := h.fetchUserInfo(rw, err, sso, token) - if !done { - return true - } - - u.Data = sessionData - return false + u.Data = h.fetchUserInfo(rw, err, sso, token) } -func (h *HttpServer) fetchUserInfo(rw http.ResponseWriter, err error, sso *issuer.WellKnownOIDC, token *oauth2.Token) (SessionData, bool) { +func (h *HttpServer) fetchUserInfo(rw http.ResponseWriter, err error, sso *issuer.WellKnownOIDC, token *oauth2.Token) SessionData { res, err := sso.OAuth2Config.Client(context.Background(), token).Get(sso.UserInfoEndpoint) if err != nil || res.StatusCode != http.StatusOK { - return SessionData{}, false + return SessionData{} } defer res.Body.Close() var userInfoJson UserInfoFields if err := json.NewDecoder(res.Body).Decode(&userInfoJson); err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) - return SessionData{}, false + return SessionData{} } subject, ok := userInfoJson.GetString("sub") if !ok { http.Error(rw, "Invalid subject", http.StatusInternalServerError) - return SessionData{}, false + return SessionData{} } subject += "@" + sso.Config.Namespace @@ -231,5 +221,5 @@ func (h *HttpServer) fetchUserInfo(rw http.ResponseWriter, err error, sso *issue ID: subject, DisplayName: displayName, UserInfo: userInfoJson, - }, true + } } diff --git a/server/server.go b/server/server.go index 7d466db..08f2b8e 100644 --- a/server/server.go +++ b/server/server.go @@ -169,7 +169,15 @@ func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Ser http.Error(rw, err.Error(), http.StatusInternalServerError) } }) - r.GET("/userinfo", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + userInfoRequest := func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { + rw.Header().Set("Access-Control-Allow-Credentials", "true") + rw.Header().Set("Access-Control-Allow-Headers", "Authorization,Content-Type") + rw.Header().Set("Access-Control-Allow-Origin", strings.TrimSuffix(req.Referer(), "/")) + rw.Header().Set("Access-Control-Allow-Methods", "GET") + if req.Method == http.MethodOptions { + return + } + token, err := oauthSrv.ValidationBearerToken(req) if err != nil { http.Error(rw, "403 Forbidden", http.StatusForbidden) @@ -190,7 +198,9 @@ func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Ser m["updated_at"] = time.Now().Unix() _ = json.NewEncoder(rw).Encode(m) - }) + } + r.GET("/userinfo", userInfoRequest) + r.OPTIONS("/userinfo", userInfoRequest) return &http.Server{ Addr: conf.Listen, diff --git a/server/userinfo.go b/server/userinfofields.go similarity index 100% rename from server/userinfo.go rename to server/userinfofields.go diff --git a/test-client/index.html b/test-client/index.html index 26b9261..adc7413 100644 --- a/test-client/index.html +++ b/test-client/index.html @@ -1,92 +1,72 @@ - Test Client - - + + + #tokenValues { + width: 400px; + height: 400px; + } +
-

Test Client

+

Test Client

+
+ +
+
- +
+ +
+
+ +
-
-
-
- -
-
- -
-
-
-

Permissions:

-
    -
    +
    +

    Permissions:

    +
      +
      diff --git a/test-client/popup.js b/test-client/pop2.js similarity index 80% rename from test-client/popup.js rename to test-client/pop2.js index e1d8e0c..f1c30d1 100644 --- a/test-client/popup.js +++ b/test-client/pop2.js @@ -71,7 +71,7 @@ client_id, scope = '', redirect_uri = window.location.href.substr(0, window.location.href.length - window.location.hash.length).replace(/#$/, ''), - access_token, + access_token = localStorage.getItem("pop2_access_token"), callbackWaitForToken, w_width = 400, w_height = 360; @@ -91,10 +91,12 @@ receiveToken: function (token, expires_in) { if (token !== 'ERROR') { access_token = token; + localStorage.setItem("pop2_access_token", access_token); if (callbackWaitForToken) callbackWaitForToken(access_token); setTimeout( function () { access_token = undefined; + localStorage.removeItem("pop2_access_token"); }, expires_in * 1000 ); @@ -129,6 +131,31 @@ } else { return callback(access_token); } + }, + clientRequest: function (resource, options, refresh = false) { + const sendRequest = function () { + options.credentials = 'include'; + if (options.headers) { + options.headers['Authorization'] = 'Bearer ' + access_token; + } + return fetch(resource, options); + }; + if (!refresh) return sendRequest(); + else { + return new Promise(function (res, rej) { + sendRequest().then(function (x) { + res(x) + }).catch(function () { + w.POP2.getToken(function () { + sendRequest().then(function (x) { + res(x); + }).catch(function (x) { + rej(x); + }); + }) + }); + }); + } } }; -})(this); +})(window);