mirror of
https://github.com/1f349/lavender.git
synced 2025-01-11 17:26:28 +00:00
Write flow tests
This commit is contained in:
parent
1280c30c5e
commit
b15aa72ec0
@ -30,13 +30,13 @@ func NewManager(services []SsoConfig) (*Manager, error) {
|
|||||||
return l, nil
|
return l, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManagerForTests(services []WellKnownOIDC) *Manager {
|
func NewManagerForTests(services []*WellKnownOIDC) *Manager {
|
||||||
l := &Manager{m: make(map[string]*WellKnownOIDC, len(services))}
|
l := &Manager{m: make(map[string]*WellKnownOIDC, len(services))}
|
||||||
for _, i := range services {
|
for _, i := range services {
|
||||||
if !isValidNamespace.MatchString(i.Config.Namespace) {
|
if !isValidNamespace.MatchString(i.Config.Namespace) {
|
||||||
panic("Invalid namespace in tests")
|
panic("Invalid namespace in tests")
|
||||||
}
|
}
|
||||||
l.m[i.Config.Namespace] = &i
|
l.m[i.Config.Namespace] = i
|
||||||
}
|
}
|
||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
_ "embed"
|
_ "embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/1f349/lavender/issuer"
|
||||||
"github.com/1f349/lavender/server/pages"
|
"github.com/1f349/lavender/server/pages"
|
||||||
"github.com/MrMelon54/mjwt/auth"
|
"github.com/MrMelon54/mjwt/auth"
|
||||||
"github.com/MrMelon54/mjwt/claims"
|
"github.com/MrMelon54/mjwt/claims"
|
||||||
@ -21,6 +22,15 @@ var uuidNewStringState = uuid.NewString
|
|||||||
var uuidNewStringAti = uuid.NewString
|
var uuidNewStringAti = uuid.NewString
|
||||||
var uuidNewStringRti = uuid.NewString
|
var uuidNewStringRti = uuid.NewString
|
||||||
|
|
||||||
|
var testOa2Exchange = func(oa2conf oauth2.Config, ctx context.Context, code string) (*oauth2.Token, error) {
|
||||||
|
return oa2conf.Exchange(ctx, code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) {
|
||||||
|
client := oidc.OAuth2Config.Client(ctx, exchange)
|
||||||
|
return client.Get(oidc.UserInfoEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *HttpServer) flowPopup(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
|
func (h *HttpServer) flowPopup(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
|
||||||
pages.RenderPageTemplate(rw, "flow-popup", map[string]any{
|
pages.RenderPageTemplate(rw, "flow-popup", map[string]any{
|
||||||
"ServiceName": h.conf.ServiceName,
|
"ServiceName": h.conf.ServiceName,
|
||||||
@ -67,7 +77,7 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
|
|||||||
state := q.Get("state")
|
state := q.Get("state")
|
||||||
n := strings.IndexByte(state, ':')
|
n := strings.IndexByte(state, ':')
|
||||||
if !h.manager.CheckNamespace(state[:n]) {
|
if !h.manager.CheckNamespace(state[:n]) {
|
||||||
http.Error(rw, "Invalid state", http.StatusBadRequest)
|
http.Error(rw, "Invalid state namespace", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
v, found := h.flowState.Get(state)
|
v, found := h.flowState.Get(state)
|
||||||
@ -78,14 +88,13 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
|
|||||||
|
|
||||||
oa2conf := v.sso.OAuth2Config
|
oa2conf := v.sso.OAuth2Config
|
||||||
oa2conf.RedirectURL = h.conf.BaseUrl + "/callback"
|
oa2conf.RedirectURL = h.conf.BaseUrl + "/callback"
|
||||||
exchange, err := oa2conf.Exchange(context.Background(), q.Get("code"))
|
exchange, err := testOa2Exchange(oa2conf, context.Background(), q.Get("code"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Failed exchange:", err)
|
fmt.Println("Failed exchange:", err)
|
||||||
http.Error(rw, "Failed to exchange code", http.StatusInternalServerError)
|
http.Error(rw, "Failed to exchange code", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
client := v.sso.OAuth2Config.Client(req.Context(), exchange)
|
v2, err := testOa2UserInfo(v.sso, req.Context(), exchange)
|
||||||
v2, err := client.Get(v.sso.UserInfoEndpoint)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Failed to get userinfo:", err)
|
fmt.Println("Failed to get userinfo:", err)
|
||||||
http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError)
|
http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError)
|
||||||
@ -93,25 +102,25 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
|
|||||||
}
|
}
|
||||||
defer v2.Body.Close()
|
defer v2.Body.Close()
|
||||||
if v2.StatusCode != http.StatusOK {
|
if v2.StatusCode != http.StatusOK {
|
||||||
http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError)
|
http.Error(rw, "Failed to get userinfo: unexpected status code", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var v3 map[string]any
|
var v3 map[string]any
|
||||||
if err = json.NewDecoder(v2.Body).Decode(&v3); err != nil {
|
if err = json.NewDecoder(v2.Body).Decode(&v3); err != nil {
|
||||||
fmt.Println("Failed to decode userinfo:", err)
|
fmt.Println("Failed to decode userinfo:", err)
|
||||||
http.Error(rw, "Failed to decode userinfo JSON", http.StatusInternalServerError)
|
http.Error(rw, "Failed to decode userinfo", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sub, ok := v3["sub"].(string)
|
sub, ok := v3["sub"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
http.Error(rw, "Invalid value in userinfo", http.StatusInternalServerError)
|
http.Error(rw, "Invalid subject in userinfo", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
aud, ok := v3["aud"].(string)
|
aud, ok := v3["aud"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
http.Error(rw, "Invalid value in userinfo", http.StatusInternalServerError)
|
http.Error(rw, "Invalid audience in userinfo", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,11 +1,17 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/1f349/cache"
|
"github.com/1f349/cache"
|
||||||
"github.com/1f349/lavender/issuer"
|
"github.com/1f349/lavender/issuer"
|
||||||
"github.com/1f349/lavender/server/pages"
|
"github.com/1f349/lavender/server/pages"
|
||||||
"github.com/1f349/lavender/utils"
|
"github.com/1f349/lavender/utils"
|
||||||
|
"github.com/MrMelon54/mjwt"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@ -15,17 +21,73 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
"unicode"
|
||||||
)
|
)
|
||||||
|
|
||||||
const lavenderDomain = "http://localhost:0"
|
const lavenderDomain = "http://localhost:0"
|
||||||
const clientAppDomain = "http://localhost:1"
|
const clientAppDomain = "http://localhost:1"
|
||||||
const loginDomain = "http://localhost:2"
|
const loginDomain = "http://localhost:2"
|
||||||
|
|
||||||
|
var testSigner mjwt.Signer
|
||||||
|
|
||||||
|
var testOidc = &issuer.WellKnownOIDC{
|
||||||
|
Config: issuer.SsoConfig{
|
||||||
|
Addr: utils.JsonUrl{},
|
||||||
|
Namespace: "example.com",
|
||||||
|
Client: issuer.SsoConfigClient{
|
||||||
|
ID: "test-id",
|
||||||
|
Secret: "test-secret",
|
||||||
|
Scopes: []string{"openid"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Issuer: "https://example.com",
|
||||||
|
AuthorizationEndpoint: loginDomain + "/authorize",
|
||||||
|
TokenEndpoint: loginDomain + "/token",
|
||||||
|
UserInfoEndpoint: loginDomain + "/userinfo",
|
||||||
|
ResponseTypesSupported: nil,
|
||||||
|
ScopesSupported: nil,
|
||||||
|
ClaimsSupported: nil,
|
||||||
|
GrantTypesSupported: nil,
|
||||||
|
OAuth2Config: oauth2.Config{
|
||||||
|
ClientID: "test-id",
|
||||||
|
ClientSecret: "test-secret",
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: loginDomain + "/authorize",
|
||||||
|
TokenURL: loginDomain + "/token",
|
||||||
|
AuthStyle: oauth2.AuthStyleInHeader,
|
||||||
|
},
|
||||||
|
Scopes: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var testManager = issuer.NewManagerForTests([]*issuer.WellKnownOIDC{testOidc})
|
||||||
|
var testHttpServer = HttpServer{
|
||||||
|
r: nil,
|
||||||
|
conf: Conf{
|
||||||
|
BaseUrl: lavenderDomain,
|
||||||
|
ServiceName: "Test Lavender Service",
|
||||||
|
},
|
||||||
|
manager: testManager,
|
||||||
|
flowState: cache.New[string, flowStateData](),
|
||||||
|
services: map[string]struct{}{
|
||||||
|
clientAppDomain: {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
err := pages.LoadPages("")
|
err := pages.LoadPages("")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testSigner = mjwt.NewMJwtSigner("https://example.com", key)
|
||||||
|
testHttpServer.signer = testSigner
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFlowPopup(t *testing.T) {
|
func TestFlowPopup(t *testing.T) {
|
||||||
@ -59,47 +121,6 @@ func TestFlowPopup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFlowPopupPost(t *testing.T) {
|
func TestFlowPopupPost(t *testing.T) {
|
||||||
manager := issuer.NewManagerForTests([]issuer.WellKnownOIDC{
|
|
||||||
{
|
|
||||||
Config: issuer.SsoConfig{
|
|
||||||
Addr: utils.JsonUrl{},
|
|
||||||
Namespace: "example.com",
|
|
||||||
Client: issuer.SsoConfigClient{
|
|
||||||
ID: "test-id",
|
|
||||||
Secret: "test-secret",
|
|
||||||
Scopes: []string{"openid"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Issuer: "https://example.com",
|
|
||||||
AuthorizationEndpoint: loginDomain + "/authorize",
|
|
||||||
TokenEndpoint: loginDomain + "/token",
|
|
||||||
UserInfoEndpoint: loginDomain + "/userinfo",
|
|
||||||
ResponseTypesSupported: nil,
|
|
||||||
ScopesSupported: nil,
|
|
||||||
ClaimsSupported: nil,
|
|
||||||
GrantTypesSupported: nil,
|
|
||||||
OAuth2Config: oauth2.Config{
|
|
||||||
ClientID: "test-id",
|
|
||||||
ClientSecret: "test-secret",
|
|
||||||
Endpoint: oauth2.Endpoint{
|
|
||||||
AuthURL: loginDomain + "/authorize",
|
|
||||||
TokenURL: loginDomain + "/token",
|
|
||||||
AuthStyle: oauth2.AuthStyleInHeader,
|
|
||||||
},
|
|
||||||
Scopes: nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
h := HttpServer{
|
|
||||||
r: nil,
|
|
||||||
conf: Conf{BaseUrl: lavenderDomain},
|
|
||||||
manager: manager,
|
|
||||||
flowState: cache.New[string, flowStateData](),
|
|
||||||
services: map[string]struct{}{
|
|
||||||
clientAppDomain: {},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// test no login service error
|
// test no login service error
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(http.MethodPost, "/popup", strings.NewReader(url.Values{
|
req := httptest.NewRequest(http.MethodPost, "/popup", strings.NewReader(url.Values{
|
||||||
@ -107,7 +128,7 @@ func TestFlowPopupPost(t *testing.T) {
|
|||||||
"origin": []string{clientAppDomain},
|
"origin": []string{clientAppDomain},
|
||||||
}.Encode()))
|
}.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
h.flowPopupPost(rec, req, httprouter.Params{})
|
testHttpServer.flowPopupPost(rec, req, httprouter.Params{})
|
||||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
assert.Equal(t, "No login service defined for this username\n", rec.Body.String())
|
assert.Equal(t, "No login service defined for this username\n", rec.Body.String())
|
||||||
|
|
||||||
@ -118,7 +139,7 @@ func TestFlowPopupPost(t *testing.T) {
|
|||||||
"origin": []string{"http://localhost:1010"},
|
"origin": []string{"http://localhost:1010"},
|
||||||
}.Encode()))
|
}.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
h.flowPopupPost(rec, req, httprouter.Params{})
|
testHttpServer.flowPopupPost(rec, req, httprouter.Params{})
|
||||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
assert.Equal(t, "Invalid target origin\n", rec.Body.String())
|
assert.Equal(t, "Invalid target origin\n", rec.Body.String())
|
||||||
|
|
||||||
@ -131,7 +152,7 @@ func TestFlowPopupPost(t *testing.T) {
|
|||||||
"origin": []string{clientAppDomain},
|
"origin": []string{clientAppDomain},
|
||||||
}.Encode()))
|
}.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
h.flowPopupPost(rec, req, httprouter.Params{})
|
testHttpServer.flowPopupPost(rec, req, httprouter.Params{})
|
||||||
assert.Equal(t, http.StatusFound, rec.Code)
|
assert.Equal(t, http.StatusFound, rec.Code)
|
||||||
assert.Equal(t, "", rec.Body.String())
|
assert.Equal(t, "", rec.Body.String())
|
||||||
assert.Equal(t, loginDomain+"/authorize?"+url.Values{
|
assert.Equal(t, loginDomain+"/authorize?"+url.Values{
|
||||||
@ -142,3 +163,263 @@ func TestFlowPopupPost(t *testing.T) {
|
|||||||
"state": []string{"example.com:" + nextState},
|
"state": []string{"example.com:" + nextState},
|
||||||
}.Encode(), rec.Header().Get("Location"))
|
}.Encode(), rec.Header().Get("Location"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFlowCallback(t *testing.T) {
|
||||||
|
expiryTime := time.Now().Add(15 * time.Minute)
|
||||||
|
nextState := uuid.NewString()
|
||||||
|
testHttpServer.flowState.Set("example.com:"+nextState, flowStateData{
|
||||||
|
sso: testOidc,
|
||||||
|
targetOrigin: clientAppDomain,
|
||||||
|
}, expiryTime)
|
||||||
|
|
||||||
|
testOa2Exchange = func(oa2conf oauth2.Config, ctx context.Context, code string) (*oauth2.Token, error) {
|
||||||
|
return nil, errors.New("no exchange should be made")
|
||||||
|
}
|
||||||
|
testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) {
|
||||||
|
return nil, errors.New("no userinfo should be fetched")
|
||||||
|
}
|
||||||
|
|
||||||
|
// test parse form error
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/callback?%+"+url.Values{
|
||||||
|
"state": []string{"example.com:" + nextState},
|
||||||
|
"origin": []string{clientAppDomain},
|
||||||
|
}.Encode(), nil)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
testHttpServer.flowCallback(rec, req, httprouter.Params{})
|
||||||
|
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
assert.Equal(t, "Error parsing form\n", rec.Body.String())
|
||||||
|
|
||||||
|
// test invalid namespace
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
|
||||||
|
"state": []string{"missing.example.com:" + nextState},
|
||||||
|
"origin": []string{clientAppDomain},
|
||||||
|
}.Encode(), nil)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
testHttpServer.flowCallback(rec, req, httprouter.Params{})
|
||||||
|
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
assert.Equal(t, "Invalid state namespace\n", rec.Body.String())
|
||||||
|
|
||||||
|
// test invalid state
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
|
||||||
|
"state": []string{"example.com:invalid"},
|
||||||
|
"origin": []string{clientAppDomain},
|
||||||
|
}.Encode(), nil)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
testHttpServer.flowCallback(rec, req, httprouter.Params{})
|
||||||
|
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
assert.Equal(t, "Invalid state\n", rec.Body.String())
|
||||||
|
|
||||||
|
// test failed exchange
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
|
||||||
|
"state": []string{"example.com:" + nextState},
|
||||||
|
"origin": []string{clientAppDomain},
|
||||||
|
}.Encode(), nil)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
testHttpServer.flowCallback(rec, req, httprouter.Params{})
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||||
|
assert.Equal(t, "Failed to exchange code\n", rec.Body.String())
|
||||||
|
|
||||||
|
testOa2Exchange = func(oa2conf oauth2.Config, ctx context.Context, code string) (*oauth2.Token, error) {
|
||||||
|
return &oauth2.Token{
|
||||||
|
AccessToken: "abcd1234",
|
||||||
|
TokenType: "",
|
||||||
|
RefreshToken: "efgh5678",
|
||||||
|
Expiry: expiryTime,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// test failed userinfo
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
|
||||||
|
"state": []string{"example.com:" + nextState},
|
||||||
|
"origin": []string{clientAppDomain},
|
||||||
|
}.Encode(), nil)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
testHttpServer.flowCallback(rec, req, httprouter.Params{})
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||||
|
assert.Equal(t, "Failed to get userinfo\n", rec.Body.String())
|
||||||
|
|
||||||
|
testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
rec.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return rec.Result(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// test failed userinfo status code
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
|
||||||
|
"state": []string{"example.com:" + nextState},
|
||||||
|
"origin": []string{clientAppDomain},
|
||||||
|
}.Encode(), nil)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
testHttpServer.flowCallback(rec, req, httprouter.Params{})
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||||
|
assert.Equal(t, "Failed to get userinfo: unexpected status code\n", rec.Body.String())
|
||||||
|
|
||||||
|
testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
rec.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = rec.Body.WriteString("{")
|
||||||
|
return rec.Result(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// test failed userinfo decode
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
|
||||||
|
"state": []string{"example.com:" + nextState},
|
||||||
|
"origin": []string{clientAppDomain},
|
||||||
|
}.Encode(), nil)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
testHttpServer.flowCallback(rec, req, httprouter.Params{})
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||||
|
assert.Equal(t, "Failed to decode userinfo\n", rec.Body.String())
|
||||||
|
|
||||||
|
testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
rec.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = rec.Body.WriteString("{\"sub\":1}")
|
||||||
|
return rec.Result(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// test invalid subject in userinfo
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
|
||||||
|
"state": []string{"example.com:" + nextState},
|
||||||
|
"origin": []string{clientAppDomain},
|
||||||
|
}.Encode(), nil)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
testHttpServer.flowCallback(rec, req, httprouter.Params{})
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||||
|
assert.Equal(t, "Invalid subject in userinfo\n", rec.Body.String())
|
||||||
|
|
||||||
|
testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
rec.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = rec.Body.WriteString("{\"sub\":\"1\",\"aud\":1}")
|
||||||
|
return rec.Result(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// test invalid audience in userinfo
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
|
||||||
|
"state": []string{"example.com:" + nextState},
|
||||||
|
"origin": []string{clientAppDomain},
|
||||||
|
}.Encode(), nil)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
testHttpServer.flowCallback(rec, req, httprouter.Params{})
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||||
|
assert.Equal(t, "Invalid audience in userinfo\n", rec.Body.String())
|
||||||
|
|
||||||
|
testOa2UserInfo = func(oidc *issuer.WellKnownOIDC, ctx context.Context, exchange *oauth2.Token) (*http.Response, error) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
rec.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = rec.Body.WriteString(fmt.Sprintf(`{
|
||||||
|
"sub": "test-user",
|
||||||
|
"aud": "%s",
|
||||||
|
"test-field": "ok"
|
||||||
|
}
|
||||||
|
`, clientAppDomain))
|
||||||
|
return rec.Result(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// test successful request
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/callback?"+url.Values{
|
||||||
|
"state": []string{"example.com:" + nextState},
|
||||||
|
"origin": []string{clientAppDomain},
|
||||||
|
}.Encode(), nil)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
testHttpServer.flowCallback(rec, req, httprouter.Params{})
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
const p1 = `<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<title>Test Lavender Service</title>
|
||||||
|
<script>
|
||||||
|
let loginData = {
|
||||||
|
target:"%s",
|
||||||
|
userinfo:{"aud":"%s","sub":"test-user","test-field":"ok"},
|
||||||
|
tokens: `
|
||||||
|
const p2 = `,
|
||||||
|
};
|
||||||
|
window.addEventListener("load", function () {
|
||||||
|
window.opener.postMessage(loginData, loginData.target);
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<header>
|
||||||
|
<h1>Test Lavender Service</h1>
|
||||||
|
</header>
|
||||||
|
<main id="mainBody">Loading...</main>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
`
|
||||||
|
var p1v = fmt.Sprintf(p1, clientAppDomain, clientAppDomain)
|
||||||
|
|
||||||
|
a := make([]byte, len(p1v))
|
||||||
|
n, err := rec.Body.Read(a)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, len(p1v), n)
|
||||||
|
assert.Equal(t, p1v, string(a))
|
||||||
|
|
||||||
|
var accessToken, refreshToken string
|
||||||
|
findByte(rec.Body, '{')
|
||||||
|
findString(rec.Body, "access:")
|
||||||
|
readQuotedString(rec.Body, &accessToken)
|
||||||
|
findByte(rec.Body, ',')
|
||||||
|
findString(rec.Body, "refresh:")
|
||||||
|
readQuotedString(rec.Body, &refreshToken)
|
||||||
|
findByte(rec.Body, ',')
|
||||||
|
findByte(rec.Body, '}')
|
||||||
|
|
||||||
|
assert.Equal(t, p2, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func findByte(buf *bytes.Buffer, v byte) {
|
||||||
|
for {
|
||||||
|
readByte, err := buf.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if readByte == v {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !unicode.IsSpace(rune(readByte)) {
|
||||||
|
panic(fmt.Sprint("Found non space rune: ", readByte))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func findString(buf *bytes.Buffer, v string) {
|
||||||
|
if len(v) == 0 {
|
||||||
|
panic("Cannot find empty string")
|
||||||
|
}
|
||||||
|
findByte(buf, v[0])
|
||||||
|
if len(v) > 1 {
|
||||||
|
a2 := make([]byte, len(v)-1)
|
||||||
|
n, err := buf.Read(a2)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if n != len(a2) {
|
||||||
|
panic("Probably found end of buffer")
|
||||||
|
}
|
||||||
|
if bytes.Compare([]byte(v[1:]), a2) != 0 {
|
||||||
|
panic("Failed to find string in buffer")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readQuotedString(buf *bytes.Buffer, p *string) {
|
||||||
|
findByte(buf, '"')
|
||||||
|
b, err := buf.ReadBytes('"')
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
*p = string(b)
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user