diff --git a/cmd/lavender/conf.go b/cmd/lavender/conf.go index 3bcc34f..9637149 100644 --- a/cmd/lavender/conf.go +++ b/cmd/lavender/conf.go @@ -8,7 +8,6 @@ import ( type startUpConfig struct { Listen string `json:"listen"` BaseUrl string `json:"base_url"` - PrivateKey string `json:"private_key"` Issuer string `json:"issuer"` SsoServices []loginServiceManager.SsoConfig `json:"sso_services"` AllowedClients []utils.JsonUrl `json:"allowed_clients"` diff --git a/cmd/lavender/serve.go b/cmd/lavender/serve.go index 0f67a18..30ca916 100644 --- a/cmd/lavender/serve.go +++ b/cmd/lavender/serve.go @@ -74,10 +74,10 @@ func normalLoad(startUp startUpConfig, wd string) { manager, err := issuer.NewManager(startUp.SsoServices) if err != nil { - log.Fatal("[Lavender] Failed to create SSO service manager") + log.Fatal("[Lavender] Failed to create SSO service manager: ", err) } - srv := server.NewHttpServer(startUp.Listen, startUp.BaseUrl, manager, mSign) + srv := server.NewHttpServer(startUp.Listen, startUp.BaseUrl, startUp.AllowedClients, manager, mSign) log.Printf("[Lavender] Starting HTTP server on '%s'\n", srv.Addr) go utils.RunBackgroundHttp("HTTP", srv) diff --git a/go.mod b/go.mod index 3e75443..3e61370 100644 --- a/go.mod +++ b/go.mod @@ -10,15 +10,18 @@ require ( github.com/google/subcommands v1.2.0 github.com/google/uuid v1.3.1 github.com/julienschmidt/httprouter v1.3.0 + github.com/stretchr/testify v1.8.4 golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d ) require ( github.com/MrMelon54/rescheduler v0.0.2 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang-jwt/jwt/v4 v4.5.0 // indirect github.com/golang/protobuf v1.4.2 // indirect github.com/kr/text v0.2.0 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/net v0.9.0 // indirect google.golang.org/appengine v1.6.6 // indirect google.golang.org/protobuf v1.23.0 // indirect diff --git a/issuer/manager.go b/issuer/manager.go index aa2cf8e..88c62b5 100644 --- a/issuer/manager.go +++ b/issuer/manager.go @@ -1,5 +1,13 @@ package issuer +import ( + "fmt" + "regexp" + "strings" +) + +var isValidNamespace = regexp.MustCompile("^[0-9a-z.]+$") + type Manager struct { m map[string]*WellKnownOIDC } @@ -7,23 +15,36 @@ type Manager struct { func NewManager(services []SsoConfig) (*Manager, error) { l := &Manager{m: make(map[string]*WellKnownOIDC)} for _, i := range services { + if !isValidNamespace.MatchString(i.Namespace) { + return nil, fmt.Errorf("invalid namespace: %s", i.Namespace) + } + conf, err := i.FetchConfig() if err != nil { return nil, err } - // save by issuer - l.m[conf.Issuer] = conf + // save by namespace + l.m[i.Namespace] = conf } return l, nil } -func (l *Manager) CheckIssuer(issuer string) bool { - _, ok := l.m[issuer] +func (l *Manager) CheckNamespace(namespace string) bool { + _, ok := l.m[namespace] return ok } func (l *Manager) FindServiceFromLogin(login string) *WellKnownOIDC { - - return l.m[namespace] + // @ should have at least one byte before it + n := strings.IndexByte(login, '@') + if n < 1 { + return nil + } + // there should not be a second @ + n2 := strings.IndexByte(login[n+1:], '@') + if n2 != -1 { + return nil + } + return l.m[login[n+1:]] } diff --git a/issuer/manager_test.go b/issuer/manager_test.go new file mode 100644 index 0000000..bf829d0 --- /dev/null +++ b/issuer/manager_test.go @@ -0,0 +1,53 @@ +package issuer + +import ( + "github.com/1f349/lavender/utils" + "github.com/stretchr/testify/assert" + "io" + "net/http" + "net/url" + "strings" + "testing" +) + +var testAddrUrl = func() utils.JsonUrl { + a, err := url.Parse("https://example.com") + if err != nil { + panic(err) + } + return utils.JsonUrl{URL: a} +}() + +func testBody() io.ReadCloser { + return io.NopCloser(strings.NewReader("{}")) +} + +func TestManager_CheckIssuer(t *testing.T) { + httpGet = func(url string) (resp *http.Response, err error) { + return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil + } + manager, err := NewManager([]SsoConfig{ + { + Addr: testAddrUrl, + Namespace: "example.com", + }, + }) + assert.NoError(t, err) + assert.True(t, manager.CheckNamespace("example.com")) + assert.False(t, manager.CheckNamespace("missing.example.com")) +} + +func TestManager_FindServiceFromLogin(t *testing.T) { + httpGet = func(url string) (resp *http.Response, err error) { + return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil + } + manager, err := NewManager([]SsoConfig{ + { + Addr: testAddrUrl, + Namespace: "example.com", + }, + }) + assert.NoError(t, err) + assert.Equal(t, manager.FindServiceFromLogin("jane@example.com"), manager.m["example.com"]) + assert.Nil(t, manager.FindServiceFromLogin("jane@missing.example.com")) +} diff --git a/issuer/sso.go b/issuer/sso.go index 3621226..2533784 100644 --- a/issuer/sso.go +++ b/issuer/sso.go @@ -12,6 +12,8 @@ import ( "strings" ) +var httpGet = http.Get + // SsoConfig is the base URL for an OAUTH/OPENID/SSO login service // The path `/.well-known/openid-configuration` should be available type SsoConfig struct { @@ -33,7 +35,7 @@ func (s SsoConfig) FetchConfig() (*WellKnownOIDC, error) { u += ".well-known/openid-configuration" // fetch metadata - get, err := http.Get(u) + get, err := httpGet(u) if err != nil { return nil, err } @@ -41,19 +43,33 @@ func (s SsoConfig) FetchConfig() (*WellKnownOIDC, error) { var c WellKnownOIDC err = json.NewDecoder(get.Body).Decode(&c) - return &c, err + if err != nil { + return nil, err + } + c.OAuth2Config = oauth2.Config{ + ClientID: c.Config.Client.ID, + ClientSecret: c.Config.Client.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: c.AuthorizationEndpoint, + TokenURL: c.TokenEndpoint, + AuthStyle: oauth2.AuthStyleInHeader, + }, + Scopes: c.Config.Client.Scopes, + } + return &c, nil } type WellKnownOIDC struct { - Config SsoConfig `json:"-"` - Issuer string `json:"issuer"` - AuthorizationEndpoint string `json:"authorization_endpoint"` - TokenEndpoint string `json:"token_endpoint"` - UserInfoEndpoint string `json:"userinfo_endpoint"` - ResponseTypesSupported []string `json:"response_types_supported"` - ScopesSupported []string `json:"scopes_supported"` - ClaimsSupported []string `json:"claims_supported"` - GrantTypesSupported []string `json:"grant_types_supported"` + Config SsoConfig `json:"-"` + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserInfoEndpoint string `json:"userinfo_endpoint"` + ResponseTypesSupported []string `json:"response_types_supported"` + ScopesSupported []string `json:"scopes_supported"` + ClaimsSupported []string `json:"claims_supported"` + GrantTypesSupported []string `json:"grant_types_supported"` + OAuth2Config oauth2.Config `json:"-"` } func (o WellKnownOIDC) Validate() error { @@ -90,19 +106,6 @@ func (o WellKnownOIDC) Validate() error { return nil } -func (o WellKnownOIDC) Oauth2Config() oauth2.Config { - return oauth2.Config{ - ClientID: o.Config.Client.ID, - ClientSecret: o.Config.Client.Secret, - Endpoint: oauth2.Endpoint{ - AuthURL: o.AuthorizationEndpoint, - TokenURL: o.TokenEndpoint, - AuthStyle: oauth2.AuthStyleInHeader, - }, - Scopes: o.Config.Client.Scopes, - } -} - func (o WellKnownOIDC) ValidReturnUrl(u *url.URL) bool { - o.Config.Addr + return o.Config.Addr.Scheme == u.Scheme && o.Config.Addr.Host == u.Host } diff --git a/server/flow-callback.go.html b/server/flow-callback.go.html new file mode 100644 index 0000000..7c72a9d --- /dev/null +++ b/server/flow-callback.go.html @@ -0,0 +1,21 @@ + + + + {{.ServiceName}} + + + +
+

{{.ServiceName}}

+
+
Loading...
+ + diff --git a/server/flow.go b/server/flow.go index 00160fc..ca7302e 100644 --- a/server/flow.go +++ b/server/flow.go @@ -2,13 +2,13 @@ package server import ( _ "embed" + "encoding/json" "github.com/google/uuid" "github.com/julienschmidt/httprouter" "html/template" "log" "net/http" - "net/url" - "regexp" + "strings" "time" ) @@ -17,7 +17,9 @@ var ( flowPopupHtml string flowPopupTemplate *template.Template - isValidState = regexp.MustCompile("^[a-z.]+%[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$") + //go:embed flow-callback.go.html + flowCallbackHtml string + flowCallbackTemplate *template.Template ) func init() { @@ -26,12 +28,17 @@ func init() { log.Fatal("flow.go: Failed to parse flow popup HTML:", err) } flowPopupTemplate = pageParse + pageParse, err = template.New("pages").Parse(flowCallbackHtml) + if err != nil { + log.Fatal("flow.go: Failed to parse flow callback HTML:", err) + } + flowCallbackTemplate = pageParse } func (h *HttpServer) flowPopup(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { err := flowPopupTemplate.Execute(rw, map[string]any{ "ServiceName": flowPopupTemplate, - "Return": req.URL.Query().Get("return"), + "Origin": req.URL.Query().Get("origin"), }) if err != nil { log.Printf("Failed to render page: %s\n", err) @@ -45,13 +52,9 @@ func (h *HttpServer) flowPopupPost(rw http.ResponseWriter, req *http.Request, _ return } - returnUrl, err := url.Parse(req.PostFormValue("return")) - if err != nil { - http.Error(rw, "Invalid return URL", http.StatusBadRequest) - return - } - if !login.ValidReturnUrl(returnUrl) { - http.Error(rw, "Invalid return URL for this application", http.StatusBadRequest) + targetOrigin := req.PostFormValue("origin") + if _, found := h.services[targetOrigin]; !found { + http.Error(rw, "Invalid target origin", http.StatusBadRequest) return } @@ -59,11 +62,11 @@ func (h *HttpServer) flowPopupPost(rw http.ResponseWriter, req *http.Request, _ state := login.Config.Namespace + "%" + uuid.NewString() h.flowState.Set(state, flowStateData{ login, - returnUrl, + targetOrigin, }, time.Now().Add(15*time.Minute)) // generate oauth2 config and redirect to authorize URL - oa2conf := login.Oauth2Config() + oa2conf := login.OAuth2Config oa2conf.RedirectURL = h.baseUrl + "/callback" nextUrl := oa2conf.AuthCodeURL(state) http.Redirect(rw, req, nextUrl, http.StatusFound) @@ -78,11 +81,8 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h q := req.URL.Query() state := q.Get("state") - if !isValidState.MatchString(state) { - http.Error(rw, "Invalid state", http.StatusBadRequest) - return - } - if !h.manager.CheckIssuer(state) { + n := strings.IndexByte(state, '%') + if !h.manager.CheckNamespace(state[:n]) { http.Error(rw, "Invalid state", http.StatusBadRequest) return } @@ -92,5 +92,30 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h return } - // TODO: process flow callback + exchange, err := v.sso.OAuth2Config.Exchange(req.Context(), q.Get("code")) + if err != nil { + http.Error(rw, "Failed to exchange code", http.StatusInternalServerError) + return + } + client := v.sso.OAuth2Config.Client(req.Context(), exchange) + v2, err := client.Get(v.sso.UserInfoEndpoint) + if err != nil { + http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError) + return + } + defer v2.Body.Close() + if v2.StatusCode != http.StatusOK { + http.Error(rw, "Failed to get userinfo", http.StatusInternalServerError) + return + } + var v3 any + if json.NewDecoder(v2.Body).Decode(&v3) != nil { + http.Error(rw, "Failed to decode userinfo JSON", http.StatusInternalServerError) + return + } + + _ = flowCallbackTemplate.Execute(rw, map[string]any{ + "TargetOrigin": v.targetOrigin, + "TargetMessage": v3, + }) } diff --git a/server/server.go b/server/server.go index edd582d..fc6b16b 100644 --- a/server/server.go +++ b/server/server.go @@ -4,10 +4,10 @@ import ( "fmt" "github.com/1f349/cache" "github.com/1f349/lavender/issuer" + "github.com/1f349/lavender/utils" "github.com/MrMelon54/mjwt" "github.com/julienschmidt/httprouter" "net/http" - "net/url" "time" ) @@ -17,14 +17,15 @@ type HttpServer struct { manager *issuer.Manager signer mjwt.Signer flowState *cache.Cache[string, flowStateData] + services map[string]struct{} } type flowStateData struct { - sso *issuer.WellKnownOIDC - returnUrl *url.URL + sso *issuer.WellKnownOIDC + targetOrigin string } -func NewHttpServer(listen, baseUrl string, manager *issuer.Manager, signer mjwt.Signer) *http.Server { +func NewHttpServer(listen, baseUrl string, clients []utils.JsonUrl, manager *issuer.Manager, signer mjwt.Signer) *http.Server { r := httprouter.New() // remove last slash from baseUrl @@ -35,11 +36,17 @@ func NewHttpServer(listen, baseUrl string, manager *issuer.Manager, signer mjwt. } } + services := make(map[string]struct{}) + for _, i := range clients { + services[i.Host] = struct{}{} + } + hs := &HttpServer{ - r: r, - baseUrl: baseUrl, - manager: manager, - signer: signer, + r: r, + baseUrl: baseUrl, + manager: manager, + signer: signer, + services: services, } r.GET("/", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {