Merge pull request #1 from 1f349/alpha

Support refreshing tokens
This commit is contained in:
Melon 2023-12-13 02:11:38 +00:00 committed by GitHub
commit 67be57b6b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 339 additions and 182 deletions

1
go.mod
View File

@ -12,6 +12,7 @@ require (
github.com/google/subcommands v1.2.0 github.com/google/subcommands v1.2.0
github.com/google/uuid v1.4.0 github.com/google/uuid v1.4.0
github.com/julienschmidt/httprouter v1.3.0 github.com/julienschmidt/httprouter v1.3.0
github.com/rs/cors v1.10.1
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.8.4
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
) )

2
go.sum
View File

@ -46,6 +46,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/rs/cors v1.10.1 h1:L0uuZVXIKlI1SShY2nhFfo44TYvDPQ1w4oFkUJNfhyo=
github.com/rs/cors v1.10.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

View File

@ -3,18 +3,13 @@ package server
import ( import (
"context" "context"
_ "embed" _ "embed"
"encoding/json"
"fmt" "fmt"
"github.com/1f349/lavender/issuer" "github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/server/pages" "github.com/1f349/lavender/server/pages"
"github.com/1f349/mjwt/auth"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"net/http" "net/http"
"net/mail"
"net/url" "net/url"
"strings" "strings"
"time" "time"
@ -138,93 +133,8 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
http.Error(rw, "Failed to exchange code", http.StatusInternalServerError) http.Error(rw, "Failed to exchange code", http.StatusInternalServerError)
return return
} }
v2, err := testOa2UserInfo(v.sso, req.Context(), exchange)
if err != nil {
fmt.Println("Failed to get userinfo:", err)
http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError)
return
}
defer v2.Body.Close()
if v2.StatusCode != http.StatusOK {
http.Error(rw, "Failed to get userinfo: unexpected status code", http.StatusInternalServerError)
return
}
var v3 map[string]any
if err = json.NewDecoder(v2.Body).Decode(&v3); err != nil {
fmt.Println("Failed to decode userinfo:", err)
http.Error(rw, "Failed to decode userinfo", http.StatusInternalServerError)
return
}
sub, ok := v3["sub"].(string)
if !ok {
http.Error(rw, "Invalid subject in userinfo", http.StatusInternalServerError)
return
}
aud, ok := v3["aud"].(string)
if !ok {
http.Error(rw, "Invalid audience in userinfo", http.StatusInternalServerError)
return
}
var needsMailFlag, needsDomains bool
ps := claims.NewPermStorage()
for _, i := range v.target.Permissions {
if strings.HasPrefix(i, "dynamic:") {
switch i {
case "dynamic:mail-inbox":
needsMailFlag = true
case "dynamic:domain-owns":
needsDomains = true
}
} else {
ps.Set(i)
}
}
if needsMailFlag {
if verified, ok := v3["email_verified"].(bool); ok && verified {
if mailAddress, ok := v3["email"].(string); ok {
address, err := mail.ParseAddress(mailAddress)
if err != nil {
http.Error(rw, "Invalid email in userinfo", http.StatusInternalServerError)
return
}
n := strings.IndexByte(address.Address, '@')
if n != -1 {
if address.Address[n+1:] == v.sso.Config.Namespace {
ps.Set("mail:inbox=" + address.Address)
}
}
}
}
}
if needsDomains {
a := h.conf.Load().Users.AllDomains(sub + "@" + v.sso.Config.Namespace)
for _, i := range a {
ps.Set("domain:owns=" + i)
}
}
nsSub := sub + "@" + v.sso.Config.Namespace
ati := uuidNewStringAti()
accessToken, err := h.signer.GenerateJwt(nsSub, ati, jwt.ClaimStrings{aud}, 15*time.Minute, auth.AccessTokenClaims{
Perms: ps,
})
if err != nil {
http.Error(rw, "Error generating access token", http.StatusInternalServerError)
return
}
refreshToken, err := h.signer.GenerateJwt(nsSub, uuidNewStringRti(), jwt.ClaimStrings{aud}, 15*time.Minute, auth.RefreshTokenClaims{AccessTokenId: ati})
if err != nil {
http.Error(rw, "Error generating refresh token", http.StatusInternalServerError)
return
}
h.finishTokenGenerateFlow(rw, req, v, exchange, func(accessToken, refreshToken string, v3 map[string]any) {
pages.RenderPageTemplate(rw, "flow-callback", map[string]any{ pages.RenderPageTemplate(rw, "flow-callback", map[string]any{
"ServiceName": h.conf.Load().ServiceName, "ServiceName": h.conf.Load().ServiceName,
"TargetOrigin": v.target.Url.String(), "TargetOrigin": v.target.Url.String(),
@ -232,4 +142,5 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
"AccessToken": accessToken, "AccessToken": accessToken,
"RefreshToken": refreshToken, "RefreshToken": refreshToken,
}) })
})
} }

