4
0
mirror of https://github.com/1f349/lavender.git synced 2025-01-11 17:26:28 +00:00

Write flow tests

This commit is contained in:
Melon 2023-10-09 16:29:10 +01:00
parent 1280c30c5e
commit b15aa72ec0
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
3 changed files with 344 additions and 54 deletions

View File

@ -30,13 +30,13 @@ func NewManager(services []SsoConfig) (*Manager, error) {
return l, nil return l, nil
} }
func NewManagerForTests(services []WellKnownOIDC) *Manager { func NewManagerForTests(services []*WellKnownOIDC) *Manager {
l := &Manager{m: make(map[string]*WellKnownOIDC, len(services))} l := &Manager{m: make(map[string]*WellKnownOIDC, len(services))}
for _, i := range services { for _, i := range services {
if !isValidNamespace.MatchString(i.Config.Namespace) { if !isValidNamespace.MatchString(i.Config.Namespace) {
panic("Invalid namespace in tests") panic("Invalid namespace in tests")
} }
l.m[i.Config.Namespace] = &i l.m[i.Config.Namespace] = i
} }
return l return l
} }

View File

@ -5,6 +5,7 @@ import (
_ "embed" _ "embed"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/server/pages" "github.com/1f349/lavender/server/pages"
"github.com/MrMelon54/mjwt/auth" "github.com/MrMelon54/mjwt/auth"
"github.com/MrMelon54/mjwt/claims" "github.com/MrMelon54/mjwt/claims"
@ -21,6 +22,15 @@ var uuidNewStringState = uuid.NewString
var uuidNewStringAti = uuid.NewString var uuidNewStringAti = uuid.NewString
var uuidNewStringRti = uuid.NewString var uuidNewStringRti = uuid.NewString
var testOa2Exchange = func(oa2conf oauth2.Config, ctx context.Context, code string) (*oauth2.Token, error) {
return oa2conf.Exchange(ctx, code)
}
var testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) {
client := oidc.OAuth2Config.Client(ctx, exchange)
return client.Get(oidc.UserInfoEndpoint)
}
func (h *HttpServer) flowPopup(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { func (h *HttpServer) flowPopup(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
pages.RenderPageTemplate(rw, "flow-popup", map[string]any{ pages.RenderPageTemplate(rw, "flow-popup", map[string]any{
"ServiceName": h.conf.ServiceName, "ServiceName": h.conf.ServiceName,
@ -67,7 +77,7 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
state := q.Get("state") state := q.Get("state")
n := strings.IndexByte(state, ':') n := strings.IndexByte(state, ':')
if !h.manager.CheckNamespace(state[:n]) { if !h.manager.CheckNamespace(state[:n]) {
http.Error(rw, "Invalid state", http.StatusBadRequest) http.Error(rw, "Invalid state namespace", http.StatusBadRequest)
return return
} }
v, found := h.flowState.Get(state) v, found := h.flowState.Get(state)
@ -78,14 +88,13 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
oa2conf := v.sso.OAuth2Config oa2conf := v.sso.OAuth2Config
oa2conf.RedirectURL = h.conf.BaseUrl + "/callback" oa2conf.RedirectURL = h.conf.BaseUrl + "/callback"
exchange, err := oa2conf.Exchange(context.Background(), q.Get("code")) exchange, err := testOa2Exchange(oa2conf, context.Background(), q.Get("code"))
if err != nil { if err != nil {
fmt.Println("Failed exchange:", err) fmt.Println("Failed exchange:", err)
http.Error(rw, "Failed to exchange code", http.StatusInternalServerError) http.Error(rw, "Failed to exchange code", http.StatusInternalServerError)
return return
} }
client := v.sso.OAuth2Config.Client(req.Context(), exchange) v2, err := testOa2UserInfo(v.sso, req.Context(), exchange)
v2, err := client.Get(v.sso.UserInfoEndpoint)
if err != nil { if err != nil {
fmt.Println("Failed to get userinfo:", err) fmt.Println("Failed to get userinfo:", err)
http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError) http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError)
@ -93,25 +102,25 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
} }
defer v2.Body.Close() defer v2.Body.Close()
if v2.StatusCode != http.StatusOK { if v2.StatusCode != http.StatusOK {
http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError) http.Error(rw, "Failed to get userinfo: unexpected status code", http.StatusInternalServerError)
return return
} }
var v3 map[string]any var v3 map[string]any
if err = json.NewDecoder(v2.Body).Decode(&v3); err != nil { if err = json.NewDecoder(v2.Body).Decode(&v3); err != nil {
fmt.Println("Failed to decode userinfo:", err) fmt.Println("Failed to decode userinfo:", err)
http.Error(rw, "Failed to decode userinfo JSON", http.StatusInternalServerError) http.Error(rw, "Failed to decode userinfo", http.StatusInternalServerError)
return return
} }
sub, ok := v3["sub"].(string) sub, ok := v3["sub"].(string)
if !ok { if !ok {
http.Error(rw, "Invalid value in userinfo", http.StatusInternalServerError) http.Error(rw, "Invalid subject in userinfo", http.StatusInternalServerError)
return return
} }
aud, ok := v3["aud"].(string) aud, ok := v3["aud"].(string)
if !ok { if !ok {
http.Error(rw, "Invalid value in userinfo", http.StatusInternalServerError) http.Error(rw, "Invalid audience in userinfo", http.StatusInternalServerError)
return return
} }

View File

@ -1,11 +1,17 @@
package server package server
import ( import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"errors"
"fmt" "fmt"
"github.com/1f349/cache" "github.com/1f349/cache"
"github.com/1f349/lavender/issuer" "github.com/1f349/lavender/issuer"
"github.com/1f349/lavender/server/pages" "github.com/1f349/lavender/server/pages"
"github.com/1f349/lavender/utils" "github.com/1f349/lavender/utils"
"github.com/MrMelon54/mjwt"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -15,17 +21,73 @@ import (
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"time"
"unicode"
) )
const lavenderDomain = "http://localhost:0" const lavenderDomain = "http://localhost:0"
const clientAppDomain = "http://localhost:1" const clientAppDomain = "http://localhost:1"
const loginDomain = "http://localhost:2" const loginDomain = "http://localhost:2"
var testSigner mjwt.Signer
var testOidc = &issuer.WellKnownOIDC{
Config: issuer.SsoConfig{
Addr: utils.JsonUrl{},
Namespace: "example.com",
Client: issuer.SsoConfigClient{
ID: "test-id",
Secret: "test-secret",
Scopes: []string{"openid"},
},
},
Issuer: "https://example.com",
AuthorizationEndpoint: loginDomain + "/authorize",
TokenEndpoint: loginDomain + "/token",
UserInfoEndpoint: loginDomain + "/userinfo",
ResponseTypesSupported: nil,
ScopesSupported: nil,
ClaimsSupported: nil,
GrantTypesSupported: nil,
OAuth2Config: oauth2.Config{
ClientID: "test-id",
ClientSecret: "test-secret",
Endpoint: oauth2.Endpoint{
AuthURL: loginDomain + "/authorize",
TokenURL: loginDomain + "/token",
AuthStyle: oauth2.AuthStyleInHeader,
},
Scopes: nil,
},
}
var testManager = issuer.NewManagerForTests([]*issuer.WellKnownOIDC{testOidc})
var testHttpServer = HttpServer{
r: nil,
conf: Conf{
BaseUrl: lavenderDomain,
ServiceName: "Test Lavender Service",
},
manager: testManager,
flowState: cache.New[string, flowStateData](),
services: map[string]struct{}{
clientAppDomain: {},
},
}
func init() { func init() {
err := pages.LoadPages("") err := pages.LoadPages("")
if err != nil { if err != nil {
panic(err) panic(err)
} }
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(err)
}
testSigner = mjwt.NewMJwtSigner("https://example.com", key)
testHttpServer.signer = testSigner
} }
func TestFlowPopup(t *testing.T) { func TestFlowPopup(t *testing.T) {
@ -59,47 +121,6 @@ func TestFlowPopup(t *testing.T) {
} }
func TestFlowPopupPost(t *testing.T) { func TestFlowPopupPost(t *testing.T) {
manager := issuer.NewManagerForTests([]issuer.WellKnownOIDC{
{
Config: issuer.SsoConfig{
Addr: utils.JsonUrl{},
Namespace: "example.com",
Client: issuer.SsoConfigClient{
ID: "test-id",
Secret: "test-secret",
Scopes: []string{"openid"},
},
},
Issuer: "https://example.com",
AuthorizationEndpoint: loginDomain + "/authorize",
TokenEndpoint: loginDomain + "/token",
UserInfoEndpoint: loginDomain + "/userinfo",
ResponseTypesSupported: nil,
ScopesSupported: nil,
ClaimsSupported: nil,
GrantTypesSupported: nil,
OAuth2Config: oauth2.Config{
ClientID: "test-id",
ClientSecret: "test-secret",
Endpoint: oauth2.Endpoint{
AuthURL: loginDomain + "/authorize",
TokenURL: loginDomain + "/token",
AuthStyle: oauth2.AuthStyleInHeader,
},
Scopes: nil,
},
},
})
h := HttpServer{
r: nil,
conf: Conf{BaseUrl: lavenderDomain},
manager: manager,
flowState: cache.New[string, flowStateData](),
services: map[string]struct{}{
clientAppDomain: {},
},
}
// test no login service error // test no login service error
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/popup", strings.NewReader(url.Values{ req := httptest.NewRequest(http.MethodPost, "/popup", strings.NewReader(url.Values{
@ -107,7 +128,7 @@ func TestFlowPopupPost(t *testing.T) {
"origin": []string{clientAppDomain}, "origin": []string{clientAppDomain},
}.Encode())) }.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
h.flowPopupPost(rec, req, httprouter.Params{}) testHttpServer.flowPopupPost(rec, req, httprouter.Params{})
assert.Equal(t, http.StatusBadRequest, rec.Code) assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Equal(t, "No login service defined for this username\n", rec.Body.String()) assert.Equal(t, "No login service defined for this username\n", rec.Body.String())
@ -118,7 +139,7 @@ func TestFlowPopupPost(t *testing.T) {
"origin": []string{"http://localhost:1010"}, "origin": []string{"http://localhost:1010"},
}.Encode())) }.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
h.flowPopupPost(rec, req, httprouter.Params{}) testHttpServer.flowPopupPost(rec, req, httprouter.Params{})
assert.Equal(t, http.StatusBadRequest, rec.Code) assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Equal(t, "Invalid target origin\n", rec.Body.String()) assert.Equal(t, "Invalid target origin\n", rec.Body.String())
@ -131,7 +152,7 @@ func TestFlowPopupPost(t *testing.T) {
"origin": []string{clientAppDomain}, "origin": []string{clientAppDomain},
}.Encode())) }.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
h.flowPopupPost(rec, req, httprouter.Params{}) testHttpServer.flowPopupPost(rec, req, httprouter.Params{})
assert.Equal(t, http.StatusFound, rec.Code) assert.Equal(t, http.StatusFound, rec.Code)
assert.Equal(t, "", rec.Body.String()) assert.Equal(t, "", rec.Body.String())
assert.Equal(t, loginDomain+"/authorize?"+url.Values{ assert.Equal(t, loginDomain+"/authorize?"+url.Values{
@ -142,3 +163,263 @@ func TestFlowPopupPost(t *testing.T) {
"state": []string{"example.com:" + nextState}, "state": []string{"example.com:" + nextState},
}.Encode(), rec.Header().Get("Location")) }.Encode(), rec.Header().Get("Location"))
} }
func TestFlowCallback(t *testing.T) {
expiryTime := time.Now().Add(15 * time.Minute)
nextState := uuid.NewString()
testHttpServer.flowState.Set("example.com:"+nextState, flowStateData{
sso: testOidc,
targetOrigin: clientAppDomain,
}, expiryTime)
testOa2Exchange = func(oa2conf oauth2.Config, ctx context.Context, code string) (*oauth2.Token, error) {
return nil, errors.New("no exchange should be made")
}
testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) {
return nil, errors.New("no userinfo should be fetched")
}
// test parse form error
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/callback?%+"+url.Values{
"state": []string{"example.com:" + nextState},
"origin": []string{clientAppDomain},
}.Encode(), nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
testHttpServer.flowCallback(rec, req, httprouter.Params{})
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Equal(t, "Error parsing form\n", rec.Body.String())
// test invalid namespace
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
"state": []string{"missing.example.com:" + nextState},
"origin": []string{clientAppDomain},
}.Encode(), nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
testHttpServer.flowCallback(rec, req, httprouter.Params{})
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Equal(t, "Invalid state namespace\n", rec.Body.String())
// test invalid state
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
"state": []string{"example.com:invalid"},
"origin": []string{clientAppDomain},
}.Encode(), nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
testHttpServer.flowCallback(rec, req, httprouter.Params{})
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Equal(t, "Invalid state\n", rec.Body.String())
// test failed exchange
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
"state": []string{"example.com:" + nextState},
"origin": []string{clientAppDomain},
}.Encode(), nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
testHttpServer.flowCallback(rec, req, httprouter.Params{})
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Equal(t, "Failed to exchange code\n", rec.Body.String())
testOa2Exchange = func(oa2conf oauth2.Config, ctx context.Context, code string) (*oauth2.Token, error) {
return &oauth2.Token{
AccessToken: "abcd1234",
TokenType: "",
RefreshToken: "efgh5678",
Expiry: expiryTime,
}, nil
}
// test failed userinfo
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
"state": []string{"example.com:" + nextState},
"origin": []string{clientAppDomain},
}.Encode(), nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
testHttpServer.flowCallback(rec, req, httprouter.Params{})
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Equal(t, "Failed to get userinfo\n", rec.Body.String())
testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) {
rec := httptest.NewRecorder()
rec.WriteHeader(http.StatusInternalServerError)
return rec.Result(), nil
}
// test failed userinfo status code
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
"state": []string{"example.com:" + nextState},
"origin": []string{clientAppDomain},
}.Encode(), nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
testHttpServer.flowCallback(rec, req, httprouter.Params{})
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Equal(t, "Failed to get userinfo: unexpected status code\n", rec.Body.String())
testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) {
rec := httptest.NewRecorder()
rec.WriteHeader(http.StatusOK)
_, _ = rec.Body.WriteString("{")
return rec.Result(), nil
}
// test failed userinfo decode
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
"state": []string{"example.com:" + nextState},
"origin": []string{clientAppDomain},
}.Encode(), nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
testHttpServer.flowCallback(rec, req, httprouter.Params{})
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Equal(t, "Failed to decode userinfo\n", rec.Body.String())
testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) {
rec := httptest.NewRecorder()
rec.WriteHeader(http.StatusOK)
_, _ = rec.Body.WriteString("{\"sub\":1}")
return rec.Result(), nil
}
// test invalid subject in userinfo
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
"state": []string{"example.com:" + nextState},
"origin": []string{clientAppDomain},
}.Encode(), nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
testHttpServer.flowCallback(rec, req, httprouter.Params{})
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Equal(t, "Invalid subject in userinfo\n", rec.Body.String())
testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) {
rec := httptest.NewRecorder()
rec.WriteHeader(http.StatusOK)
_, _ = rec.Body.WriteString("{\"sub\":\"1\",\"aud\":1}")
return rec.Result(), nil
}
// test invalid audience in userinfo
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
"state": []string{"example.com:" + nextState},
"origin": []string{clientAppDomain},
}.Encode(), nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
testHttpServer.flowCallback(rec, req, httprouter.Params{})
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Equal(t, "Invalid audience in userinfo\n", rec.Body.String())
testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) {
rec := httptest.NewRecorder()
rec.WriteHeader(http.StatusOK)
_, _ = rec.Body.WriteString(fmt.Sprintf(`{
"sub": "test-user",
"aud": "%s",
"test-field": "ok"
}
`, clientAppDomain))
return rec.Result(), nil
}
// test successful request
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
"state": []string{"example.com:" + nextState},
"origin": []string{clientAppDomain},
}.Encode(), nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
testHttpServer.flowCallback(rec, req, httprouter.Params{})
assert.Equal(t, http.StatusOK, rec.Code)
const p1 = `<!DOCTYPE html>
<html lang="en">
<head>
<title>Test Lavender Service</title>
<script>
let loginData = {
target:"%s",
userinfo:{"aud":"%s","sub":"test-user","test-field":"ok"},
tokens: `
const p2 = `,
};
window.addEventListener("load", function () {
window.opener.postMessage(loginData, loginData.target);
});
</script>
</head>
<body>
<header>
<h1>Test Lavender Service</h1>
</header>
<main id="mainBody">Loading...</main>
</body>
</html>
`
var p1v = fmt.Sprintf(p1, clientAppDomain, clientAppDomain)
a := make([]byte, len(p1v))
n, err := rec.Body.Read(a)
assert.NoError(t, err)
assert.Equal(t, len(p1v), n)
assert.Equal(t, p1v, string(a))
var accessToken, refreshToken string
findByte(rec.Body, '{')
findString(rec.Body, "access:")
readQuotedString(rec.Body, &accessToken)
findByte(rec.Body, ',')
findString(rec.Body, "refresh:")
readQuotedString(rec.Body, &refreshToken)
findByte(rec.Body, ',')
findByte(rec.Body, '}')
assert.Equal(t, p2, rec.Body.String())
}
func findByte(buf *bytes.Buffer, v byte) {
for {
readByte, err := buf.ReadByte()
if err != nil {
panic(err)
}
if readByte == v {
break
}
if !unicode.IsSpace(rune(readByte)) {
panic(fmt.Sprint("Found non space rune: ", readByte))
}
}
}
func findString(buf *bytes.Buffer, v string) {
if len(v) == 0 {
panic("Cannot find empty string")
}
findByte(buf, v[0])
if len(v) > 1 {
a2 := make([]byte, len(v)-1)
n, err := buf.Read(a2)
if err != nil {
panic(err)
}
if n != len(a2) {
panic("Probably found end of buffer")
}
if bytes.Compare([]byte(v[1:]), a2) != 0 {
panic("Failed to find string in buffer")
}
}
}
func readQuotedString(buf *bytes.Buffer, p *string) {
findByte(buf, '"')
b, err := buf.ReadBytes('"')
if err != nil {
panic(err)
}
*p = string(b)
}