From b15aa72ec082f878a19d32bd56faacf7f3e83a1b Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Mon, 9 Oct 2023 16:29:10 +0100 Subject: [PATCH] Write flow tests --- issuer/manager.go | 4 +- server/flow.go | 25 ++- server/flow_test.go | 369 ++++++++++++++++++++++++++++++++++++++------ 3 files changed, 344 insertions(+), 54 deletions(-) diff --git a/issuer/manager.go b/issuer/manager.go index fe63fe6..c9f7893 100644 --- a/issuer/manager.go +++ b/issuer/manager.go @@ -30,13 +30,13 @@ func NewManager(services []SsoConfig) (*Manager, error) { return l, nil } -func NewManagerForTests(services []WellKnownOIDC) *Manager { +func NewManagerForTests(services []*WellKnownOIDC) *Manager { l := &Manager{m: make(map[string]*WellKnownOIDC, len(services))} for _, i := range services { if !isValidNamespace.MatchString(i.Config.Namespace) { panic("Invalid namespace in tests") } - l.m[i.Config.Namespace] = &i + l.m[i.Config.Namespace] = i } return l } diff --git a/server/flow.go b/server/flow.go index 7b73e83..dab877d 100644 --- a/server/flow.go +++ b/server/flow.go @@ -5,6 +5,7 @@ import ( _ "embed" "encoding/json" "fmt" + "github.com/1f349/lavender/issuer" "github.com/1f349/lavender/server/pages" "github.com/MrMelon54/mjwt/auth" "github.com/MrMelon54/mjwt/claims" @@ -21,6 +22,15 @@ var uuidNewStringState = uuid.NewString var uuidNewStringAti = uuid.NewString var uuidNewStringRti = uuid.NewString +var testOa2Exchange = func(oa2conf oauth2.Config, ctx context.Context, code string) (*oauth2.Token, error) { + return oa2conf.Exchange(ctx, code) +} + +var testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) { + client := oidc.OAuth2Config.Client(ctx, exchange) + return client.Get(oidc.UserInfoEndpoint) +} + func (h *HttpServer) flowPopup(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { pages.RenderPageTemplate(rw, "flow-popup", map[string]any{ "ServiceName": h.conf.ServiceName, @@ -67,7 +77,7 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h state := q.Get("state") n := strings.IndexByte(state, ':') if !h.manager.CheckNamespace(state[:n]) { - http.Error(rw, "Invalid state", http.StatusBadRequest) + http.Error(rw, "Invalid state namespace", http.StatusBadRequest) return } v, found := h.flowState.Get(state) @@ -78,14 +88,13 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h oa2conf := v.sso.OAuth2Config oa2conf.RedirectURL = h.conf.BaseUrl + "/callback" - exchange, err := oa2conf.Exchange(context.Background(), q.Get("code")) + exchange, err := testOa2Exchange(oa2conf, context.Background(), q.Get("code")) if err != nil { fmt.Println("Failed exchange:", err) http.Error(rw, "Failed to exchange code", http.StatusInternalServerError) return } - client := v.sso.OAuth2Config.Client(req.Context(), exchange) - v2, err := client.Get(v.sso.UserInfoEndpoint) + v2, err := testOa2UserInfo(v.sso, req.Context(), exchange) if err != nil { fmt.Println("Failed to get userinfo:", err) http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError) @@ -93,25 +102,25 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h } defer v2.Body.Close() if v2.StatusCode != http.StatusOK { - http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError) + http.Error(rw, "Failed to get userinfo: unexpected status code", http.StatusInternalServerError) return } var v3 map[string]any if err = json.NewDecoder(v2.Body).Decode(&v3); err != nil { fmt.Println("Failed to decode userinfo:", err) - http.Error(rw, "Failed to decode userinfo JSON", http.StatusInternalServerError) + http.Error(rw, "Failed to decode userinfo", http.StatusInternalServerError) return } sub, ok := v3["sub"].(string) if !ok { - http.Error(rw, "Invalid value in userinfo", http.StatusInternalServerError) + http.Error(rw, "Invalid subject in userinfo", http.StatusInternalServerError) return } aud, ok := v3["aud"].(string) if !ok { - http.Error(rw, "Invalid value in userinfo", http.StatusInternalServerError) + http.Error(rw, "Invalid audience in userinfo", http.StatusInternalServerError) return } diff --git a/server/flow_test.go b/server/flow_test.go index f0d3c90..1a47a29 100644 --- a/server/flow_test.go +++ b/server/flow_test.go @@ -1,11 +1,17 @@ package server import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "errors" "fmt" "github.com/1f349/cache" "github.com/1f349/lavender/issuer" "github.com/1f349/lavender/server/pages" "github.com/1f349/lavender/utils" + "github.com/MrMelon54/mjwt" "github.com/google/uuid" "github.com/julienschmidt/httprouter" "github.com/stretchr/testify/assert" @@ -15,17 +21,73 @@ import ( "net/url" "strings" "testing" + "time" + "unicode" ) const lavenderDomain = "http://localhost:0" const clientAppDomain = "http://localhost:1" const loginDomain = "http://localhost:2" +var testSigner mjwt.Signer + +var testOidc = &issuer.WellKnownOIDC{ + Config: issuer.SsoConfig{ + Addr: utils.JsonUrl{}, + Namespace: "example.com", + Client: issuer.SsoConfigClient{ + ID: "test-id", + Secret: "test-secret", + Scopes: []string{"openid"}, + }, + }, + Issuer: "https://example.com", + AuthorizationEndpoint: loginDomain + "/authorize", + TokenEndpoint: loginDomain + "/token", + UserInfoEndpoint: loginDomain + "/userinfo", + ResponseTypesSupported: nil, + ScopesSupported: nil, + ClaimsSupported: nil, + GrantTypesSupported: nil, + OAuth2Config: oauth2.Config{ + ClientID: "test-id", + ClientSecret: "test-secret", + Endpoint: oauth2.Endpoint{ + AuthURL: loginDomain + "/authorize", + TokenURL: loginDomain + "/token", + AuthStyle: oauth2.AuthStyleInHeader, + }, + Scopes: nil, + }, +} + +var testManager = issuer.NewManagerForTests([]*issuer.WellKnownOIDC{testOidc}) +var testHttpServer = HttpServer{ + r: nil, + conf: Conf{ + BaseUrl: lavenderDomain, + ServiceName: "Test Lavender Service", + }, + manager: testManager, + flowState: cache.New[string, flowStateData](), + services: map[string]struct{}{ + clientAppDomain: {}, + }, +} + func init() { err := pages.LoadPages("") if err != nil { panic(err) } + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + + testSigner = mjwt.NewMJwtSigner("https://example.com", key) + testHttpServer.signer = testSigner } func TestFlowPopup(t *testing.T) { @@ -59,47 +121,6 @@ func TestFlowPopup(t *testing.T) { } func TestFlowPopupPost(t *testing.T) { - manager := issuer.NewManagerForTests([]issuer.WellKnownOIDC{ - { - Config: issuer.SsoConfig{ - Addr: utils.JsonUrl{}, - Namespace: "example.com", - Client: issuer.SsoConfigClient{ - ID: "test-id", - Secret: "test-secret", - Scopes: []string{"openid"}, - }, - }, - Issuer: "https://example.com", - AuthorizationEndpoint: loginDomain + "/authorize", - TokenEndpoint: loginDomain + "/token", - UserInfoEndpoint: loginDomain + "/userinfo", - ResponseTypesSupported: nil, - ScopesSupported: nil, - ClaimsSupported: nil, - GrantTypesSupported: nil, - OAuth2Config: oauth2.Config{ - ClientID: "test-id", - ClientSecret: "test-secret", - Endpoint: oauth2.Endpoint{ - AuthURL: loginDomain + "/authorize", - TokenURL: loginDomain + "/token", - AuthStyle: oauth2.AuthStyleInHeader, - }, - Scopes: nil, - }, - }, - }) - h := HttpServer{ - r: nil, - conf: Conf{BaseUrl: lavenderDomain}, - manager: manager, - flowState: cache.New[string, flowStateData](), - services: map[string]struct{}{ - clientAppDomain: {}, - }, - } - // test no login service error rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/popup", strings.NewReader(url.Values{ @@ -107,7 +128,7 @@ func TestFlowPopupPost(t *testing.T) { "origin": []string{clientAppDomain}, }.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - h.flowPopupPost(rec, req, httprouter.Params{}) + testHttpServer.flowPopupPost(rec, req, httprouter.Params{}) assert.Equal(t, http.StatusBadRequest, rec.Code) assert.Equal(t, "No login service defined for this username\n", rec.Body.String()) @@ -118,7 +139,7 @@ func TestFlowPopupPost(t *testing.T) { "origin": []string{"http://localhost:1010"}, }.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - h.flowPopupPost(rec, req, httprouter.Params{}) + testHttpServer.flowPopupPost(rec, req, httprouter.Params{}) assert.Equal(t, http.StatusBadRequest, rec.Code) assert.Equal(t, "Invalid target origin\n", rec.Body.String()) @@ -131,7 +152,7 @@ func TestFlowPopupPost(t *testing.T) { "origin": []string{clientAppDomain}, }.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - h.flowPopupPost(rec, req, httprouter.Params{}) + testHttpServer.flowPopupPost(rec, req, httprouter.Params{}) assert.Equal(t, http.StatusFound, rec.Code) assert.Equal(t, "", rec.Body.String()) assert.Equal(t, loginDomain+"/authorize?"+url.Values{ @@ -142,3 +163,263 @@ func TestFlowPopupPost(t *testing.T) { "state": []string{"example.com:" + nextState}, }.Encode(), rec.Header().Get("Location")) } + +func TestFlowCallback(t *testing.T) { + expiryTime := time.Now().Add(15 * time.Minute) + nextState := uuid.NewString() + testHttpServer.flowState.Set("example.com:"+nextState, flowStateData{ + sso: testOidc, + targetOrigin: clientAppDomain, + }, expiryTime) + + testOa2Exchange = func(oa2conf oauth2.Config, ctx context.Context, code string) (*oauth2.Token, error) { + return nil, errors.New("no exchange should be made") + } + testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) { + return nil, errors.New("no userinfo should be fetched") + } + + // test parse form error + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/callback?%+"+url.Values{ + "state": []string{"example.com:" + nextState}, + "origin": []string{clientAppDomain}, + }.Encode(), nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + testHttpServer.flowCallback(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Equal(t, "Error parsing form\n", rec.Body.String()) + + // test invalid namespace + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{ + "state": []string{"missing.example.com:" + nextState}, + "origin": []string{clientAppDomain}, + }.Encode(), nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + testHttpServer.flowCallback(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Equal(t, "Invalid state namespace\n", rec.Body.String()) + + // test invalid state + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{ + "state": []string{"example.com:invalid"}, + "origin": []string{clientAppDomain}, + }.Encode(), nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + testHttpServer.flowCallback(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Equal(t, "Invalid state\n", rec.Body.String()) + + // test failed exchange + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{ + "state": []string{"example.com:" + nextState}, + "origin": []string{clientAppDomain}, + }.Encode(), nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + testHttpServer.flowCallback(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, "Failed to exchange code\n", rec.Body.String()) + + testOa2Exchange = func(oa2conf oauth2.Config, ctx context.Context, code string) (*oauth2.Token, error) { + return &oauth2.Token{ + AccessToken: "abcd1234", + TokenType: "", + RefreshToken: "efgh5678", + Expiry: expiryTime, + }, nil + } + + // test failed userinfo + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{ + "state": []string{"example.com:" + nextState}, + "origin": []string{clientAppDomain}, + }.Encode(), nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + testHttpServer.flowCallback(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, "Failed to get userinfo\n", rec.Body.String()) + + testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) { + rec := httptest.NewRecorder() + rec.WriteHeader(http.StatusInternalServerError) + return rec.Result(), nil + } + + // test failed userinfo status code + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{ + "state": []string{"example.com:" + nextState}, + "origin": []string{clientAppDomain}, + }.Encode(), nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + testHttpServer.flowCallback(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, "Failed to get userinfo: unexpected status code\n", rec.Body.String()) + + testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) { + rec := httptest.NewRecorder() + rec.WriteHeader(http.StatusOK) + _, _ = rec.Body.WriteString("{") + return rec.Result(), nil + } + + // test failed userinfo decode + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{ + "state": []string{"example.com:" + nextState}, + "origin": []string{clientAppDomain}, + }.Encode(), nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + testHttpServer.flowCallback(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, "Failed to decode userinfo\n", rec.Body.String()) + + testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) { + rec := httptest.NewRecorder() + rec.WriteHeader(http.StatusOK) + _, _ = rec.Body.WriteString("{\"sub\":1}") + return rec.Result(), nil + } + + // test invalid subject in userinfo + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{ + "state": []string{"example.com:" + nextState}, + "origin": []string{clientAppDomain}, + }.Encode(), nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + testHttpServer.flowCallback(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, "Invalid subject in userinfo\n", rec.Body.String()) + + testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) { + rec := httptest.NewRecorder() + rec.WriteHeader(http.StatusOK) + _, _ = rec.Body.WriteString("{\"sub\":\"1\",\"aud\":1}") + return rec.Result(), nil + } + + // test invalid audience in userinfo + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{ + "state": []string{"example.com:" + nextState}, + "origin": []string{clientAppDomain}, + }.Encode(), nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + testHttpServer.flowCallback(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, "Invalid audience in userinfo\n", rec.Body.String()) + + testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) { + rec := httptest.NewRecorder() + rec.WriteHeader(http.StatusOK) + _, _ = rec.Body.WriteString(fmt.Sprintf(`{ + "sub": "test-user", + "aud": "%s", + "test-field": "ok" +} +`, clientAppDomain)) + return rec.Result(), nil + } + + // test successful request + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{ + "state": []string{"example.com:" + nextState}, + "origin": []string{clientAppDomain}, + }.Encode(), nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + testHttpServer.flowCallback(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusOK, rec.Code) + const p1 = ` + + + Test Lavender Service + + + +
+

Test Lavender Service

+
+
Loading...
+ + +` + var p1v = fmt.Sprintf(p1, clientAppDomain, clientAppDomain) + + a := make([]byte, len(p1v)) + n, err := rec.Body.Read(a) + assert.NoError(t, err) + assert.Equal(t, len(p1v), n) + assert.Equal(t, p1v, string(a)) + + var accessToken, refreshToken string + findByte(rec.Body, '{') + findString(rec.Body, "access:") + readQuotedString(rec.Body, &accessToken) + findByte(rec.Body, ',') + findString(rec.Body, "refresh:") + readQuotedString(rec.Body, &refreshToken) + findByte(rec.Body, ',') + findByte(rec.Body, '}') + + assert.Equal(t, p2, rec.Body.String()) +} + +func findByte(buf *bytes.Buffer, v byte) { + for { + readByte, err := buf.ReadByte() + if err != nil { + panic(err) + } + if readByte == v { + break + } + if !unicode.IsSpace(rune(readByte)) { + panic(fmt.Sprint("Found non space rune: ", readByte)) + } + } +} + +func findString(buf *bytes.Buffer, v string) { + if len(v) == 0 { + panic("Cannot find empty string") + } + findByte(buf, v[0]) + if len(v) > 1 { + a2 := make([]byte, len(v)-1) + n, err := buf.Read(a2) + if err != nil { + panic(err) + } + if n != len(a2) { + panic("Probably found end of buffer") + } + if bytes.Compare([]byte(v[1:]), a2) != 0 { + panic("Failed to find string in buffer") + } + } +} + +func readQuotedString(buf *bytes.Buffer, p *string) { + findByte(buf, '"') + b, err := buf.ReadBytes('"') + if err != nil { + panic(err) + } + *p = string(b) +}