Add support for tableflip

This commit is contained in:
Melon 2024-08-17 12:29:50 +01:00
parent f442409ebf
commit 3e86b91ec3
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
9 changed files with 127 additions and 83 deletions

View File

@ -18,35 +18,53 @@ import (
"github.com/1f349/violet/servers/api"
"github.com/1f349/violet/servers/conf"
"github.com/1f349/violet/utils"
"github.com/charmbracelet/log"
"github.com/cloudflare/tableflip"
"github.com/google/subcommands"
"github.com/mrmelon54/exit-reload"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"io/fs"
"net/http"
"os"
"os/signal"
"path/filepath"
"syscall"
"time"
)
type serveCmd struct {
configPath string
cpuprofile string
debugLog bool
pidFile string
}
func (s *serveCmd) Name() string { return "serve" }
func (s *serveCmd) Synopsis() string { return "Serve reverse proxy server" }
func (s *serveCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&s.configPath, "conf", "", "/path/to/config.json : path to the config file")
f.BoolVar(&s.debugLog, "debug", false, "enable debug logging")
f.StringVar(&s.pidFile, "pid-file", "", "path to pid file")
}
func (s *serveCmd) Usage() string {
return `serve [-conf <config file>]
return `serve [-conf <config file>] [-debug] [-pid-file <pid file>]
Serve reverse proxy server using information from config file
`
}
func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
if s.debugLog {
logger.Logger.SetLevel(log.DebugLevel)
}
logger.Logger.Info("Starting...")
upg, err := tableflip.New(tableflip.Options{
PIDFile: s.pidFile,
})
if err != nil {
panic(err)
}
defer upg.Stop()
if s.configPath == "" {
logger.Logger.Info("Error: config flag is missing")
return subcommands.ExitUsageError
@ -71,13 +89,9 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
// working directory is the parent of the config file
wd := filepath.Dir(s.configPath)
normalLoad(config, wd)
return subcommands.ExitSuccess
}
func normalLoad(startUp startUpConfig, wd string) {
// the cert and key paths are useless in self-signed mode
if !startUp.SelfSigned {
if !config.SelfSigned {
// create path to cert dir
err := os.MkdirAll(filepath.Join(wd, "certs"), os.ModePerm)
if err != nil {
@ -92,11 +106,11 @@ func normalLoad(startUp startUpConfig, wd string) {
// errorPageDir stores an FS interface for accessing the error page directory
var errorPageDir fs.FS
if startUp.ErrorPagePath != "" {
errorPageDir = os.DirFS(startUp.ErrorPagePath)
err := os.MkdirAll(startUp.ErrorPagePath, os.ModePerm)
if config.ErrorPagePath != "" {
errorPageDir = os.DirFS(config.ErrorPagePath)
err := os.MkdirAll(config.ErrorPagePath, os.ModePerm)
if err != nil {
logger.Logger.Fatal("Failed to create error page", "path", startUp.ErrorPagePath)
logger.Logger.Fatal("Failed to create error page", "path", config.ErrorPagePath)
}
}
@ -125,18 +139,15 @@ func normalLoad(startUp startUpConfig, wd string) {
ws := websocket.NewServer()
allowedDomains := domains.New(db) // load allowed domains
acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store
allowedCerts := certs.New(certDir, keyDir, startUp.SelfSigned) // load certificate manager
allowedCerts := certs.New(certDir, keyDir, config.SelfSigned) // load certificate manager
hybridTransport := proxy.NewHybridTransport(ws) // load reverse proxy
dynamicFavicons := favicons.New(db, startUp.InkscapeCmd) // load dynamic favicon provider
dynamicFavicons := favicons.New(db, config.InkscapeCmd) // load dynamic favicon provider
dynamicErrorPages := errorPages.New(errorPageDir) // load dynamic error page provider
dynamicRouter := router.NewManager(db, hybridTransport) // load dynamic router manager
// struct containing config for the http servers
srvConf := &conf.Conf{
ApiListen: startUp.Listen.Api,
HttpListen: startUp.Listen.Http,
HttpsListen: startUp.Listen.Https,
RateLimit: startUp.RateLimit,
RateLimit: config.RateLimit,
DB: db,
Domains: allowedDomains,
Acme: acmeChallenges,
@ -151,32 +162,72 @@ func normalLoad(startUp startUpConfig, wd string) {
allCompilables := utils.MultiCompilable{allowedDomains, allowedCerts, dynamicFavicons, dynamicErrorPages, dynamicRouter}
allCompilables.Compile()
_, httpsPort, ok := utils.SplitDomainPort(config.Listen.Https, 443)
if !ok {
httpsPort = 443
}
var srvApi, srvHttp, srvHttps *http.Server
if srvConf.ApiListen != "" {
if config.Listen.Api != "" {
// Listen must be called before Ready
lnApi, err := upg.Listen("tcp", config.Listen.Api)
if err != nil {
logger.Logger.Fatal("Listen failed", "err", err)
}
srvApi = api.NewApiServer(srvConf, allCompilables, promRegistry)
srvApi.SetKeepAlivesEnabled(false)
l := logger.Logger.With("server", "API")
l.Info("Starting server", "addr", srvApi.Addr)
go utils.RunBackgroundHttp(l, srvApi)
go utils.RunBackgroundHttp(l, srvApi, lnApi)
}
if srvConf.HttpListen != "" {
srvHttp = servers.NewHttpServer(srvConf, promRegistry)
if config.Listen.Http != "" {
// Listen must be called before Ready
lnHttp, err := upg.Listen("tcp", config.Listen.Http)
if err != nil {
logger.Logger.Fatal("Listen failed", "err", err)
}
srvHttp = servers.NewHttpServer(uint16(httpsPort), srvConf, promRegistry)
srvHttp.SetKeepAlivesEnabled(false)
l := logger.Logger.With("server", "HTTP")
l.Info("Starting server", "addr", srvHttp.Addr)
go utils.RunBackgroundHttp(l, srvHttp)
go utils.RunBackgroundHttp(l, srvHttp, lnHttp)
}
if config.Listen.Https != "" {
// Listen must be called before Ready
lnHttps, err := upg.Listen("tcp", config.Listen.Https)
if err != nil {
logger.Logger.Fatal("Listen failed", "err", err)
}
if srvConf.HttpsListen != "" {
srvHttps = servers.NewHttpsServer(srvConf, promRegistry)
srvHttps.SetKeepAlivesEnabled(false)
l := logger.Logger.With("server", "HTTPS")
l.Info("Starting server", "addr", srvHttps.Addr)
go utils.RunBackgroundHttps(l, srvHttps)
go utils.RunBackgroundHttps(l, srvHttps, lnHttps)
}
exit_reload.ExitReload("Violet", func() {
allCompilables.Compile()
}, func() {
// Do an upgrade on SIGHUP
go func() {
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGHUP)
for range sig {
err := upg.Upgrade()
if err != nil {
logger.Logger.Error("Failed upgrade", "err", err)
}
}
}()
logger.Logger.Info("Ready")
if err := upg.Ready(); err != nil {
panic(err)
}
<-upg.Exit()
time.AfterFunc(30*time.Second, func() {
logger.Logger.Warn("Graceful shutdown timed out")
os.Exit(1)
})
// stop updating certificates
allowedCerts.Stop()
@ -193,5 +244,6 @@ func normalLoad(startUp startUpConfig, wd string) {
if srvHttps != nil {
_ = srvHttps.Close()
}
})
return subcommands.ExitSuccess
}

2
go.mod
View File

@ -6,6 +6,7 @@ require (
github.com/1f349/mjwt v0.2.5
github.com/AlecAivazis/survey/v2 v2.3.7
github.com/charmbracelet/log v0.4.0
github.com/cloudflare/tableflip v1.2.3
github.com/golang-migrate/migrate/v4 v4.17.1
github.com/google/subcommands v1.2.0
github.com/google/uuid v1.6.0
@ -13,7 +14,6 @@ require (
github.com/julienschmidt/httprouter v1.3.0
github.com/mattn/go-sqlite3 v1.14.22
github.com/mrmelon54/certgen v0.0.2
github.com/mrmelon54/exit-reload v0.0.2
github.com/mrmelon54/png2ico v1.0.2
github.com/mrmelon54/rescheduler v0.0.3
github.com/mrmelon54/trie v0.0.3

5
go.sum
View File

@ -16,6 +16,8 @@ github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMt
github.com/charmbracelet/lipgloss v0.10.0/go.mod h1:Wig9DSfvANsxqkRsqj6x87irdy123SR4dOXlKa91ciE=
github.com/charmbracelet/log v0.4.0 h1:G9bQAcx8rWA2T3pWvx7YtPTPwgqpk7D68BX21IRW8ZM=
github.com/charmbracelet/log v0.4.0/go.mod h1:63bXt/djrizTec0l11H20t8FDSvA4CRZJ1KH22MdptM=
github.com/cloudflare/tableflip v1.2.3 h1:8I+B99QnnEWPHOY3fWipwVKxS70LGgUsslG7CSfmHMw=
github.com/cloudflare/tableflip v1.2.3/go.mod h1:P4gRehmV6Z2bY5ao5ml9Pd8u6kuEnlB37pUFMmv7j2E=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/creack/pty v1.1.17 h1:QeVUsEDNrLBW4tMgZHvxy18sKtr6VI492kBhUfhDJNI=
github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
@ -72,8 +74,6 @@ github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQ
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
github.com/mrmelon54/certgen v0.0.2 h1:4CMDkA/gGZu+E4iikU+5qdOWK7qOQrk58KtUfnmyYmY=
github.com/mrmelon54/certgen v0.0.2/go.mod h1:vwrWSXQmxZYqEyh+cf05IvDIFV2aYuxL4+O6ABIlN8M=
github.com/mrmelon54/exit-reload v0.0.2 h1:vqgfrMD/bF21HkDsWgg5+NLjFDrD3KGVEN/iTrMn9Ms=
github.com/mrmelon54/exit-reload v0.0.2/go.mod h1:aE3NhsqGMLUqmv6cJZRouC/8gXkZTvVSabRGOpI+Vjc=
github.com/mrmelon54/png2ico v1.0.2 h1:KyJd3ATmDjxAJS28MTSf44GxzYnlZ+7KT8SXzGb3sN8=
github.com/mrmelon54/png2ico v1.0.2/go.mod h1:vp8Be9y5cz102ANon+BnsIzTUdet3VQRvOuWJTH9h0M=
github.com/mrmelon54/rescheduler v0.0.3 h1:TrkJL6S7PKvXuo1mvdgRgsILA/pk5L1lrXhV/q7IEzQ=
@ -130,6 +130,7 @@ golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@ -48,7 +48,6 @@ func NewApiServer(conf *conf.Conf, compileTarget utils.MultiCompilable, registry
// Create and run http server
return &http.Server{
Addr: conf.ApiListen,
Handler: r,
ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute,

View File

@ -11,9 +11,6 @@ import (
// Conf stores the shared configuration for the API, HTTP and HTTPS servers.
type Conf struct {
ApiListen string // api server listen address
HttpListen string // http server listen address
HttpsListen string // https server listen address
RateLimit uint64 // rate limit per minute
DB *database.Queries
Domains utils.DomainProvider

View File

@ -17,13 +17,9 @@ import (
//
// `/.well-known/acme-challenge/{token}` is used for outputting answers for
// acme challenges, this is used for Let's Encrypt HTTP verification.
func NewHttpServer(conf *conf.Conf, registry *prometheus.Registry) *http.Server {
func NewHttpServer(httpsPort uint16, conf *conf.Conf, registry *prometheus.Registry) *http.Server {
r := httprouter.New()
var secureExtend string
_, httpsPort, ok := utils.SplitDomainPort(conf.HttpsListen, 443)
if !ok {
httpsPort = 443
}
if httpsPort != 443 {
secureExtend = fmt.Sprintf(":%d", httpsPort)
}
@ -72,7 +68,6 @@ func NewHttpServer(conf *conf.Conf, registry *prometheus.Registry) *http.Server
// Create and run http server
return &http.Server{
Addr: conf.HttpListen,
Handler: metricsMiddleware,
ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute,

View File

@ -18,7 +18,7 @@ func TestNewHttpServer_AcmeChallenge(t *testing.T) {
Acme: utils.NewAcmeChallenge(),
Signer: fake.SnakeOilProv,
}
srv := NewHttpServer(httpConf, nil)
srv := NewHttpServer(443, httpConf, nil)
httpConf.Acme.Put("example.com", "456", "456def")
req, err := http.NewRequest(http.MethodGet, "https://example.com/.well-known/acme-challenge/456", nil)

View File

@ -42,7 +42,6 @@ func NewHttpsServer(conf *conf.Conf, registry *prometheus.Registry) *http.Server
})
return &http.Server{
Addr: conf.HttpsListen,
Handler: hsts,
TLSConfig: &tls.Config{
// Suggested by https://ssl-config.mozilla.org/#server=go&version=1.21.5&config=intermediate

View File

@ -3,6 +3,7 @@ package utils
import (
"errors"
"github.com/charmbracelet/log"
"net"
"net/http"
"strings"
)
@ -21,14 +22,14 @@ func logHttpServerError(logger *log.Logger, err error) {
// RunBackgroundHttp runs a http server and logs when the server closes or
// errors.
func RunBackgroundHttp(logger *log.Logger, s *http.Server) {
logHttpServerError(logger, s.ListenAndServe())
func RunBackgroundHttp(logger *log.Logger, s *http.Server, ln net.Listener) {
logHttpServerError(logger, s.Serve(ln))
}
// RunBackgroundHttps runs a http server with TLS encryption and logs when the
// server closes or errors.
func RunBackgroundHttps(logger *log.Logger, s *http.Server) {
logHttpServerError(logger, s.ListenAndServeTLS("", ""))
func RunBackgroundHttps(logger *log.Logger, s *http.Server, ln net.Listener) {
logHttpServerError(logger, s.ServeTLS(ln, "", ""))
}
// GetBearer returns the bearer from the Authorization header or an empty string