diff --git a/cmd/violet/serve.go b/cmd/violet/serve.go index 8e79f80..b11504e 100644 --- a/cmd/violet/serve.go +++ b/cmd/violet/serve.go @@ -18,35 +18,53 @@ import ( "github.com/1f349/violet/servers/api" "github.com/1f349/violet/servers/conf" "github.com/1f349/violet/utils" + "github.com/charmbracelet/log" + "github.com/cloudflare/tableflip" "github.com/google/subcommands" - "github.com/mrmelon54/exit-reload" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" "io/fs" "net/http" "os" + "os/signal" "path/filepath" + "syscall" + "time" ) type serveCmd struct { configPath string - cpuprofile string + debugLog bool + pidFile string } func (s *serveCmd) Name() string { return "serve" } func (s *serveCmd) Synopsis() string { return "Serve reverse proxy server" } func (s *serveCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&s.configPath, "conf", "", "/path/to/config.json : path to the config file") + f.BoolVar(&s.debugLog, "debug", false, "enable debug logging") + f.StringVar(&s.pidFile, "pid-file", "", "path to pid file") } func (s *serveCmd) Usage() string { - return `serve [-conf ] + return `serve [-conf ] [-debug] [-pid-file ] Serve reverse proxy server using information from config file ` } func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { + if s.debugLog { + logger.Logger.SetLevel(log.DebugLevel) + } logger.Logger.Info("Starting...") + upg, err := tableflip.New(tableflip.Options{ + PIDFile: s.pidFile, + }) + if err != nil { + panic(err) + } + defer upg.Stop() + if s.configPath == "" { logger.Logger.Info("Error: config flag is missing") return subcommands.ExitUsageError @@ -71,13 +89,9 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) // working directory is the parent of the config file wd := filepath.Dir(s.configPath) - normalLoad(config, wd) - return subcommands.ExitSuccess -} -func normalLoad(startUp startUpConfig, wd string) { // the cert and key paths are useless in self-signed mode - if !startUp.SelfSigned { + if !config.SelfSigned { // create path to cert dir err := os.MkdirAll(filepath.Join(wd, "certs"), os.ModePerm) if err != nil { @@ -92,11 +106,11 @@ func normalLoad(startUp startUpConfig, wd string) { // errorPageDir stores an FS interface for accessing the error page directory var errorPageDir fs.FS - if startUp.ErrorPagePath != "" { - errorPageDir = os.DirFS(startUp.ErrorPagePath) - err := os.MkdirAll(startUp.ErrorPagePath, os.ModePerm) + if config.ErrorPagePath != "" { + errorPageDir = os.DirFS(config.ErrorPagePath) + err := os.MkdirAll(config.ErrorPagePath, os.ModePerm) if err != nil { - logger.Logger.Fatal("Failed to create error page", "path", startUp.ErrorPagePath) + logger.Logger.Fatal("Failed to create error page", "path", config.ErrorPagePath) } } @@ -123,75 +137,113 @@ func normalLoad(startUp startUpConfig, wd string) { ) ws := websocket.NewServer() - allowedDomains := domains.New(db) // load allowed domains - acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store - allowedCerts := certs.New(certDir, keyDir, startUp.SelfSigned) // load certificate manager - hybridTransport := proxy.NewHybridTransport(ws) // load reverse proxy - dynamicFavicons := favicons.New(db, startUp.InkscapeCmd) // load dynamic favicon provider - dynamicErrorPages := errorPages.New(errorPageDir) // load dynamic error page provider - dynamicRouter := router.NewManager(db, hybridTransport) // load dynamic router manager + allowedDomains := domains.New(db) // load allowed domains + acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store + allowedCerts := certs.New(certDir, keyDir, config.SelfSigned) // load certificate manager + hybridTransport := proxy.NewHybridTransport(ws) // load reverse proxy + dynamicFavicons := favicons.New(db, config.InkscapeCmd) // load dynamic favicon provider + dynamicErrorPages := errorPages.New(errorPageDir) // load dynamic error page provider + dynamicRouter := router.NewManager(db, hybridTransport) // load dynamic router manager // struct containing config for the http servers srvConf := &conf.Conf{ - ApiListen: startUp.Listen.Api, - HttpListen: startUp.Listen.Http, - HttpsListen: startUp.Listen.Https, - RateLimit: startUp.RateLimit, - DB: db, - Domains: allowedDomains, - Acme: acmeChallenges, - Certs: allowedCerts, - Favicons: dynamicFavicons, - Signer: mJwtVerify, - ErrorPages: dynamicErrorPages, - Router: dynamicRouter, + RateLimit: config.RateLimit, + DB: db, + Domains: allowedDomains, + Acme: acmeChallenges, + Certs: allowedCerts, + Favicons: dynamicFavicons, + Signer: mJwtVerify, + ErrorPages: dynamicErrorPages, + Router: dynamicRouter, } // create the compilable list and run a first time compile allCompilables := utils.MultiCompilable{allowedDomains, allowedCerts, dynamicFavicons, dynamicErrorPages, dynamicRouter} allCompilables.Compile() + _, httpsPort, ok := utils.SplitDomainPort(config.Listen.Https, 443) + if !ok { + httpsPort = 443 + } + var srvApi, srvHttp, srvHttps *http.Server - if srvConf.ApiListen != "" { + if config.Listen.Api != "" { + // Listen must be called before Ready + lnApi, err := upg.Listen("tcp", config.Listen.Api) + if err != nil { + logger.Logger.Fatal("Listen failed", "err", err) + } srvApi = api.NewApiServer(srvConf, allCompilables, promRegistry) srvApi.SetKeepAlivesEnabled(false) l := logger.Logger.With("server", "API") l.Info("Starting server", "addr", srvApi.Addr) - go utils.RunBackgroundHttp(l, srvApi) + go utils.RunBackgroundHttp(l, srvApi, lnApi) } - if srvConf.HttpListen != "" { - srvHttp = servers.NewHttpServer(srvConf, promRegistry) + if config.Listen.Http != "" { + // Listen must be called before Ready + lnHttp, err := upg.Listen("tcp", config.Listen.Http) + if err != nil { + logger.Logger.Fatal("Listen failed", "err", err) + } + srvHttp = servers.NewHttpServer(uint16(httpsPort), srvConf, promRegistry) srvHttp.SetKeepAlivesEnabled(false) l := logger.Logger.With("server", "HTTP") l.Info("Starting server", "addr", srvHttp.Addr) - go utils.RunBackgroundHttp(l, srvHttp) + go utils.RunBackgroundHttp(l, srvHttp, lnHttp) } - if srvConf.HttpsListen != "" { + if config.Listen.Https != "" { + // Listen must be called before Ready + lnHttps, err := upg.Listen("tcp", config.Listen.Https) + if err != nil { + logger.Logger.Fatal("Listen failed", "err", err) + } srvHttps = servers.NewHttpsServer(srvConf, promRegistry) srvHttps.SetKeepAlivesEnabled(false) l := logger.Logger.With("server", "HTTPS") l.Info("Starting server", "addr", srvHttps.Addr) - go utils.RunBackgroundHttps(l, srvHttps) + go utils.RunBackgroundHttps(l, srvHttps, lnHttps) } - exit_reload.ExitReload("Violet", func() { - allCompilables.Compile() - }, func() { - // stop updating certificates - allowedCerts.Stop() + // Do an upgrade on SIGHUP + go func() { + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGHUP) + for range sig { + err := upg.Upgrade() + if err != nil { + logger.Logger.Error("Failed upgrade", "err", err) + } + } + }() - // close websockets first - ws.Shutdown() + logger.Logger.Info("Ready") + if err := upg.Ready(); err != nil { + panic(err) + } + <-upg.Exit() - // close http servers - if srvApi != nil { - _ = srvApi.Close() - } - if srvHttp != nil { - _ = srvHttp.Close() - } - if srvHttps != nil { - _ = srvHttps.Close() - } + time.AfterFunc(30*time.Second, func() { + logger.Logger.Warn("Graceful shutdown timed out") + os.Exit(1) }) + + // stop updating certificates + allowedCerts.Stop() + + // close websockets first + ws.Shutdown() + + // close http servers + if srvApi != nil { + _ = srvApi.Close() + } + if srvHttp != nil { + _ = srvHttp.Close() + } + if srvHttps != nil { + _ = srvHttps.Close() + } + + return subcommands.ExitSuccess } diff --git a/go.mod b/go.mod index 9eb96dd..0901161 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/1f349/mjwt v0.2.5 github.com/AlecAivazis/survey/v2 v2.3.7 github.com/charmbracelet/log v0.4.0 + github.com/cloudflare/tableflip v1.2.3 github.com/golang-migrate/migrate/v4 v4.17.1 github.com/google/subcommands v1.2.0 github.com/google/uuid v1.6.0 @@ -13,7 +14,6 @@ require ( github.com/julienschmidt/httprouter v1.3.0 github.com/mattn/go-sqlite3 v1.14.22 github.com/mrmelon54/certgen v0.0.2 - github.com/mrmelon54/exit-reload v0.0.2 github.com/mrmelon54/png2ico v1.0.2 github.com/mrmelon54/rescheduler v0.0.3 github.com/mrmelon54/trie v0.0.3 diff --git a/go.sum b/go.sum index e97a765..779bc6f 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,8 @@ github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMt github.com/charmbracelet/lipgloss v0.10.0/go.mod h1:Wig9DSfvANsxqkRsqj6x87irdy123SR4dOXlKa91ciE= github.com/charmbracelet/log v0.4.0 h1:G9bQAcx8rWA2T3pWvx7YtPTPwgqpk7D68BX21IRW8ZM= github.com/charmbracelet/log v0.4.0/go.mod h1:63bXt/djrizTec0l11H20t8FDSvA4CRZJ1KH22MdptM= +github.com/cloudflare/tableflip v1.2.3 h1:8I+B99QnnEWPHOY3fWipwVKxS70LGgUsslG7CSfmHMw= +github.com/cloudflare/tableflip v1.2.3/go.mod h1:P4gRehmV6Z2bY5ao5ml9Pd8u6kuEnlB37pUFMmv7j2E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.17 h1:QeVUsEDNrLBW4tMgZHvxy18sKtr6VI492kBhUfhDJNI= github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= @@ -72,8 +74,6 @@ github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQ github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/mrmelon54/certgen v0.0.2 h1:4CMDkA/gGZu+E4iikU+5qdOWK7qOQrk58KtUfnmyYmY= github.com/mrmelon54/certgen v0.0.2/go.mod h1:vwrWSXQmxZYqEyh+cf05IvDIFV2aYuxL4+O6ABIlN8M= -github.com/mrmelon54/exit-reload v0.0.2 h1:vqgfrMD/bF21HkDsWgg5+NLjFDrD3KGVEN/iTrMn9Ms= -github.com/mrmelon54/exit-reload v0.0.2/go.mod h1:aE3NhsqGMLUqmv6cJZRouC/8gXkZTvVSabRGOpI+Vjc= github.com/mrmelon54/png2ico v1.0.2 h1:KyJd3ATmDjxAJS28MTSf44GxzYnlZ+7KT8SXzGb3sN8= github.com/mrmelon54/png2ico v1.0.2/go.mod h1:vp8Be9y5cz102ANon+BnsIzTUdet3VQRvOuWJTH9h0M= github.com/mrmelon54/rescheduler v0.0.3 h1:TrkJL6S7PKvXuo1mvdgRgsILA/pk5L1lrXhV/q7IEzQ= @@ -130,6 +130,7 @@ golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/servers/api/api.go b/servers/api/api.go index 36bbd55..07c8f3c 100644 --- a/servers/api/api.go +++ b/servers/api/api.go @@ -48,7 +48,6 @@ func NewApiServer(conf *conf.Conf, compileTarget utils.MultiCompilable, registry // Create and run http server return &http.Server{ - Addr: conf.ApiListen, Handler: r, ReadTimeout: time.Minute, ReadHeaderTimeout: time.Minute, diff --git a/servers/conf/conf.go b/servers/conf/conf.go index cdbb428..61874d9 100644 --- a/servers/conf/conf.go +++ b/servers/conf/conf.go @@ -11,16 +11,13 @@ import ( // Conf stores the shared configuration for the API, HTTP and HTTPS servers. type Conf struct { - ApiListen string // api server listen address - HttpListen string // http server listen address - HttpsListen string // https server listen address - RateLimit uint64 // rate limit per minute - DB *database.Queries - Domains utils.DomainProvider - Acme utils.AcmeChallengeProvider - Certs utils.CertProvider - Favicons *favicons.Favicons - Signer mjwt.Verifier - ErrorPages *errorPages.ErrorPages - Router *router.Manager + RateLimit uint64 // rate limit per minute + DB *database.Queries + Domains utils.DomainProvider + Acme utils.AcmeChallengeProvider + Certs utils.CertProvider + Favicons *favicons.Favicons + Signer mjwt.Verifier + ErrorPages *errorPages.ErrorPages + Router *router.Manager } diff --git a/servers/http.go b/servers/http.go index 886b750..8746c89 100644 --- a/servers/http.go +++ b/servers/http.go @@ -17,13 +17,9 @@ import ( // // `/.well-known/acme-challenge/{token}` is used for outputting answers for // acme challenges, this is used for Let's Encrypt HTTP verification. -func NewHttpServer(conf *conf.Conf, registry *prometheus.Registry) *http.Server { +func NewHttpServer(httpsPort uint16, conf *conf.Conf, registry *prometheus.Registry) *http.Server { r := httprouter.New() var secureExtend string - _, httpsPort, ok := utils.SplitDomainPort(conf.HttpsListen, 443) - if !ok { - httpsPort = 443 - } if httpsPort != 443 { secureExtend = fmt.Sprintf(":%d", httpsPort) } @@ -72,7 +68,6 @@ func NewHttpServer(conf *conf.Conf, registry *prometheus.Registry) *http.Server // Create and run http server return &http.Server{ - Addr: conf.HttpListen, Handler: metricsMiddleware, ReadTimeout: time.Minute, ReadHeaderTimeout: time.Minute, diff --git a/servers/http_test.go b/servers/http_test.go index cb66080..4c80017 100644 --- a/servers/http_test.go +++ b/servers/http_test.go @@ -18,7 +18,7 @@ func TestNewHttpServer_AcmeChallenge(t *testing.T) { Acme: utils.NewAcmeChallenge(), Signer: fake.SnakeOilProv, } - srv := NewHttpServer(httpConf, nil) + srv := NewHttpServer(443, httpConf, nil) httpConf.Acme.Put("example.com", "456", "456def") req, err := http.NewRequest(http.MethodGet, "https://example.com/.well-known/acme-challenge/456", nil) diff --git a/servers/https.go b/servers/https.go index 3888aae..47682b7 100644 --- a/servers/https.go +++ b/servers/https.go @@ -42,7 +42,6 @@ func NewHttpsServer(conf *conf.Conf, registry *prometheus.Registry) *http.Server }) return &http.Server{ - Addr: conf.HttpsListen, Handler: hsts, TLSConfig: &tls.Config{ // Suggested by https://ssl-config.mozilla.org/#server=go&version=1.21.5&config=intermediate diff --git a/utils/server-utils.go b/utils/server-utils.go index 32656a2..7a75e57 100644 --- a/utils/server-utils.go +++ b/utils/server-utils.go @@ -3,6 +3,7 @@ package utils import ( "errors" "github.com/charmbracelet/log" + "net" "net/http" "strings" ) @@ -21,14 +22,14 @@ func logHttpServerError(logger *log.Logger, err error) { // RunBackgroundHttp runs a http server and logs when the server closes or // errors. -func RunBackgroundHttp(logger *log.Logger, s *http.Server) { - logHttpServerError(logger, s.ListenAndServe()) +func RunBackgroundHttp(logger *log.Logger, s *http.Server, ln net.Listener) { + logHttpServerError(logger, s.Serve(ln)) } // RunBackgroundHttps runs a http server with TLS encryption and logs when the // server closes or errors. -func RunBackgroundHttps(logger *log.Logger, s *http.Server) { - logHttpServerError(logger, s.ListenAndServeTLS("", "")) +func RunBackgroundHttps(logger *log.Logger, s *http.Server, ln net.Listener) { + logHttpServerError(logger, s.ServeTLS(ln, "", "")) } // GetBearer returns the bearer from the Authorization header or an empty string