diff --git a/cmd/lavender/serve.go b/cmd/lavender/serve.go index f904638..5970748 100644 --- a/cmd/lavender/serve.go +++ b/cmd/lavender/serve.go @@ -44,34 +44,24 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) return subcommands.ExitUsageError } - openConf, err := os.Open(s.configPath) + var conf server.Conf + err := loadConfig(s.configPath, &conf) if err != nil { if os.IsNotExist(err) { log.Println("[Lavender] Error: missing config file") } else { - log.Println("[Lavender] Error: open config file: ", err) + log.Println("[Lavender] Error: loading config file: ", err) } return subcommands.ExitFailure } - var config server.Conf - err = json.NewDecoder(openConf).Decode(&config) - if err != nil { - log.Println("[Lavender] Error: invalid config file: ", err) - return subcommands.ExitFailure - } - configPathAbs, err := filepath.Abs(s.configPath) if err != nil { log.Fatal("[Lavender] Failed to get absolute config path") } wd := filepath.Dir(configPathAbs) - normalLoad(config, wd) - return subcommands.ExitSuccess -} -func normalLoad(startUp server.Conf, wd string) { - mSign, err := mjwt.NewMJwtSignerFromFileOrCreate(startUp.Issuer, filepath.Join(wd, "lavender.private.key"), rand.Reader, 4096) + mSign, err := mjwt.NewMJwtSignerFromFileOrCreate(conf.Issuer, filepath.Join(wd, "lavender.private.key"), rand.Reader, 4096) if err != nil { log.Fatal("[Lavender] Failed to load or create MJWT signer:", err) } @@ -81,14 +71,35 @@ func normalLoad(startUp server.Conf, wd string) { log.Fatal("[Lavender] Failed to load page templates:", err) } - srv := server.NewHttpServer(startUp, mSign) - log.Printf("[Lavender] Starting HTTP server on '%s'\n", srv.Addr) - go utils.RunBackgroundHttp("HTTP", srv) + srv := server.NewHttpServer(conf, mSign) + log.Printf("[Lavender] Starting HTTP server on '%s'\n", srv.Server.Addr) + go utils.RunBackgroundHttp("HTTP", srv.Server) - exit_reload.ExitReload("Tulip", func() {}, func() { + exit_reload.ExitReload("Lavender", func() { + var conf server.Conf + err := loadConfig(s.configPath, &conf) + if err != nil { + log.Println("[Lavender] Failed to read config:", err) + } + err = srv.UpdateConfig(conf) + if err != nil { + log.Println("[Lavender] Failed to reload config:", err) + } + }, func() { // stop http server - _ = srv.Close() + _ = srv.Server.Close() }) + + return subcommands.ExitSuccess +} + +func loadConfig(configPath string, conf *server.Conf) error { + openConf, err := os.Open(configPath) + if err != nil { + return err + } + + return json.NewDecoder(openConf).Decode(conf) } func saveMjwtPubKey(mSign mjwt.Signer, wd string) { diff --git a/server/conf.go b/server/conf.go index 382a365..719db7c 100644 --- a/server/conf.go +++ b/server/conf.go @@ -12,7 +12,7 @@ type Conf struct { Issuer string `json:"issuer"` SsoServices []issuer.SsoConfig `json:"sso_services"` AllowedClients []AllowedClient `json:"allowed_clients"` - Ownership DomainOwnership `json:"ownership"` + Users UserConfig `json:"users"` } type AllowedClient struct { diff --git a/server/flow.go b/server/flow.go index 6ce455a..caea9f5 100644 --- a/server/flow.go +++ b/server/flow.go @@ -37,14 +37,14 @@ func (h *HttpServer) flowPopup(rw http.ResponseWriter, req *http.Request, _ http cookie, err := req.Cookie("lavender-login-name") if err == nil && cookie.Valid() == nil { pages.RenderPageTemplate(rw, "flow-popup-memory", map[string]any{ - "ServiceName": h.conf.ServiceName, + "ServiceName": h.conf.Load().ServiceName, "Origin": req.URL.Query().Get("origin"), "LoginName": cookie.Value, }) return } pages.RenderPageTemplate(rw, "flow-popup", map[string]any{ - "ServiceName": h.conf.ServiceName, + "ServiceName": h.conf.Load().ServiceName, "Origin": req.URL.Query().Get("origin"), }) } @@ -68,7 +68,7 @@ func (h *HttpServer) flowPopupPost(rw http.ResponseWriter, req *http.Request, _ return } loginName := req.PostFormValue("loginname") - login := h.manager.FindServiceFromLogin(loginName) + login := h.manager.Load().FindServiceFromLogin(loginName) if login == nil { http.Error(rw, "No login service defined for this username", http.StatusBadRequest) return @@ -90,7 +90,7 @@ func (h *HttpServer) flowPopupPost(rw http.ResponseWriter, req *http.Request, _ }) targetOrigin := req.PostFormValue("origin") - allowedService, found := h.services[targetOrigin] + allowedService, found := (*h.services.Load())[targetOrigin] if !found { http.Error(rw, "Invalid target origin", http.StatusBadRequest) return @@ -105,7 +105,7 @@ func (h *HttpServer) flowPopupPost(rw http.ResponseWriter, req *http.Request, _ // generate oauth2 config and redirect to authorize URL oa2conf := login.OAuth2Config - oa2conf.RedirectURL = h.conf.BaseUrl + "/callback" + oa2conf.RedirectURL = h.conf.Load().BaseUrl + "/callback" nextUrl := oa2conf.AuthCodeURL(state, oauth2.SetAuthURLParam("login_name", loginUn)) http.Redirect(rw, req, nextUrl, http.StatusFound) } @@ -120,7 +120,7 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h q := req.URL.Query() state := q.Get("state") n := strings.IndexByte(state, ':') - if n == -1 || !h.manager.CheckNamespace(state[:n]) { + if n == -1 || !h.manager.Load().CheckNamespace(state[:n]) { http.Error(rw, "Invalid state namespace", http.StatusBadRequest) return } @@ -131,7 +131,7 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h } oa2conf := v.sso.OAuth2Config - oa2conf.RedirectURL = h.conf.BaseUrl + "/callback" + oa2conf.RedirectURL = h.conf.Load().BaseUrl + "/callback" exchange, err := testOa2Exchange(oa2conf, context.Background(), q.Get("code")) if err != nil { fmt.Println("Failed exchange:", err) @@ -193,25 +193,22 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h return } n := strings.IndexByte(address.Address, '@') - if n == -1 { - goto noEmailSupport + if n != -1 { + if address.Address[n+1:] == v.sso.Config.Namespace { + ps.Set("mail:client") + } } - if address.Address[n+1:] != v.sso.Config.Namespace { - goto noEmailSupport - } - ps.Set("mail-client") } } } if needsDomains { - a := h.conf.Ownership.AllOwns(sub + "@" + v.sso.Config.Namespace) + a := h.conf.Load().Users.AllDomains(sub + "@" + v.sso.Config.Namespace) for _, i := range a { ps.Set("domain:owns=" + i) } } -noEmailSupport: nsSub := sub + "@" + v.sso.Config.Namespace ati := uuidNewStringAti() accessToken, err := h.signer.GenerateJwt(nsSub, ati, jwt.ClaimStrings{aud}, 15*time.Minute, auth.AccessTokenClaims{ @@ -229,7 +226,7 @@ noEmailSupport: } pages.RenderPageTemplate(rw, "flow-callback", map[string]any{ - "ServiceName": h.conf.ServiceName, + "ServiceName": h.conf.Load().ServiceName, "TargetOrigin": v.target.Url.String(), "TargetMessage": v3, "AccessToken": accessToken, diff --git a/server/flow_test.go b/server/flow_test.go index 8258b61..2a0d84c 100644 --- a/server/flow_test.go +++ b/server/flow_test.go @@ -65,19 +65,20 @@ var testOidc = &issuer.WellKnownOIDC{ var testManager = issuer.NewManagerForTests([]*issuer.WellKnownOIDC{testOidc}) var testHttpServer = HttpServer{ - r: nil, - conf: Conf{ - BaseUrl: lavenderDomain, - ServiceName: "Test Lavender Service", - }, - manager: testManager, + r: nil, flowState: cache.New[string, flowStateData](), - services: map[string]AllowedClient{ - clientAppDomain: {}, - }, } func init() { + testHttpServer.conf.Store(&Conf{ + BaseUrl: lavenderDomain, + ServiceName: "Test Lavender Service", + }) + testHttpServer.manager.Store(testManager) + testHttpServer.services.Store(&map[string]AllowedClient{ + clientAppDomain: {}, + }) + err := pages.LoadPages("") if err != nil { panic(err) @@ -103,7 +104,8 @@ func init() { } func TestFlowPopup(t *testing.T) { - h := HttpServer{conf: Conf{ServiceName: "Test Service Name"}} + h := HttpServer{} + h.conf.Store(&Conf{ServiceName: "Test Service Name"}) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/popup?"+url.Values{"origin": []string{clientAppDomain}}.Encode(), nil) h.flowPopup(rec, req, httprouter.Params{}) diff --git a/server/owners.go b/server/owners.go index 9ff797c..9a13a20 100644 --- a/server/owners.go +++ b/server/owners.go @@ -1,14 +1,30 @@ package server -// DomainOwnership is the structure for storing if a user owns a domain -type DomainOwnership map[string][]string - -func (d DomainOwnership) AllOwns(user string) []string { - return d[user] +// UserConfig is the structure for storing a user's role and owned domains +type UserConfig map[string]struct { + Roles []string `json:"roles"` + Domains []string `json:"domains"` } -func (d DomainOwnership) Owns(user, domain string) bool { - for _, i := range d[user] { +func (u UserConfig) AllRoles(user string) []string { + return u[user].Roles +} + +func (u UserConfig) HasRole(user, role string) bool { + for _, i := range u[user].Roles { + if i == role { + return true + } + } + return false +} + +func (u UserConfig) AllDomains(user string) []string { + return u[user].Domains +} + +func (u UserConfig) OwnsDomain(user, domain string) bool { + for _, i := range u[user].Domains { if i == domain { return true } diff --git a/server/server.go b/server/server.go index 163bb26..0a6a9f8 100644 --- a/server/server.go +++ b/server/server.go @@ -8,16 +8,18 @@ import ( "github.com/julienschmidt/httprouter" "log" "net/http" + "sync/atomic" "time" ) type HttpServer struct { + Server *http.Server r *httprouter.Router - conf Conf - manager *issuer.Manager + conf atomic.Pointer[Conf] + manager atomic.Pointer[issuer.Manager] signer mjwt.Signer flowState *cache.Cache[string, flowStateData] - services map[string]AllowedClient + services atomic.Pointer[map[string]AllowedClient] } type flowStateData struct { @@ -25,7 +27,7 @@ type flowStateData struct { target AllowedClient } -func NewHttpServer(conf Conf, signer mjwt.Signer) *http.Server { +func NewHttpServer(conf Conf, signer mjwt.Signer) *HttpServer { r := httprouter.New() // remove last slash from baseUrl @@ -36,23 +38,24 @@ func NewHttpServer(conf Conf, signer mjwt.Signer) *http.Server { } } - manager, err := issuer.NewManager(conf.SsoServices) - if err != nil { - log.Fatal("[Lavender] Failed to create SSO service manager: ", err) - } - - services := make(map[string]AllowedClient) - for _, i := range conf.AllowedClients { - services[i.Url.String()] = i - } - hs := &HttpServer{ + Server: &http.Server{ + Addr: conf.Listen, + Handler: r, + ReadTimeout: time.Minute, + ReadHeaderTimeout: time.Minute, + WriteTimeout: time.Minute, + IdleTimeout: time.Minute, + MaxHeaderBytes: 2500, + }, r: r, - conf: conf, - manager: manager, signer: signer, flowState: cache.New[string, flowStateData](), - services: services, + } + err := hs.UpdateConfig(conf) + if err != nil { + log.Fatalln("Failed to load initial config:", err) + return nil } r.GET("/", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { @@ -63,14 +66,22 @@ func NewHttpServer(conf Conf, signer mjwt.Signer) *http.Server { r.GET("/popup", hs.flowPopup) r.POST("/popup", hs.flowPopupPost) r.GET("/callback", hs.flowCallback) - - return &http.Server{ - Addr: conf.Listen, - Handler: r, - ReadTimeout: time.Minute, - ReadHeaderTimeout: time.Minute, - WriteTimeout: time.Minute, - IdleTimeout: time.Minute, - MaxHeaderBytes: 2500, - } + return hs +} + +func (h *HttpServer) UpdateConfig(conf Conf) error { + m, err := issuer.NewManager(conf.SsoServices) + if err != nil { + return fmt.Errorf("failed to reload SSO service manager: %w", err) + } + + clientLookup := make(map[string]AllowedClient) + for _, i := range conf.AllowedClients { + clientLookup[i.Url.String()] = i + } + + h.conf.Store(&conf) + h.manager.Store(m) + h.services.Store(&clientLookup) + return nil } diff --git a/server/verify.go b/server/verify.go index d3eaf74..53746be 100644 --- a/server/verify.go +++ b/server/verify.go @@ -1,9 +1,9 @@ package server import ( - "github.com/1f349/violet/utils" "github.com/1f349/mjwt" "github.com/1f349/mjwt/auth" + "github.com/1f349/violet/utils" "github.com/julienschmidt/httprouter" "net/http" ) @@ -24,7 +24,7 @@ func (h *HttpServer) verifyHandler(rw http.ResponseWriter, req *http.Request, _ } // check issuer against config - if b.Issuer != h.conf.Issuer { + if b.Issuer != h.conf.Load().Issuer { http.Error(rw, "Invalid issuer", http.StatusBadRequest) return } diff --git a/server/verify_test.go b/server/verify_test.go index 1d48bfe..fb2e907 100644 --- a/server/verify_test.go +++ b/server/verify_test.go @@ -20,9 +20,9 @@ func TestVerifyHandler(t *testing.T) { invalidSigner := mjwt.NewMJwtSigner("Invalid Issuer", privKey) h := HttpServer{ - conf: Conf{Issuer: "Test Issuer"}, signer: mjwt.NewMJwtSigner("Test Issuer", privKey), } + h.conf.Store(&Conf{Issuer: "Test Issuer"}) // test for missing bearer response rec := httptest.NewRecorder()