View File

@ -356,9 +356,9 @@ func TestFlowCallback(t *testing.T) {
<script> <script>
let loginData = { let loginData = {
target:"%s", target:"%s",
userinfo:{"aud":"%s","sub":"test-user","test-field":"ok"},
tokens: ` tokens: `
const p2 = `, const p2 = `,
userinfo:{"aud":"%s","sub":"test-user","test-field":"ok"},
}; };
window.addEventListener("load", function () { window.addEventListener("load", function () {
window.opener.postMessage(loginData, loginData.target); window.opener.postMessage(loginData, loginData.target);
@ -376,7 +376,8 @@ func TestFlowCallback(t *testing.T) {
</body> </body>
</html> </html>
` `
var p1v = fmt.Sprintf(p1, clientAppDomain, clientAppDomain) var p1v = fmt.Sprintf(p1, clientAppDomain)
var p2v = fmt.Sprintf(p2, clientAppDomain)
a := make([]byte, len(p1v)) a := make([]byte, len(p1v))
n, err := rec.Body.Read(a) n, err := rec.Body.Read(a)
@ -394,7 +395,7 @@ func TestFlowCallback(t *testing.T) {
findByte(rec.Body, ',') findByte(rec.Body, ',')
findByte(rec.Body, '}') findByte(rec.Body, '}')
assert.Equal(t, p2, rec.Body.String()) assert.Equal(t, p2v, rec.Body.String())
} }
func findByte(buf *bytes.Buffer, v byte) { func findByte(buf *bytes.Buffer, v byte) {

View File

@ -5,11 +5,11 @@
<script> <script>
let loginData = { let loginData = {
target:{{.TargetOrigin}}, target:{{.TargetOrigin}},
userinfo:{{.TargetMessage}},
tokens: { tokens: {
access:{{.AccessToken}}, access:{{.AccessToken}},
refresh:{{.RefreshToken}}, refresh:{{.RefreshToken}},
}, },
userinfo:{{.TargetMessage}},
}; };
window.addEventListener("load", function () { window.addEventListener("load", function () {
window.opener.postMessage(loginData, loginData.target); window.opener.postMessage(loginData, loginData.target);

183
server/refresh.go Normal file
View File

@ -0,0 +1,183 @@
package server
import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/1f349/mjwt/auth"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"github.com/julienschmidt/httprouter"
"golang.org/x/oauth2"
"net/http"
"net/mail"
"strings"
"time"
)
func (h *HttpServer) refreshHandler(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
ref := strings.TrimSuffix(req.Referer(), "/")
allowedClient, ok := (*h.services.Load())[ref]
if !ok {
http.Error(rw, "Invalid origin", http.StatusBadRequest)
return
}
loginNameCookie, err := req.Cookie("lavender-login-name")
if err != nil {
http.Error(rw, "Failed to read cookie", http.StatusBadRequest)
return
}
loginService := h.manager.Load().FindServiceFromLogin(loginNameCookie.Value)
cookie, err := req.Cookie("sso-exchange")
if err != nil {
http.Error(rw, "Failed to read cookie", http.StatusBadRequest)
return
}
rawEncrypt, err := base64.RawURLEncoding.DecodeString(cookie.Value)
if err != nil {
http.Error(rw, "Internal Server Error", http.StatusBadRequest)
return
}
rawTokens, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, h.signer.PrivateKey(), rawEncrypt, []byte("sso-exchange"))
if err != nil {
http.Error(rw, "Internal Server Error", http.StatusBadRequest)
return
}
var exchange oauth2.Token
err = json.Unmarshal(rawTokens, &exchange)
if err != nil {
http.Error(rw, "Internal Server Error", http.StatusBadRequest)
return
}
h.finishTokenGenerateFlow(rw, req, flowStateData{
sso: loginService,
target: allowedClient,
}, &exchange, func(accessToken string, refreshToken string, v3 map[string]any) {
tokens := map[string]any{
"target": allowedClient.Url.String(),
"userinfo": v3,
"tokens": map[string]any{
"access": accessToken,
"refresh": refreshToken,
},
}
_ = json.NewEncoder(rw).Encode(tokens)
})
}
func (h *HttpServer) finishTokenGenerateFlow(rw http.ResponseWriter, req *http.Request, v flowStateData, exchange *oauth2.Token, response func(accessToken string, refreshToken string, v3 map[string]any)) {
// fetch user info
v2, err := testOa2UserInfo(v.sso, req.Context(), exchange)
if err != nil {
fmt.Println("Failed to get userinfo:", err)
http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError)
return
}
defer v2.Body.Close()
if v2.StatusCode != http.StatusOK {
http.Error(rw, "Failed to get userinfo: unexpected status code", http.StatusInternalServerError)
return
}
// encrypt exchange tokens for cookie storage
marshal, err := json.Marshal(exchange)
if err != nil {
fmt.Println("Failed to marshal exchange tokens", err)
http.Error(rw, "Internal server error", http.StatusInternalServerError)
return
}
oaepBytes, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, h.signer.PublicKey(), marshal, []byte("sso-exchange"))
if err != nil {
fmt.Println("Failed to encrypt exchange tokens", err)
http.Error(rw, "Internal server error", http.StatusInternalServerError)
return
}
http.SetCookie(rw, &http.Cookie{
Name: "sso-exchange",
Value: base64.RawURLEncoding.EncodeToString(oaepBytes),
Path: "/",
Expires: time.Now().AddDate(0, 3, 0),
Secure: true,
SameSite: http.SameSiteLaxMode,
})
var v3 map[string]any
if err = json.NewDecoder(v2.Body).Decode(&v3); err != nil {
fmt.Println("Failed to decode userinfo:", err)
http.Error(rw, "Failed to decode userinfo", http.StatusInternalServerError)
return
}
sub, ok := v3["sub"].(string)
if !ok {
http.Error(rw, "Invalid subject in userinfo", http.StatusInternalServerError)
return
}
aud, ok := v3["aud"].(string)
if !ok {
http.Error(rw, "Invalid audience in userinfo", http.StatusInternalServerError)
return
}
var needsMailFlag, needsDomains bool
ps := claims.NewPermStorage()
for _, i := range v.target.Permissions {
if strings.HasPrefix(i, "dynamic:") {
switch i {
case "dynamic:mail-inbox":
needsMailFlag = true
case "dynamic:domain-owns":
needsDomains = true
}
} else {
ps.Set(i)
}
}
if needsMailFlag {
if verified, ok := v3["email_verified"].(bool); ok && verified {
if mailAddress, ok := v3["email"].(string); ok {
address, err := mail.ParseAddress(mailAddress)
if err != nil {
http.Error(rw, "Invalid email in userinfo", http.StatusInternalServerError)
return
}
n := strings.IndexByte(address.Address, '@')
if n != -1 {
if address.Address[n+1:] == v.sso.Config.Namespace {
ps.Set("mail:inbox=" + address.Address)
}
}
}
}
}
if needsDomains {
a := h.conf.Load().Users.AllDomains(sub + "@" + v.sso.Config.Namespace)
for _, i := range a {
ps.Set("domain:owns=" + i)
}
}
nsSub := sub + "@" + v.sso.Config.Namespace
ati := uuidNewStringAti()
accessToken, err := h.signer.GenerateJwt(nsSub, ati, jwt.ClaimStrings{aud}, 15*time.Minute, auth.AccessTokenClaims{
Perms: ps,
})
if err != nil {
http.Error(rw, "Error generating access token", http.StatusInternalServerError)
return
}
refreshToken, err := h.signer.GenerateJwt(nsSub, uuidNewStringRti(), jwt.ClaimStrings{aud}, 15*time.Minute, auth.RefreshTokenClaims{AccessTokenId: ati})
if err != nil {
http.Error(rw, "Error generating refresh token", http.StatusInternalServerError)
return
}
response(accessToken, refreshToken, v3)
}

View File

@ -6,6 +6,7 @@ import (
"github.com/1f349/lavender/issuer" "github.com/1f349/lavender/issuer"
"github.com/1f349/mjwt" "github.com/1f349/mjwt"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/rs/cors"
"log" "log"
"net/http" "net/http"
"sync/atomic" "sync/atomic"
@ -66,6 +67,25 @@ func NewHttpServer(conf Conf, signer mjwt.Signer) *HttpServer {
r.GET("/popup", hs.flowPopup) r.GET("/popup", hs.flowPopup)
r.POST("/popup", hs.flowPopupPost) r.POST("/popup", hs.flowPopupPost)
r.GET("/callback", hs.flowCallback) r.GET("/callback", hs.flowCallback)
var corsAccessControl = cors.New(cors.Options{
AllowOriginFunc: func(origin string) bool {
load := hs.services.Load()
_, ok := (*load)[origin]
return ok
},
AllowedMethods: []string{http.MethodPost, http.MethodOptions},
AllowedHeaders: []string{"Content-Type"},
AllowCredentials: true,
})
r.POST("/refresh", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
corsAccessControl.ServeHTTP(rw, req, func(writer http.ResponseWriter, request *http.Request) {
hs.refreshHandler(rw, req, params)
})
})
r.OPTIONS("/refresh", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
corsAccessControl.ServeHTTP(rw, req, func(_ http.ResponseWriter, _ *http.Request) {})
})
return hs return hs
} }

View File

@ -4,16 +4,23 @@
<title>Test Client</title> <title>Test Client</title>
<script> <script>
let currentLoginPopup = null; let currentLoginPopup = null;
let currentTokens = null;
const ssoService = "http://localhost:9090";
window.addEventListener("message", function (event) { function updateTokenInfo(data) {
if (event.origin !== "http:\/\/localhost:9090") return; currentTokens = data.tokens;
if (isObject(event.data)) { data.tokens = {
document.getElementById("someTextArea").textContent = JSON.stringify(event.data, null, 2); access: "*****",
refresh: "*****",
}
document.getElementById("someTextArea").textContent = JSON.stringify(data, null, 2);
let perms = document.getElementById("somePerms"); let perms = document.getElementById("somePerms");
while (perms.childNodes.length > 0) { while (perms.childNodes.length > 0) {
perms.childNodes.item(0).remove(); perms.childNodes.item(0).remove();
} }
let jwt = parseJwt(event.data.tokens.access); document.getElementById("tokenValues").textContent = JSON.stringify(currentTokens, null, 2);
let jwt = parseJwt(currentTokens.access);
if (jwt.per != null) { if (jwt.per != null) {
jwt.per.forEach(function (x) { jwt.per.forEach(function (x) {
let a = document.createElement("li"); let a = document.createElement("li");
@ -21,6 +28,12 @@
perms.appendChild(a); perms.appendChild(a);
}); });
} }
}
window.addEventListener("message", function (event) {
if (event.origin !== ssoService) return;
if (isObject(event.data)) {
updateTokenInfo(event.data);
if (currentLoginPopup) currentLoginPopup.close(); if (currentLoginPopup) currentLoginPopup.close();
return; return;
@ -71,7 +84,22 @@
function doThisThing() { function doThisThing() {
if (currentLoginPopup) currentLoginPopup.close(); if (currentLoginPopup) currentLoginPopup.close();
currentLoginPopup = popupCenterScreen('http://localhost:9090/popup?origin=' + encodeURIComponent("http://localhost:2020"), 'Login with Lavender', 500, 500, false); currentLoginPopup = popupCenterScreen(ssoService + '/popup?origin=' + encodeURIComponent(location.origin), 'Login with Lavender', 500, 500, false);
}
async function refreshAllTokens() {
let req = await fetch(ssoService + '/refresh', {
method: 'POST',
mode: 'cors',
cache: 'no-cache',
credentials: 'include',
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({"token": currentTokens.refresh}),
});
let reqJson = await req.json();
updateTokenInfo(reqJson);
} }
</script> </script>
<style> <style>
@ -79,6 +107,11 @@
width: 400px; width: 400px;
height: 400px; height: 400px;
} }
#tokenValues {
width: 400px;
height: 400px;
}
</style> </style>
</head> </head>
<body> <body>
@ -88,11 +121,17 @@
<main> <main>
<div> <div>
<button onclick="doThisThing();">Login</button> <button onclick="doThisThing();">Login</button>
<button onclick="refreshAllTokens();">Refresh</button>
</div> </div>
<div style="display:flex; gap: 2em;"> <div style="display:flex; gap: 2em;">
<div>
<div> <div>
<label for="someTextArea"></label><textarea id="someTextArea"></textarea> <label for="someTextArea"></label><textarea id="someTextArea"></textarea>
</div> </div>
<div>
<label for="tokenValues"></label><textarea id="tokenValues"></textarea>
</div>
</div>
<div> <div>
<p>Permissions:</p> <p>Permissions:</p>
<ul id="somePerms"></ul> <ul id="somePerms"></ul>