diff --git a/cmd/lavender/conf.go b/cmd/lavender/conf.go deleted file mode 100644 index c9b465f..0000000 --- a/cmd/lavender/conf.go +++ /dev/null @@ -1,15 +0,0 @@ -package main - -import ( - loginServiceManager "github.com/1f349/lavender/issuer" - "github.com/1f349/lavender/utils" -) - -type startUpConfig struct { - Listen string `json:"listen"` - BaseUrl string `json:"base_url"` - ServiceName string `json:"service_name"` - Issuer string `json:"issuer"` - SsoServices []loginServiceManager.SsoConfig `json:"sso_services"` - AllowedClients []utils.JsonUrl `json:"allowed_clients"` -} diff --git a/cmd/lavender/serve.go b/cmd/lavender/serve.go index 4f94ca7..881d222 100644 --- a/cmd/lavender/serve.go +++ b/cmd/lavender/serve.go @@ -5,7 +5,6 @@ import ( "crypto/rand" "encoding/json" "flag" - "github.com/1f349/lavender/issuer" "github.com/1f349/lavender/server" "github.com/1f349/violet/utils" exit_reload "github.com/MrMelon54/exit-reload" @@ -50,7 +49,7 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) return subcommands.ExitFailure } - var config startUpConfig + var config server.Conf err = json.NewDecoder(openConf).Decode(&config) if err != nil { log.Println("[Lavender] Error: invalid config file: ", err) @@ -66,18 +65,13 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) return subcommands.ExitSuccess } -func normalLoad(startUp startUpConfig, wd string) { +func normalLoad(startUp server.Conf, wd string) { mSign, err := mjwt.NewMJwtSignerFromFileOrCreate(startUp.Issuer, filepath.Join(wd, "lavender.private.key"), rand.Reader, 4096) if err != nil { log.Fatal("[Lavender] Failed to load or create MJWT signer:", err) } - manager, err := issuer.NewManager(startUp.SsoServices) - if err != nil { - log.Fatal("[Lavender] Failed to create SSO service manager: ", err) - } - - srv := server.NewHttpServer(startUp.Listen, startUp.BaseUrl, startUp.ServiceName, startUp.AllowedClients, manager, mSign) + srv := server.NewHttpServer(startUp, mSign) log.Printf("[Lavender] Starting HTTP server on '%s'\n", srv.Addr) go utils.RunBackgroundHttp("HTTP", srv) diff --git a/issuer/manager_test.go b/issuer/manager_test.go index bf829d0..143014e 100644 --- a/issuer/manager_test.go +++ b/issuer/manager_test.go @@ -22,7 +22,7 @@ func testBody() io.ReadCloser { return io.NopCloser(strings.NewReader("{}")) } -func TestManager_CheckIssuer(t *testing.T) { +func TestManager_CheckNamespace(t *testing.T) { httpGet = func(url string) (resp *http.Response, err error) { return &http.Response{StatusCode: http.StatusOK, Body: testBody()}, nil } diff --git a/issuer/sso_test.go b/issuer/sso_test.go new file mode 100644 index 0000000..de84991 --- /dev/null +++ b/issuer/sso_test.go @@ -0,0 +1 @@ +package issuer diff --git a/server/conf.go b/server/conf.go new file mode 100644 index 0000000..bb7ddbb --- /dev/null +++ b/server/conf.go @@ -0,0 +1,15 @@ +package server + +import ( + "github.com/1f349/lavender/issuer" + "github.com/1f349/lavender/utils" +) + +type Conf struct { + Listen string `json:"listen"` + BaseUrl string `json:"base_url"` + ServiceName string `json:"service_name"` + Issuer string `json:"issuer"` + SsoServices []issuer.SsoConfig `json:"sso_services"` + AllowedClients []utils.JsonUrl `json:"allowed_clients"` +} diff --git a/server/flow.go b/server/flow.go index 852100e..69ad3a5 100644 --- a/server/flow.go +++ b/server/flow.go @@ -18,16 +18,6 @@ import ( "time" ) -var ( - //go:embed flow-popup.go.html - flowPopupHtml string - flowPopupTemplate *template.Template - - //go:embed flow-callback.go.html - flowCallbackHtml string - flowCallbackTemplate *template.Template -) - func init() { pageParse, err := template.New("pages").Parse(flowPopupHtml) if err != nil { @@ -44,7 +34,7 @@ func init() { func (h *HttpServer) flowPopup(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { err := flowPopupTemplate.Execute(rw, map[string]any{ - "ServiceName": h.serviceName, + "ServiceName": h.conf.ServiceName, "Origin": req.URL.Query().Get("origin"), }) if err != nil { @@ -75,7 +65,7 @@ func (h *HttpServer) flowPopupPost(rw http.ResponseWriter, req *http.Request, _ // generate oauth2 config and redirect to authorize URL oa2conf := login.OAuth2Config - oa2conf.RedirectURL = h.baseUrl + "/callback" + oa2conf.RedirectURL = h.conf.BaseUrl + "/callback" nextUrl := oa2conf.AuthCodeURL(state, oauth2.SetAuthURLParam("login_name", loginName)) http.Redirect(rw, req, nextUrl, http.StatusFound) } @@ -101,7 +91,7 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h } oa2conf := v.sso.OAuth2Config - oa2conf.RedirectURL = h.baseUrl + "/callback" + oa2conf.RedirectURL = h.conf.BaseUrl + "/callback" exchange, err := oa2conf.Exchange(context.Background(), q.Get("code")) if err != nil { fmt.Println("Failed exchange:", err) @@ -157,7 +147,7 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h } _ = flowCallbackTemplate.Execute(rw, map[string]any{ - "ServiceName": h.serviceName, + "ServiceName": h.conf.ServiceName, "TargetOrigin": v.targetOrigin, "TargetMessage": v3, "AccessToken": accessToken, diff --git a/server/flow_test.go b/server/flow_test.go new file mode 100644 index 0000000..abb4e43 --- /dev/null +++ b/server/flow_test.go @@ -0,0 +1 @@ +package server diff --git a/server/flow-callback.go.html b/server/pages/flow-callback.go.html similarity index 100% rename from server/flow-callback.go.html rename to server/pages/flow-callback.go.html diff --git a/server/flow-popup.go.html b/server/pages/flow-popup.go.html similarity index 100% rename from server/flow-popup.go.html rename to server/pages/flow-popup.go.html diff --git a/server/pages/pages.go b/server/pages/pages.go new file mode 100644 index 0000000..e3f8618 --- /dev/null +++ b/server/pages/pages.go @@ -0,0 +1,20 @@ +package pages + +import ( + "embed" + _ "embed" + "html/template" + "os" +) + +var ( + //go:embed pages/* + flowPages embed.FS + flowTemplates *template.Template +) + +func LoadPages(wd string) { + wdFs := os.DirFS(wd) + + flowPages.Open() +} diff --git a/server/server.go b/server/server.go index f0447ad..7b0474b 100644 --- a/server/server.go +++ b/server/server.go @@ -4,21 +4,20 @@ import ( "fmt" "github.com/1f349/cache" "github.com/1f349/lavender/issuer" - "github.com/1f349/lavender/utils" "github.com/MrMelon54/mjwt" "github.com/julienschmidt/httprouter" + "log" "net/http" "time" ) type HttpServer struct { - r *httprouter.Router - baseUrl string - serviceName string - manager *issuer.Manager - signer mjwt.Signer - flowState *cache.Cache[string, flowStateData] - services map[string]struct{} + r *httprouter.Router + conf Conf + manager *issuer.Manager + signer mjwt.Signer + flowState *cache.Cache[string, flowStateData] + services map[string]struct{} } type flowStateData struct { @@ -26,30 +25,34 @@ type flowStateData struct { targetOrigin string } -func NewHttpServer(listen, baseUrl, serviceName string, clients []utils.JsonUrl, manager *issuer.Manager, signer mjwt.Signer) *http.Server { +func NewHttpServer(conf Conf, signer mjwt.Signer) *http.Server { r := httprouter.New() // remove last slash from baseUrl { - l := len(baseUrl) - if baseUrl[l-1] == '/' { - baseUrl = baseUrl[:l-1] + l := len(conf.BaseUrl) + if conf.BaseUrl[l-1] == '/' { + conf.BaseUrl = conf.BaseUrl[:l-1] } } + manager, err := issuer.NewManager(conf.SsoServices) + if err != nil { + log.Fatal("[Lavender] Failed to create SSO service manager: ", err) + } + services := make(map[string]struct{}) - for _, i := range clients { + for _, i := range conf.AllowedClients { services[i.String()] = struct{}{} } hs := &HttpServer{ - r: r, - baseUrl: baseUrl, - serviceName: serviceName, - manager: manager, - signer: signer, - flowState: cache.New[string, flowStateData](), - services: services, + r: r, + conf: conf, + manager: manager, + signer: signer, + flowState: cache.New[string, flowStateData](), + services: services, } r.GET("/", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { @@ -62,7 +65,7 @@ func NewHttpServer(listen, baseUrl, serviceName string, clients []utils.JsonUrl, r.GET("/callback", hs.flowCallback) return &http.Server{ - Addr: listen, + Addr: conf.Listen, Handler: r, ReadTimeout: time.Minute, ReadHeaderTimeout: time.Minute, diff --git a/server/verify.go b/server/verify.go index 67831c3..a813977 100644 --- a/server/verify.go +++ b/server/verify.go @@ -24,7 +24,7 @@ func (h *HttpServer) verifyHandler(rw http.ResponseWriter, req *http.Request, _ } // check issuer against config - if b.Issuer != h.baseUrl { + if b.Issuer != h.conf.Issuer { http.Error(rw, "Invalid issuer", http.StatusBadRequest) return } diff --git a/server/verify_test.go b/server/verify_test.go new file mode 100644 index 0000000..42862bd --- /dev/null +++ b/server/verify_test.go @@ -0,0 +1,61 @@ +package server + +import ( + "crypto/rand" + "crypto/rsa" + "github.com/MrMelon54/mjwt" + "github.com/MrMelon54/mjwt/auth" + "github.com/MrMelon54/mjwt/claims" + "github.com/julienschmidt/httprouter" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestVerifyHandler(t *testing.T) { + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + invalidSigner := mjwt.NewMJwtSigner("Invalid Issuer", privKey) + h := HttpServer{ + conf: Conf{Issuer: "Test Issuer"}, + signer: mjwt.NewMJwtSigner("Test Issuer", privKey), + } + + // test for missing bearer response + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "https://example.localhost", nil) + h.verifyHandler(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusForbidden, rec.Code) + assert.Equal(t, "Missing bearer\n", rec.Body.String()) + + // test for invalid token response + rec = httptest.NewRecorder() + req.Header.Set("Authorization", "Bearer abcd") + h.verifyHandler(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusForbidden, rec.Code) + assert.Equal(t, "Invalid token\n", rec.Body.String()) + + // test for invalid issuer response + rec = httptest.NewRecorder() + accessToken, err := invalidSigner.GenerateJwt("a", "a", nil, 15*time.Minute, auth.AccessTokenClaims{ + Perms: claims.NewPermStorage(), + }) + assert.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+accessToken) + h.verifyHandler(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Equal(t, "Invalid issuer\n", rec.Body.String()) + + // test for invalid issuer response + rec = httptest.NewRecorder() + accessToken, err = h.signer.GenerateJwt("a", "a", nil, 15*time.Minute, auth.AccessTokenClaims{ + Perms: claims.NewPermStorage(), + }) + assert.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+accessToken) + h.verifyHandler(rec, req, httprouter.Params{}) + assert.Equal(t, http.StatusOK, rec.Code) +}