diff --git a/issuer/manager.go b/issuer/manager.go index 88c62b5..fe63fe6 100644 --- a/issuer/manager.go +++ b/issuer/manager.go @@ -30,6 +30,17 @@ func NewManager(services []SsoConfig) (*Manager, error) { return l, nil } +func NewManagerForTests(services []WellKnownOIDC) *Manager { + l := &Manager{m: make(map[string]*WellKnownOIDC, len(services))} + for _, i := range services { + if !isValidNamespace.MatchString(i.Config.Namespace) { + panic("Invalid namespace in tests") + } + l.m[i.Config.Namespace] = &i + } + return l +} + func (l *Manager) CheckNamespace(namespace string) bool { _, ok := l.m[namespace] return ok diff --git a/issuer/sso.go b/issuer/sso.go index f208e4f..5d7d521 100644 --- a/issuer/sso.go +++ b/issuer/sso.go @@ -17,13 +17,15 @@ var httpGet = http.Get // SsoConfig is the base URL for an OAUTH/OPENID/SSO login service // The path `/.well-known/openid-configuration` should be available type SsoConfig struct { - Addr utils.JsonUrl `json:"addr"` // https://login.example.com - Namespace string `json:"namespace"` // example.com - Client struct { - ID string `json:"id"` - Secret string `json:"secret"` - Scopes []string `json:"scopes"` - } `json:"client"` + Addr utils.JsonUrl `json:"addr"` // https://login.example.com + Namespace string `json:"namespace"` // example.com + Client SsoConfigClient `json:"client"` +} + +type SsoConfigClient struct { + ID string `json:"id"` + Secret string `json:"secret"` + Scopes []string `json:"scopes"` } func (s SsoConfig) FetchConfig() (*WellKnownOIDC, error) { diff --git a/server/flow.go b/server/flow.go index 0aace24..7b73e83 100644 --- a/server/flow.go +++ b/server/flow.go @@ -12,20 +12,20 @@ import ( "github.com/google/uuid" "github.com/julienschmidt/httprouter" "golang.org/x/oauth2" - "log" "net/http" "strings" "time" ) +var uuidNewStringState = uuid.NewString +var uuidNewStringAti = uuid.NewString +var uuidNewStringRti = uuid.NewString + func (h *HttpServer) flowPopup(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { - err := pages.FlowTemplates.Execute(rw, map[string]any{ + pages.RenderPageTemplate(rw, "flow-popup", map[string]any{ "ServiceName": h.conf.ServiceName, "Origin": req.URL.Query().Get("origin"), }) - if err != nil { - log.Printf("Failed to render page: %s\n", err) - } } func (h *HttpServer) flowPopupPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { @@ -43,7 +43,7 @@ func (h *HttpServer) flowPopupPost(rw http.ResponseWriter, req *http.Request, _ } // save state for use later - state := login.Config.Namespace + ":" + uuid.NewString() + state := login.Config.Namespace + ":" + uuidNewStringState() h.flowState.Set(state, flowStateData{ login, targetOrigin, @@ -117,7 +117,7 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h ps := claims.NewPermStorage() nsSub := sub + "@" + v.sso.Config.Namespace - ati := uuid.NewString() + ati := uuidNewStringAti() accessToken, err := h.signer.GenerateJwt(nsSub, ati, jwt.ClaimStrings{aud}, 15*time.Minute, auth.AccessTokenClaims{ Perms: ps, }) @@ -126,13 +126,13 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h return } - refreshToken, err := h.signer.GenerateJwt(nsSub, uuid.NewString(), jwt.ClaimStrings{aud}, 15*time.Minute, auth.RefreshTokenClaims{AccessTokenId: ati}) + refreshToken, err := h.signer.GenerateJwt(nsSub, uuidNewStringRti(), jwt.ClaimStrings{aud}, 15*time.Minute, auth.RefreshTokenClaims{AccessTokenId: ati}) if err != nil { http.Error(rw, "Error generating refresh token", http.StatusInternalServerError) return } - _ = pages.FlowTemplates.Execute(rw, map[string]any{ + pages.RenderPageTemplate(rw, "flow-callback", map[string]any{ "ServiceName": h.conf.ServiceName, "TargetOrigin": v.targetOrigin, "TargetMessage": v3, diff --git a/server/flow_test.go b/server/flow_test.go index abb4e43..f0d3c90 100644 --- a/server/flow_test.go +++ b/server/flow_test.go @@ -1 +1,144 @@ package server + +import ( + "fmt" + "github.com/1f349/cache" + "github.com/1f349/lavender/issuer" + "github.com/1f349/lavender/server/pages" + "github.com/1f349/lavender/utils" + "github.com/google/uuid" + "github.com/julienschmidt/httprouter" + "github.com/stretchr/testify/assert" + "golang.org/x/oauth2" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +const lavenderDomain = "http://localhost:0" +const clientAppDomain = "http://localhost:1" +const loginDomain = "http://localhost:2" + +func init() { + err := pages.LoadPages("") + if err != nil { + panic(err) + } +} + +func TestFlowPopup(t *testing.T) { + h := HttpServer{conf: Conf{ServiceName: "Test Service Name"}} + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/popup?"+url.Values{"origin": []string{clientAppDomain}}.Encode(), nil) + h.flowPopup(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, fmt.Sprintf(` + + + Test Service Name + + +
+

Test Service Name

+
+
+
+ +
+ + +
+ +
+
+ + +`, clientAppDomain), rec.Body.String()) +} + +func TestFlowPopupPost(t *testing.T) { + manager := issuer.NewManagerForTests([]issuer.WellKnownOIDC{ + { + Config: issuer.SsoConfig{ + Addr: utils.JsonUrl{}, + Namespace: "example.com", + Client: issuer.SsoConfigClient{ + ID: "test-id", + Secret: "test-secret", + Scopes: []string{"openid"}, + }, + }, + Issuer: "https://example.com", + AuthorizationEndpoint: loginDomain + "/authorize", + TokenEndpoint: loginDomain + "/token", + UserInfoEndpoint: loginDomain + "/userinfo", + ResponseTypesSupported: nil, + ScopesSupported: nil, + ClaimsSupported: nil, + GrantTypesSupported: nil, + OAuth2Config: oauth2.Config{ + ClientID: "test-id", + ClientSecret: "test-secret", + Endpoint: oauth2.Endpoint{ + AuthURL: loginDomain + "/authorize", + TokenURL: loginDomain + "/token", + AuthStyle: oauth2.AuthStyleInHeader, + }, + Scopes: nil, + }, + }, + }) + h := HttpServer{ + r: nil, + conf: Conf{BaseUrl: lavenderDomain}, + manager: manager, + flowState: cache.New[string, flowStateData](), + services: map[string]struct{}{ + clientAppDomain: {}, + }, + } + + // test no login service error + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/popup", strings.NewReader(url.Values{ + "loginname": []string{"test@missing.example.com"}, + "origin": []string{clientAppDomain}, + }.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + h.flowPopupPost(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Equal(t, "No login service defined for this username\n", rec.Body.String()) + + // test invalid target origin error + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/popup", strings.NewReader(url.Values{ + "loginname": []string{"test@example.com"}, + "origin": []string{"http://localhost:1010"}, + }.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + h.flowPopupPost(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Equal(t, "Invalid target origin\n", rec.Body.String()) + + // test successful request + nextState := uuid.NewString() + uuidNewStringState = func() string { return nextState } + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/popup", strings.NewReader(url.Values{ + "loginname": []string{"test@example.com"}, + "origin": []string{clientAppDomain}, + }.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + h.flowPopupPost(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusFound, rec.Code) + assert.Equal(t, "", rec.Body.String()) + assert.Equal(t, loginDomain+"/authorize?"+url.Values{ + "client_id": []string{"test-id"}, + "login_name": []string{"test@example.com"}, + "redirect_uri": []string{lavenderDomain + "/callback"}, + "response_type": []string{"code"}, + "state": []string{"example.com:" + nextState}, + }.Encode(), rec.Header().Get("Location")) +} diff --git a/server/pages/pages.go b/server/pages/pages.go index 765c18c..a0c75b0 100644 --- a/server/pages/pages.go +++ b/server/pages/pages.go @@ -5,24 +5,41 @@ import ( _ "embed" "github.com/1f349/overlapfs" "html/template" + "io" + "io/fs" + "log" "os" "path/filepath" + "sync" ) var ( //go:embed *.go.html flowPages embed.FS - FlowTemplates *template.Template + flowTemplates *template.Template + loadOnce sync.Once ) -func LoadPages(wd string) error { - wwwDir := filepath.Join(wd, "www") - err := os.Mkdir(wwwDir, os.ModePerm) - if err != nil { - return nil - } - wdFs := os.DirFS(wwwDir) - o := overlapfs.OverlapFS{A: flowPages, B: wdFs} - FlowTemplates, err = template.ParseFS(o, "*.go.html") +func LoadPages(wd string) (err error) { + loadOnce.Do(func() { + var o fs.FS = flowPages + if wd != "" { + wwwDir := filepath.Join(wd, "www") + err = os.Mkdir(wwwDir, os.ModePerm) + if err != nil { + return + } + wdFs := os.DirFS(wwwDir) + o = overlapfs.OverlapFS{A: flowPages, B: wdFs} + } + flowTemplates, err = template.ParseFS(o, "*.go.html") + }) return err } + +func RenderPageTemplate(wr io.Writer, name string, data any) { + err := flowTemplates.ExecuteTemplate(wr, name+".go.html", data) + if err != nil { + log.Printf("Failed to render page: %s: %s\n", name, err) + } +}