Merge pull request #1 from 1f349/alpha

Support refreshing tokens
This commit is contained in:
Melon 2023-12-13 02:11:38 +00:00 committed by GitHub
commit 67be57b6b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 339 additions and 182 deletions

1
go.mod
View File

@ -12,6 +12,7 @@ require (
github.com/google/subcommands v1.2.0
github.com/google/uuid v1.4.0
github.com/julienschmidt/httprouter v1.3.0
github.com/rs/cors v1.10.1
github.com/stretchr/testify v1.8.4
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
)

2
go.sum
View File

@ -46,6 +46,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/rs/cors v1.10.1 h1:L0uuZVXIKlI1SShY2nhFfo44TYvDPQ1w4oFkUJNfhyo=
github.com/rs/cors v1.10.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

View File

@ -3,18 +3,13 @@ package server
import (
"context"
_ "embed"
"encoding/json"
"fmt"
"github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/server/pages"
"github.com/1f349/mjwt/auth"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/julienschmidt/httprouter"
"golang.org/x/oauth2"
"net/http"
"net/mail"
"net/url"
"strings"
"time"
@ -138,98 +133,14 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
http.Error(rw, "Failed to exchange code", http.StatusInternalServerError)
return
}
v2, err := testOa2UserInfo(v.sso, req.Context(), exchange)
if err != nil {
fmt.Println("Failed to get userinfo:", err)
http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError)
return
}
defer v2.Body.Close()
if v2.StatusCode != http.StatusOK {
http.Error(rw, "Failed to get userinfo: unexpected status code", http.StatusInternalServerError)
return
}
var v3 map[string]any
if err = json.NewDecoder(v2.Body).Decode(&v3); err != nil {
fmt.Println("Failed to decode userinfo:", err)
http.Error(rw, "Failed to decode userinfo", http.StatusInternalServerError)
return
}
sub, ok := v3["sub"].(string)
if !ok {
http.Error(rw, "Invalid subject in userinfo", http.StatusInternalServerError)
return
}
aud, ok := v3["aud"].(string)
if !ok {
http.Error(rw, "Invalid audience in userinfo", http.StatusInternalServerError)
return
}
var needsMailFlag, needsDomains bool
ps := claims.NewPermStorage()
for _, i := range v.target.Permissions {
if strings.HasPrefix(i, "dynamic:") {
switch i {
case "dynamic:mail-inbox":
needsMailFlag = true
case "dynamic:domain-owns":
needsDomains = true
}
} else {
ps.Set(i)
}
}
if needsMailFlag {
if verified, ok := v3["email_verified"].(bool); ok && verified {
if mailAddress, ok := v3["email"].(string); ok {
address, err := mail.ParseAddress(mailAddress)
if err != nil {
http.Error(rw, "Invalid email in userinfo", http.StatusInternalServerError)
return
}
n := strings.IndexByte(address.Address, '@')
if n != -1 {
if address.Address[n+1:] == v.sso.Config.Namespace {
ps.Set("mail:inbox=" + address.Address)
}
}
}
}
}
if needsDomains {
a := h.conf.Load().Users.AllDomains(sub + "@" + v.sso.Config.Namespace)
for _, i := range a {
ps.Set("domain:owns=" + i)
}
}
nsSub := sub + "@" + v.sso.Config.Namespace
ati := uuidNewStringAti()
accessToken, err := h.signer.GenerateJwt(nsSub, ati, jwt.ClaimStrings{aud}, 15*time.Minute, auth.AccessTokenClaims{
Perms: ps,
})
if err != nil {
http.Error(rw, "Error generating access token", http.StatusInternalServerError)
return
}
refreshToken, err := h.signer.GenerateJwt(nsSub, uuidNewStringRti(), jwt.ClaimStrings{aud}, 15*time.Minute, auth.RefreshTokenClaims{AccessTokenId: ati})
if err != nil {
http.Error(rw, "Error generating refresh token", http.StatusInternalServerError)
return
}
pages.RenderPageTemplate(rw, "flow-callback", map[string]any{
"ServiceName": h.conf.Load().ServiceName,
"TargetOrigin": v.target.Url.String(),
"TargetMessage": v3,
"AccessToken": accessToken,
"RefreshToken": refreshToken,
h.finishTokenGenerateFlow(rw, req, v, exchange, func(accessToken, refreshToken string, v3 map[string]any) {
pages.RenderPageTemplate(rw, "flow-callback", map[string]any{
"ServiceName": h.conf.Load().ServiceName,
"TargetOrigin": v.target.Url.String(),
"TargetMessage": v3,
"AccessToken": accessToken,
"RefreshToken": refreshToken,
})
})
}

View File

@ -356,9 +356,9 @@ func TestFlowCallback(t *testing.T) {
<script>
let loginData = {
target:"%s",
userinfo:{"aud":"%s","sub":"test-user","test-field":"ok"},
tokens: `
const p2 = `,
userinfo:{"aud":"%s","sub":"test-user","test-field":"ok"},
};
window.addEventListener("load", function () {
window.opener.postMessage(loginData, loginData.target);
@ -376,7 +376,8 @@ func TestFlowCallback(t *testing.T) {
</body>
</html>
`
var p1v = fmt.Sprintf(p1, clientAppDomain, clientAppDomain)
var p1v = fmt.Sprintf(p1, clientAppDomain)
var p2v = fmt.Sprintf(p2, clientAppDomain)
a := make([]byte, len(p1v))
n, err := rec.Body.Read(a)
@ -394,7 +395,7 @@ func TestFlowCallback(t *testing.T) {
findByte(rec.Body, ',')
findByte(rec.Body, '}')
assert.Equal(t, p2, rec.Body.String())
assert.Equal(t, p2v, rec.Body.String())
}
func findByte(buf *bytes.Buffer, v byte) {

View File

@ -5,11 +5,11 @@
<script>
let loginData = {
target:{{.TargetOrigin}},
userinfo:{{.TargetMessage}},
tokens: {
access:{{.AccessToken}},
refresh:{{.RefreshToken}},
},
userinfo:{{.TargetMessage}},
};
window.addEventListener("load", function () {
window.opener.postMessage(loginData, loginData.target);

183
server/refresh.go Normal file
View File

@ -0,0 +1,183 @@
package server
import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/1f349/mjwt/auth"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"github.com/julienschmidt/httprouter"
"golang.org/x/oauth2"
"net/http"
"net/mail"
"strings"
"time"
)
func (h *HttpServer) refreshHandler(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
ref := strings.TrimSuffix(req.Referer(), "/")
allowedClient, ok := (*h.services.Load())[ref]
if !ok {
http.Error(rw, "Invalid origin", http.StatusBadRequest)
return
}
loginNameCookie, err := req.Cookie("lavender-login-name")
if err != nil {
http.Error(rw, "Failed to read cookie", http.StatusBadRequest)
return
}
loginService := h.manager.Load().FindServiceFromLogin(loginNameCookie.Value)
cookie, err := req.Cookie("sso-exchange")
if err != nil {
http.Error(rw, "Failed to read cookie", http.StatusBadRequest)
return
}
rawEncrypt, err := base64.RawURLEncoding.DecodeString(cookie.Value)
if err != nil {
http.Error(rw, "Internal Server Error", http.StatusBadRequest)
return
}
rawTokens, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, h.signer.PrivateKey(), rawEncrypt, []byte("sso-exchange"))
if err != nil {
http.Error(rw, "Internal Server Error", http.StatusBadRequest)
return
}
var exchange oauth2.Token
err = json.Unmarshal(rawTokens, &exchange)
if err != nil {
http.Error(rw, "Internal Server Error", http.StatusBadRequest)
return
}
h.finishTokenGenerateFlow(rw, req, flowStateData{
sso: loginService,
target: allowedClient,
}, &exchange, func(accessToken string, refreshToken string, v3 map[string]any) {
tokens := map[string]any{
"target": allowedClient.Url.String(),
"userinfo": v3,
"tokens": map[string]any{
"access": accessToken,
"refresh": refreshToken,
},
}
_ = json.NewEncoder(rw).Encode(tokens)
})
}
func (h *HttpServer) finishTokenGenerateFlow(rw http.ResponseWriter, req *http.Request, v flowStateData, exchange *oauth2.Token, response func(accessToken string, refreshToken string, v3 map[string]any)) {
// fetch user info
v2, err := testOa2UserInfo(v.sso, req.Context(), exchange)
if err != nil {
fmt.Println("Failed to get userinfo:", err)
http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError)
return
}
defer v2.Body.Close()
if v2.StatusCode != http.StatusOK {
http.Error(rw, "Failed to get userinfo: unexpected status code", http.StatusInternalServerError)
return
}
// encrypt exchange tokens for cookie storage
marshal, err := json.Marshal(exchange)
if err != nil {
fmt.Println("Failed to marshal exchange tokens", err)
http.Error(rw, "Internal server error", http.StatusInternalServerError)
return
}
oaepBytes, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, h.signer.PublicKey(), marshal, []byte("sso-exchange"))
if err != nil {
fmt.Println("Failed to encrypt exchange tokens", err)
http.Error(rw, "Internal server error", http.StatusInternalServerError)
return
}
http.SetCookie(rw, &http.Cookie{
Name: "sso-exchange",
Value: base64.RawURLEncoding.EncodeToString(oaepBytes),
Path: "/",
Expires: time.Now().AddDate(0, 3, 0),
Secure: true,
SameSite: http.SameSiteLaxMode,
})
var v3 map[string]any
if err = json.NewDecoder(v2.Body).Decode(&v3); err != nil {
fmt.Println("Failed to decode userinfo:", err)
http.Error(rw, "Failed to decode userinfo", http.StatusInternalServerError)
return
}
sub, ok := v3["sub"].(string)
if !ok {
http.Error(rw, "Invalid subject in userinfo", http.StatusInternalServerError)
return
}
aud, ok := v3["aud"].(string)
if !ok {
http.Error(rw, "Invalid audience in userinfo", http.StatusInternalServerError)
return
}
var needsMailFlag, needsDomains bool
ps := claims.NewPermStorage()
for _, i := range v.target.Permissions {
if strings.HasPrefix(i, "dynamic:") {
switch i {
case "dynamic:mail-inbox":
needsMailFlag = true
case "dynamic:domain-owns":
needsDomains = true
}
} else {
ps.Set(i)
}
}
if needsMailFlag {
if verified, ok := v3["email_verified"].(bool); ok && verified {
if mailAddress, ok := v3["email"].(string); ok {
address, err := mail.ParseAddress(mailAddress)
if err != nil {
http.Error(rw, "Invalid email in userinfo", http.StatusInternalServerError)
return
}
n := strings.IndexByte(address.Address, '@')
if n != -1 {
if address.Address[n+1:] == v.sso.Config.Namespace {
ps.Set("mail:inbox=" + address.Address)
}
}
}
}
}
if needsDomains {
a := h.conf.Load().Users.AllDomains(sub + "@" + v.sso.Config.Namespace)
for _, i := range a {
ps.Set("domain:owns=" + i)
}
}
nsSub := sub + "@" + v.sso.Config.Namespace
ati := uuidNewStringAti()
accessToken, err := h.signer.GenerateJwt(nsSub, ati, jwt.ClaimStrings{aud}, 15*time.Minute, auth.AccessTokenClaims{
Perms: ps,
})
if err != nil {
http.Error(rw, "Error generating access token", http.StatusInternalServerError)
return
}
refreshToken, err := h.signer.GenerateJwt(nsSub, uuidNewStringRti(), jwt.ClaimStrings{aud}, 15*time.Minute, auth.RefreshTokenClaims{AccessTokenId: ati})
if err != nil {
http.Error(rw, "Error generating refresh token", http.StatusInternalServerError)
return
}
response(accessToken, refreshToken, v3)
}

View File

@ -6,6 +6,7 @@ import (
"github.com/1f349/lavender/issuer"
"github.com/1f349/mjwt"
"github.com/julienschmidt/httprouter"
"github.com/rs/cors"
"log"
"net/http"
"sync/atomic"
@ -66,6 +67,25 @@ func NewHttpServer(conf Conf, signer mjwt.Signer) *HttpServer {
r.GET("/popup", hs.flowPopup)
r.POST("/popup", hs.flowPopupPost)
r.GET("/callback", hs.flowCallback)
var corsAccessControl = cors.New(cors.Options{
AllowOriginFunc: func(origin string) bool {
load := hs.services.Load()
_, ok := (*load)[origin]
return ok
},
AllowedMethods: []string{http.MethodPost, http.MethodOptions},
AllowedHeaders: []string{"Content-Type"},
AllowCredentials: true,
})
r.POST("/refresh", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
corsAccessControl.ServeHTTP(rw, req, func(writer http.ResponseWriter, request *http.Request) {
hs.refreshHandler(rw, req, params)
})
})
r.OPTIONS("/refresh", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
corsAccessControl.ServeHTTP(rw, req, func(_ http.ResponseWriter, _ *http.Request) {})
})
return hs
}

View File

@ -1,103 +1,142 @@
<!DOCTYPE html>
<html lang="en">
<head>
<title>Test Client</title>
<script>
let currentLoginPopup = null;
<title>Test Client</title>
<script>
let currentLoginPopup = null;
let currentTokens = null;
const ssoService = "http://localhost:9090";
window.addEventListener("message", function (event) {
if (event.origin !== "http:\/\/localhost:9090") return;
if (isObject(event.data)) {
document.getElementById("someTextArea").textContent = JSON.stringify(event.data, null, 2);
let perms = document.getElementById("somePerms");
while (perms.childNodes.length > 0) {
perms.childNodes.item(0).remove();
}
let jwt = parseJwt(event.data.tokens.access);
if (jwt.per != null) {
jwt.per.forEach(function (x) {
let a = document.createElement("li");
a.textContent = x;
perms.appendChild(a);
});
}
function updateTokenInfo(data) {
currentTokens = data.tokens;
data.tokens = {
access: "*****",
refresh: "*****",
}
document.getElementById("someTextArea").textContent = JSON.stringify(data, null, 2);
let perms = document.getElementById("somePerms");
while (perms.childNodes.length > 0) {
perms.childNodes.item(0).remove();
}
document.getElementById("tokenValues").textContent = JSON.stringify(currentTokens, null, 2);
if (currentLoginPopup) currentLoginPopup.close();
return;
}
alert("Failed to log user in: the login data was probably corrupted");
});
let jwt = parseJwt(currentTokens.access);
if (jwt.per != null) {
jwt.per.forEach(function (x) {
let a = document.createElement("li");
a.textContent = x;
perms.appendChild(a);
});
}
}
function parseJwt(token) {
const base64Url = token.split('.')[1];
const base64 = base64Url.replace(/-/g, '+').replace(/_/g, '/');
const jsonPayload = decodeURIComponent(window.atob(base64).split('').map(function (c) {
return '%' + ('00' + c.charCodeAt(0).toString(16)).slice(-2);
}).join(''));
return JSON.parse(jsonPayload);
}
window.addEventListener("message", function (event) {
if (event.origin !== ssoService) return;
if (isObject(event.data)) {
updateTokenInfo(event.data);
function isObject(obj) {
return obj != null && obj.constructor.name === "Object"
}
if (currentLoginPopup) currentLoginPopup.close();
return;
}
alert("Failed to log user in: the login data was probably corrupted");
});
function popupCenterScreen(url, title, w, h, focus) {
const top = (screen.availHeight - h) / 4, left = (screen.availWidth - w) / 2;
const popup = openWindow(url, title, `scrollbars=yes,width=${w},height=${h},top=${top},left=${left}`);
if (focus === true && window.focus) popup.focus();
return popup;
}
function parseJwt(token) {
const base64Url = token.split('.')[1];
const base64 = base64Url.replace(/-/g, '+').replace(/_/g, '/');
const jsonPayload = decodeURIComponent(window.atob(base64).split('').map(function (c) {
return '%' + ('00' + c.charCodeAt(0).toString(16)).slice(-2);
}).join(''));
return JSON.parse(jsonPayload);
}
function openWindow(url, winnm, options) {
var wTop = firstAvailableValue([window.screen.availTop, window.screenY, window.screenTop, 0]);
var wLeft = firstAvailableValue([window.screen.availLeft, window.screenX, window.screenLeft, 0]);
var top = 0, left = 0;
var result;
if ((result = /top=(\d+)/g.exec(options))) top = parseInt(result[1]);
if ((result = /left=(\d+)/g.exec(options))) left = parseInt(result[1]);
if (options) {
options = options.replace("top=" + top, "top=" + (parseInt(top) + wTop));
options = options.replace("left=" + left, "left=" + (parseInt(left) + wLeft));
w = window.open(url, winnm, options);
} else w = window.open(url, winnm);
return w;
}
function isObject(obj) {
return obj != null && obj.constructor.name === "Object"
}
function firstAvailableValue(arr) {
for (var i = 0; i < arr.length; i++)
if (typeof arr[i] != 'undefined')
return arr[i];
}
function popupCenterScreen(url, title, w, h, focus) {
const top = (screen.availHeight - h) / 4, left = (screen.availWidth - w) / 2;
const popup = openWindow(url, title, `scrollbars=yes,width=${w},height=${h},top=${top},left=${left}`);
if (focus === true && window.focus) popup.focus();
return popup;
}
function doThisThing() {
if (currentLoginPopup) currentLoginPopup.close();
currentLoginPopup = popupCenterScreen('http://localhost:9090/popup?origin=' + encodeURIComponent("http://localhost:2020"), 'Login with Lavender', 500, 500, false);
}
</script>
<style>
#someTextArea {
width: 400px;
height: 400px;
}
</style>
function openWindow(url, winnm, options) {
var wTop = firstAvailableValue([window.screen.availTop, window.screenY, window.screenTop, 0]);
var wLeft = firstAvailableValue([window.screen.availLeft, window.screenX, window.screenLeft, 0]);
var top = 0, left = 0;
var result;
if ((result = /top=(\d+)/g.exec(options))) top = parseInt(result[1]);
if ((result = /left=(\d+)/g.exec(options))) left = parseInt(result[1]);
if (options) {
options = options.replace("top=" + top, "top=" + (parseInt(top) + wTop));
options = options.replace("left=" + left, "left=" + (parseInt(left) + wLeft));
w = window.open(url, winnm, options);
} else w = window.open(url, winnm);
return w;
}
function firstAvailableValue(arr) {
for (var i = 0; i < arr.length; i++)
if (typeof arr[i] != 'undefined')
return arr[i];
}
function doThisThing() {
if (currentLoginPopup) currentLoginPopup.close();
currentLoginPopup = popupCenterScreen(ssoService + '/popup?origin=' + encodeURIComponent(location.origin), 'Login with Lavender', 500, 500, false);
}
async function refreshAllTokens() {
let req = await fetch(ssoService + '/refresh', {
method: 'POST',
mode: 'cors',
cache: 'no-cache',
credentials: 'include',
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({"token": currentTokens.refresh}),
});
let reqJson = await req.json();
updateTokenInfo(reqJson);
}
</script>
<style>
#someTextArea {
width: 400px;
height: 400px;
}
#tokenValues {
width: 400px;
height: 400px;
}
</style>
</head>
<body>
<header>
<h1>Test Client</h1>
<h1>Test Client</h1>
</header>
<main>
<div>
<button onclick="doThisThing();">Login</button>
<button onclick="refreshAllTokens();">Refresh</button>
</div>
<div style="display:flex; gap: 2em;">
<div>
<button onclick="doThisThing();">Login</button>
<div>
<label for="someTextArea"></label><textarea id="someTextArea"></textarea>
</div>
<div>
<label for="tokenValues"></label><textarea id="tokenValues"></textarea>
</div>
</div>
<div style="display:flex; gap: 2em;">
<div>
<label for="someTextArea"></label><textarea id="someTextArea"></textarea>
</div>
<div>
<p>Permissions:</p>
<ul id="somePerms"></ul>
</div>
<div>
<p>Permissions:</p>
<ul id="somePerms"></ul>
</div>
</div>
</main>
</body>
</html>