Add reloading config

This commit is contained in:
Melon 2023-11-15 09:21:09 +00:00
parent 54463008f9
commit 41db9d605c
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
8 changed files with 120 additions and 83 deletions

View File

@ -44,34 +44,24 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
return subcommands.ExitUsageError return subcommands.ExitUsageError
} }
openConf, err := os.Open(s.configPath) var conf server.Conf
err := loadConfig(s.configPath, &conf)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
log.Println("[Lavender] Error: missing config file") log.Println("[Lavender] Error: missing config file")
} else { } else {
log.Println("[Lavender] Error: open config file: ", err) log.Println("[Lavender] Error: loading config file: ", err)
} }
return subcommands.ExitFailure return subcommands.ExitFailure
} }
var config server.Conf
err = json.NewDecoder(openConf).Decode(&config)
if err != nil {
log.Println("[Lavender] Error: invalid config file: ", err)
return subcommands.ExitFailure
}
configPathAbs, err := filepath.Abs(s.configPath) configPathAbs, err := filepath.Abs(s.configPath)
if err != nil { if err != nil {
log.Fatal("[Lavender] Failed to get absolute config path") log.Fatal("[Lavender] Failed to get absolute config path")
} }
wd := filepath.Dir(configPathAbs) wd := filepath.Dir(configPathAbs)
normalLoad(config, wd)
return subcommands.ExitSuccess
}
func normalLoad(startUp server.Conf, wd string) { mSign, err := mjwt.NewMJwtSignerFromFileOrCreate(conf.Issuer, filepath.Join(wd, "lavender.private.key"), rand.Reader, 4096)
mSign, err := mjwt.NewMJwtSignerFromFileOrCreate(startUp.Issuer, filepath.Join(wd, "lavender.private.key"), rand.Reader, 4096)
if err != nil { if err != nil {
log.Fatal("[Lavender] Failed to load or create MJWT signer:", err) log.Fatal("[Lavender] Failed to load or create MJWT signer:", err)
} }
@ -81,14 +71,35 @@ func normalLoad(startUp server.Conf, wd string) {
log.Fatal("[Lavender] Failed to load page templates:", err) log.Fatal("[Lavender] Failed to load page templates:", err)
} }
srv := server.NewHttpServer(startUp, mSign) srv := server.NewHttpServer(conf, mSign)
log.Printf("[Lavender] Starting HTTP server on '%s'\n", srv.Addr) log.Printf("[Lavender] Starting HTTP server on '%s'\n", srv.Server.Addr)
go utils.RunBackgroundHttp("HTTP", srv) go utils.RunBackgroundHttp("HTTP", srv.Server)
exit_reload.ExitReload("Tulip", func() {}, func() { exit_reload.ExitReload("Lavender", func() {
var conf server.Conf
err := loadConfig(s.configPath, &conf)
if err != nil {
log.Println("[Lavender] Failed to read config:", err)
}
err = srv.UpdateConfig(conf)
if err != nil {
log.Println("[Lavender] Failed to reload config:", err)
}
}, func() {
// stop http server // stop http server
_ = srv.Close() _ = srv.Server.Close()
}) })
return subcommands.ExitSuccess
}
func loadConfig(configPath string, conf *server.Conf) error {
openConf, err := os.Open(configPath)
if err != nil {
return err
}
return json.NewDecoder(openConf).Decode(conf)
} }
func saveMjwtPubKey(mSign mjwt.Signer, wd string) { func saveMjwtPubKey(mSign mjwt.Signer, wd string) {

View File

@ -12,7 +12,7 @@ type Conf struct {
Issuer string `json:"issuer"` Issuer string `json:"issuer"`
SsoServices []issuer.SsoConfig `json:"sso_services"` SsoServices []issuer.SsoConfig `json:"sso_services"`
AllowedClients []AllowedClient `json:"allowed_clients"` AllowedClients []AllowedClient `json:"allowed_clients"`
Ownership DomainOwnership `json:"ownership"` Users UserConfig `json:"users"`
} }
type AllowedClient struct { type AllowedClient struct {

View File

@ -37,14 +37,14 @@ func (h *HttpServer) flowPopup(rw http.ResponseWriter, req *http.Request, _ http
cookie, err := req.Cookie("lavender-login-name") cookie, err := req.Cookie("lavender-login-name")
if err == nil && cookie.Valid() == nil { if err == nil && cookie.Valid() == nil {
pages.RenderPageTemplate(rw, "flow-popup-memory", map[string]any{ pages.RenderPageTemplate(rw, "flow-popup-memory", map[string]any{
"ServiceName": h.conf.ServiceName, "ServiceName": h.conf.Load().ServiceName,
"Origin": req.URL.Query().Get("origin"), "Origin": req.URL.Query().Get("origin"),
"LoginName": cookie.Value, "LoginName": cookie.Value,
}) })
return return
} }
pages.RenderPageTemplate(rw, "flow-popup", map[string]any{ pages.RenderPageTemplate(rw, "flow-popup", map[string]any{
"ServiceName": h.conf.ServiceName, "ServiceName": h.conf.Load().ServiceName,
"Origin": req.URL.Query().Get("origin"), "Origin": req.URL.Query().Get("origin"),
}) })
} }
@ -68,7 +68,7 @@ func (h *HttpServer) flowPopupPost(rw http.ResponseWriter, req *http.Request, _
return return
} }
loginName := req.PostFormValue("loginname") loginName := req.PostFormValue("loginname")
login := h.manager.FindServiceFromLogin(loginName) login := h.manager.Load().FindServiceFromLogin(loginName)
if login == nil { if login == nil {
http.Error(rw, "No login service defined for this username", http.StatusBadRequest) http.Error(rw, "No login service defined for this username", http.StatusBadRequest)
return return
@ -90,7 +90,7 @@ func (h *HttpServer) flowPopupPost(rw http.ResponseWriter, req *http.Request, _
}) })
targetOrigin := req.PostFormValue("origin") targetOrigin := req.PostFormValue("origin")
allowedService, found := h.services[targetOrigin] allowedService, found := (*h.services.Load())[targetOrigin]
if !found { if !found {
http.Error(rw, "Invalid target origin", http.StatusBadRequest) http.Error(rw, "Invalid target origin", http.StatusBadRequest)
return return
@ -105,7 +105,7 @@ func (h *HttpServer) flowPopupPost(rw http.ResponseWriter, req *http.Request, _
// generate oauth2 config and redirect to authorize URL // generate oauth2 config and redirect to authorize URL
oa2conf := login.OAuth2Config oa2conf := login.OAuth2Config
oa2conf.RedirectURL = h.conf.BaseUrl + "/callback" oa2conf.RedirectURL = h.conf.Load().BaseUrl + "/callback"
nextUrl := oa2conf.AuthCodeURL(state, oauth2.SetAuthURLParam("login_name", loginUn)) nextUrl := oa2conf.AuthCodeURL(state, oauth2.SetAuthURLParam("login_name", loginUn))
http.Redirect(rw, req, nextUrl, http.StatusFound) http.Redirect(rw, req, nextUrl, http.StatusFound)
} }
@ -120,7 +120,7 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
q := req.URL.Query() q := req.URL.Query()
state := q.Get("state") state := q.Get("state")
n := strings.IndexByte(state, ':') n := strings.IndexByte(state, ':')
if n == -1 || !h.manager.CheckNamespace(state[:n]) { if n == -1 || !h.manager.Load().CheckNamespace(state[:n]) {
http.Error(rw, "Invalid state namespace", http.StatusBadRequest) http.Error(rw, "Invalid state namespace", http.StatusBadRequest)
return return
} }
@ -131,7 +131,7 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
} }
oa2conf := v.sso.OAuth2Config oa2conf := v.sso.OAuth2Config
oa2conf.RedirectURL = h.conf.BaseUrl + "/callback" oa2conf.RedirectURL = h.conf.Load().BaseUrl + "/callback"
exchange, err := testOa2Exchange(oa2conf, context.Background(), q.Get("code")) exchange, err := testOa2Exchange(oa2conf, context.Background(), q.Get("code"))
if err != nil { if err != nil {
fmt.Println("Failed exchange:", err) fmt.Println("Failed exchange:", err)
@ -193,25 +193,22 @@ func (h *HttpServer) flowCallback(rw http.ResponseWriter, req *http.Request, _ h
return return
} }
n := strings.IndexByte(address.Address, '@') n := strings.IndexByte(address.Address, '@')
if n == -1 { if n != -1 {
goto noEmailSupport if address.Address[n+1:] == v.sso.Config.Namespace {
ps.Set("mail:client")
} }
if address.Address[n+1:] != v.sso.Config.Namespace {
goto noEmailSupport
} }
ps.Set("mail-client")
} }
} }
} }
if needsDomains { if needsDomains {
a := h.conf.Ownership.AllOwns(sub + "@" + v.sso.Config.Namespace) a := h.conf.Load().Users.AllDomains(sub + "@" + v.sso.Config.Namespace)
for _, i := range a { for _, i := range a {
ps.Set("domain:owns=" + i) ps.Set("domain:owns=" + i)
} }
} }
noEmailSupport:
nsSub := sub + "@" + v.sso.Config.Namespace nsSub := sub + "@" + v.sso.Config.Namespace
ati := uuidNewStringAti() ati := uuidNewStringAti()
accessToken, err := h.signer.GenerateJwt(nsSub, ati, jwt.ClaimStrings{aud}, 15*time.Minute, auth.AccessTokenClaims{ accessToken, err := h.signer.GenerateJwt(nsSub, ati, jwt.ClaimStrings{aud}, 15*time.Minute, auth.AccessTokenClaims{
@ -229,7 +226,7 @@ noEmailSupport:
} }
pages.RenderPageTemplate(rw, "flow-callback", map[string]any{ pages.RenderPageTemplate(rw, "flow-callback", map[string]any{
"ServiceName": h.conf.ServiceName, "ServiceName": h.conf.Load().ServiceName,
"TargetOrigin": v.target.Url.String(), "TargetOrigin": v.target.Url.String(),
"TargetMessage": v3, "TargetMessage": v3,
"AccessToken": accessToken, "AccessToken": accessToken,

View File

@ -66,18 +66,19 @@ var testOidc = &issuer.WellKnownOIDC{
var testManager = issuer.NewManagerForTests([]*issuer.WellKnownOIDC{testOidc}) var testManager = issuer.NewManagerForTests([]*issuer.WellKnownOIDC{testOidc})
var testHttpServer = HttpServer{ var testHttpServer = HttpServer{
r: nil, r: nil,
conf: Conf{
BaseUrl: lavenderDomain,
ServiceName: "Test Lavender Service",
},
manager: testManager,
flowState: cache.New[string, flowStateData](), flowState: cache.New[string, flowStateData](),
services: map[string]AllowedClient{
clientAppDomain: {},
},
} }
func init() { func init() {
testHttpServer.conf.Store(&Conf{
BaseUrl: lavenderDomain,
ServiceName: "Test Lavender Service",
})
testHttpServer.manager.Store(testManager)
testHttpServer.services.Store(&map[string]AllowedClient{
clientAppDomain: {},
})
err := pages.LoadPages("") err := pages.LoadPages("")
if err != nil { if err != nil {
panic(err) panic(err)
@ -103,7 +104,8 @@ func init() {
} }
func TestFlowPopup(t *testing.T) { func TestFlowPopup(t *testing.T) {
h := HttpServer{conf: Conf{ServiceName: "Test Service Name"}} h := HttpServer{}
h.conf.Store(&Conf{ServiceName: "Test Service Name"})
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/popup?"+url.Values{"origin": []string{clientAppDomain}}.Encode(), nil) req := httptest.NewRequest(http.MethodGet, "/popup?"+url.Values{"origin": []string{clientAppDomain}}.Encode(), nil)
h.flowPopup(rec, req, httprouter.Params{}) h.flowPopup(rec, req, httprouter.Params{})

View File

@ -1,14 +1,30 @@
package server package server
// DomainOwnership is the structure for storing if a user owns a domain // UserConfig is the structure for storing a user's role and owned domains
type DomainOwnership map[string][]string type UserConfig map[string]struct {
Roles []string `json:"roles"`
func (d DomainOwnership) AllOwns(user string) []string { Domains []string `json:"domains"`
return d[user]
} }
func (d DomainOwnership) Owns(user, domain string) bool { func (u UserConfig) AllRoles(user string) []string {
for _, i := range d[user] { return u[user].Roles
}
func (u UserConfig) HasRole(user, role string) bool {
for _, i := range u[user].Roles {
if i == role {
return true
}
}
return false
}
func (u UserConfig) AllDomains(user string) []string {
return u[user].Domains
}
func (u UserConfig) OwnsDomain(user, domain string) bool {
for _, i := range u[user].Domains {
if i == domain { if i == domain {
return true return true
} }

View File

@ -8,16 +8,18 @@ import (
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"log" "log"
"net/http" "net/http"
"sync/atomic"
"time" "time"
) )
type HttpServer struct { type HttpServer struct {
Server *http.Server
r *httprouter.Router r *httprouter.Router
conf Conf conf atomic.Pointer[Conf]
manager *issuer.Manager manager atomic.Pointer[issuer.Manager]
signer mjwt.Signer signer mjwt.Signer
flowState *cache.Cache[string, flowStateData] flowState *cache.Cache[string, flowStateData]
services map[string]AllowedClient services atomic.Pointer[map[string]AllowedClient]
} }
type flowStateData struct { type flowStateData struct {
@ -25,7 +27,7 @@ type flowStateData struct {
target AllowedClient target AllowedClient
} }
func NewHttpServer(conf Conf, signer mjwt.Signer) *http.Server { func NewHttpServer(conf Conf, signer mjwt.Signer) *HttpServer {
r := httprouter.New() r := httprouter.New()
// remove last slash from baseUrl // remove last slash from baseUrl
@ -36,23 +38,24 @@ func NewHttpServer(conf Conf, signer mjwt.Signer) *http.Server {
} }
} }
manager, err := issuer.NewManager(conf.SsoServices)
if err != nil {
log.Fatal("[Lavender] Failed to create SSO service manager: ", err)
}
services := make(map[string]AllowedClient)
for _, i := range conf.AllowedClients {
services[i.Url.String()] = i
}
hs := &HttpServer{ hs := &HttpServer{
Server: &http.Server{
Addr: conf.Listen,
Handler: r,
ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute,
WriteTimeout: time.Minute,
IdleTimeout: time.Minute,
MaxHeaderBytes: 2500,
},
r: r, r: r,
conf: conf,
manager: manager,
signer: signer, signer: signer,
flowState: cache.New[string, flowStateData](), flowState: cache.New[string, flowStateData](),
services: services, }
err := hs.UpdateConfig(conf)
if err != nil {
log.Fatalln("Failed to load initial config:", err)
return nil
} }
r.GET("/", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { r.GET("/", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
@ -63,14 +66,22 @@ func NewHttpServer(conf Conf, signer mjwt.Signer) *http.Server {
r.GET("/popup", hs.flowPopup) r.GET("/popup", hs.flowPopup)
r.POST("/popup", hs.flowPopupPost) r.POST("/popup", hs.flowPopupPost)
r.GET("/callback", hs.flowCallback) r.GET("/callback", hs.flowCallback)
return hs
}
return &http.Server{ func (h *HttpServer) UpdateConfig(conf Conf) error {
Addr: conf.Listen, m, err := issuer.NewManager(conf.SsoServices)
Handler: r, if err != nil {
ReadTimeout: time.Minute, return fmt.Errorf("failed to reload SSO service manager: %w", err)
ReadHeaderTimeout: time.Minute,
WriteTimeout: time.Minute,
IdleTimeout: time.Minute,
MaxHeaderBytes: 2500,
} }
clientLookup := make(map[string]AllowedClient)
for _, i := range conf.AllowedClients {
clientLookup[i.Url.String()] = i
}
h.conf.Store(&conf)
h.manager.Store(m)
h.services.Store(&clientLookup)
return nil
} }

View File

@ -1,9 +1,9 @@
package server package server
import ( import (
"github.com/1f349/violet/utils"
"github.com/1f349/mjwt" "github.com/1f349/mjwt"
"github.com/1f349/mjwt/auth" "github.com/1f349/mjwt/auth"
"github.com/1f349/violet/utils"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"net/http" "net/http"
) )
@ -24,7 +24,7 @@ func (h *HttpServer) verifyHandler(rw http.ResponseWriter, req *http.Request, _
} }
// check issuer against config // check issuer against config
if b.Issuer != h.conf.Issuer { if b.Issuer != h.conf.Load().Issuer {
http.Error(rw, "Invalid issuer", http.StatusBadRequest) http.Error(rw, "Invalid issuer", http.StatusBadRequest)
return return
} }

View File

@ -20,9 +20,9 @@ func TestVerifyHandler(t *testing.T) {
invalidSigner := mjwt.NewMJwtSigner("Invalid Issuer", privKey) invalidSigner := mjwt.NewMJwtSigner("Invalid Issuer", privKey)
h := HttpServer{ h := HttpServer{
conf: Conf{Issuer: "Test Issuer"},
signer: mjwt.NewMJwtSigner("Test Issuer", privKey), signer: mjwt.NewMJwtSigner("Test Issuer", privKey),
} }
h.conf.Store(&Conf{Issuer: "Test Issuer"})
// test for missing bearer response // test for missing bearer response
rec := httptest.NewRecorder() rec := httptest.NewRecorder()