From 9147a813cb8348862eebd0027c31560374f0b94b Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Mon, 24 Apr 2023 15:36:21 +0100 Subject: [PATCH] Add Ctrl+C handling, self-signed mode for devs and fix some bugs in routing --- .gitignore | 1 + .idea/dataSources.xml | 4 +- certs/certs.go | 52 ++++++- cmd/violet/main.go | 72 +++++++--- domains/create-table-domains.sql | 6 + domains/domains.go | 14 +- identifier.sqlite | Bin 16384 -> 0 bytes router/create-table-redirects.sql | 10 ++ router/create-table-routes.sql | 14 ++ router/manager.go | 228 ++++++++++++++++++++++++++++++ router/query-table-redirects.sql | 7 + router/query-table-routes.sql | 11 ++ router/router.go | 7 +- router/router_test.go | 2 - servers/conf.go | 11 +- servers/https.go | 16 +-- target/redirect.go | 44 ++++-- target/route.go | 73 ++++++++-- target/route_test.go | 75 ++++++++++ utils/response.go | 1 + utils/server-utils.go | 6 + 21 files changed, 581 insertions(+), 73 deletions(-) create mode 100644 .gitignore create mode 100644 domains/create-table-domains.sql delete mode 100644 identifier.sqlite create mode 100644 router/create-table-redirects.sql create mode 100644 router/create-table-routes.sql create mode 100644 router/manager.go create mode 100644 router/query-table-redirects.sql create mode 100644 router/query-table-routes.sql create mode 100644 target/route_test.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9b1dffd --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.sqlite diff --git a/.idea/dataSources.xml b/.idea/dataSources.xml index 2e275d2..aba4cb2 100644 --- a/.idea/dataSources.xml +++ b/.idea/dataSources.xml @@ -1,11 +1,11 @@ - + sqlite.xerial true org.sqlite.JDBC - jdbc:sqlite:identifier.sqlite + jdbc:sqlite:__db.sqlite $ProjectFileDir$ diff --git a/certs/certs.go b/certs/certs.go index b79b3e5..66a0f6d 100644 --- a/certs/certs.go +++ b/certs/certs.go @@ -3,34 +3,55 @@ package certs import ( "code.mrmelon54.com/melon/certgen" "crypto/tls" + "crypto/x509/pkix" "fmt" "github.com/MrMelon54/violet/utils" "io/fs" "log" + "math/big" "path/filepath" "sync" + "sync/atomic" + "time" ) // Certs is the certificate loader and management system. type Certs struct { cDir fs.FS kDir fs.FS + ss bool s *sync.RWMutex m map[string]*tls.Certificate + ca *certgen.CertGen + sn atomic.Int64 } // New creates a new cert list -func New(certDir fs.FS, keyDir fs.FS) *Certs { - a := &Certs{ +func New(certDir fs.FS, keyDir fs.FS, selfCert bool) *Certs { + c := &Certs{ cDir: certDir, kDir: keyDir, + ss: selfCert, s: &sync.RWMutex{}, m: make(map[string]*tls.Certificate), } + if c.ss { + ca, err := certgen.MakeCaTls(pkix.Name{ + Country: []string{"GB"}, + Organization: []string{"Violet"}, + OrganizationalUnit: []string{"Development"}, + SerialNumber: "0", + CommonName: fmt.Sprintf("%d.violet.test", time.Now().Unix()), + }, big.NewInt(0)) + if err != nil { + log.Fatalln("Failed to generate CA cert for self-signed mode:", err) + } + c.ca = ca + } // run compile to get the initial data - a.Compile() - return a + c.Compile() + return c } func (c *Certs) GetCertForDomain(domain string) *tls.Certificate { @@ -43,6 +64,24 @@ func (c *Certs) GetCertForDomain(domain string) *tls.Certificate { return cert } + // if self-signed certificate is enabled then generate a certificate + if c.ss { + sn := c.sn.Add(1) + serverTls, err := certgen.MakeServerTls(c.ca, pkix.Name{ + Country: []string{"GB"}, + Organization: []string{domain}, + OrganizationalUnit: []string{domain}, + SerialNumber: fmt.Sprintf("%d", sn), + CommonName: domain, + }, big.NewInt(sn), []string{domain}, nil) + if err != nil { + return nil + } + leaf := serverTls.GetTlsLeaf() + c.m[domain] = &leaf + return &leaf + } + // lookup and return wildcard cert if wildcardDomain, ok := utils.ReplaceSubdomainWithWildcard(domain); ok { if cert, ok := c.m[wildcardDomain]; ok { @@ -55,6 +94,11 @@ func (c *Certs) GetCertForDomain(domain string) *tls.Certificate { } func (c *Certs) Compile() { + // don't bother compiling in self-signed mode + if c.ss { + return + } + // async compile magic go func() { // new map diff --git a/cmd/violet/main.go b/cmd/violet/main.go index 1d51441..e7bf603 100644 --- a/cmd/violet/main.go +++ b/cmd/violet/main.go @@ -4,16 +4,22 @@ import ( "database/sql" _ "embed" "flag" + "fmt" "github.com/MrMelon54/violet/certs" "github.com/MrMelon54/violet/domains" errorPages "github.com/MrMelon54/violet/error-pages" "github.com/MrMelon54/violet/favicons" "github.com/MrMelon54/violet/proxy" + "github.com/MrMelon54/violet/router" "github.com/MrMelon54/violet/servers" "github.com/MrMelon54/violet/utils" _ "github.com/mattn/go-sqlite3" "log" + "net/http" "os" + "os/signal" + "syscall" + "time" ) // flags - each one has a usage field lol @@ -21,6 +27,7 @@ var ( databasePath = flag.String("db", "", "/path/to/database.sqlite : path to the database file") keyPath = flag.String("keys", "", "/path/to/keys : path contains the keys with names matching the certificates and '.key' extensions") certPath = flag.String("certs", "", "/path/to/certificates : path contains the certificates to load in armoured PEM encoding") + selfSigned = flag.Bool("ss", false, "enable self-signed certificate mode") errorPagePath = flag.String("errors", "", "/path/to/error-pages : path contains the custom error pages") apiListen = flag.String("api", "127.0.0.1:8080", "address for api listening") httpListen = flag.String("http", "0.0.0.0:80", "address for http listening") @@ -30,16 +37,21 @@ var ( func main() { log.Println("[Violet] Starting...") + flag.Parse() - // create path to cert dir - err := os.MkdirAll(*certPath, os.ModePerm) - if err != nil { - log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath) + if *certPath != "" { + // create path to cert dir + err := os.MkdirAll(*certPath, os.ModePerm) + if err != nil { + log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath) + } } - // create path to key dir - err = os.MkdirAll(*keyPath, os.ModePerm) - if err != nil { - log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath) + if *keyPath != "" { + // create path to key dir + err := os.MkdirAll(*keyPath, os.ModePerm) + if err != nil { + log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath) + } } // open sqlite database @@ -48,11 +60,12 @@ func main() { log.Fatalf("[Violet] Failed to open database '%s'...", *databasePath) } - allowedDomains := domains.New(db) // load allowed domains - allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath)) // load certificate manager - reverseProxy := proxy.CreateHybridReverseProxy() // load reverse proxy - dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider - dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider + allowedDomains := domains.New(db) // load allowed domains + allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath), *selfSigned) // load certificate manager + reverseProxy := proxy.CreateHybridReverseProxy() // load reverse proxy + dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider + dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider + dynamicRouter := router.NewManager(db, reverseProxy) // load dynamic router manager // struct containing config for the http servers srvConf := &servers.Conf{ @@ -65,16 +78,41 @@ func main() { Favicons: dynamicFavicons, Verify: nil, // TODO: add mjwt verify support ErrorPages: dynamicErrorPages, - Proxy: reverseProxy, + Router: dynamicRouter, } + var srvApi, srvHttp, srvHttps *http.Server if *apiListen != "" { - servers.NewApiServer(srvConf, utils.MultiCompilable{allowedDomains}) + srvApi = servers.NewApiServer(srvConf, utils.MultiCompilable{allowedDomains, allowedCerts, dynamicFavicons, dynamicErrorPages, dynamicRouter}) } if *httpListen != "" { - servers.NewHttpServer(srvConf) + srvHttp = servers.NewHttpServer(srvConf) } if *httpsListen != "" { - servers.NewHttpsServer(srvConf) + srvHttps = servers.NewHttpsServer(srvConf) } + + // Wait for exit signal + sc := make(chan os.Signal, 1) + signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt, os.Kill) + <-sc + fmt.Println() + + // Stop servers + log.Printf("[Violet] Stopping...") + n := time.Now() + + // close http servers + if srvApi != nil { + srvApi.Close() + } + if srvHttp != nil { + srvHttp.Close() + } + if srvHttps != nil { + srvHttps.Close() + } + + log.Printf("[Violet] Took '%s' to shutdown\n", time.Now().Sub(n)) + log.Println("[Violet] Goodbye") } diff --git a/domains/create-table-domains.sql b/domains/create-table-domains.sql new file mode 100644 index 0000000..129e224 --- /dev/null +++ b/domains/create-table-domains.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS domains +( + id INTEGER PRIMARY KEY AUTOINCREMENT, + domain TEXT, + active INTEGER DEFAULT 1 +); diff --git a/domains/domains.go b/domains/domains.go index 5ae003b..9bac6e6 100644 --- a/domains/domains.go +++ b/domains/domains.go @@ -2,12 +2,16 @@ package domains import ( "database/sql" + _ "embed" "github.com/MrMelon54/violet/utils" "log" "strings" "sync" ) +//go:embed create-table-domains.sql +var createTableDomains string + // Domains is the domain list and management system. type Domains struct { db *sql.DB @@ -24,7 +28,7 @@ func New(db *sql.DB) *Domains { } // init domains table - _, err := a.db.Exec(`create table if not exists domains (id integer primary key autoincrement, domain varchar)`) + _, err := a.db.Exec(createTableDomains) if err != nil { log.Printf("[WARN] Failed to generate 'domains' table\n") return nil @@ -37,11 +41,7 @@ func New(db *sql.DB) *Domains { // IsValid returns true if a domain is valid. func (d *Domains) IsValid(host string) bool { - // remove the port - domain, ok := utils.GetDomainWithoutPort(host) - if !ok { - return false - } + domain, _, _ := utils.SplitDomainPort(host, 0) // read lock for safety d.s.RLock() @@ -88,7 +88,7 @@ func (d *Domains) internalCompile(m map[string]struct{}) error { log.Println("[Domains] Updating domains from database") // sql or something? - rows, err := d.db.Query("select name from domains where enabled = true") + rows, err := d.db.Query(`select domain from domains where active = 1`) if err != nil { return err } diff --git a/identifier.sqlite b/identifier.sqlite deleted file mode 100644 index 288d146498ac3cc3dba539263fc609851a3193e2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16384 zcmeI&&uYRj90%~kmU${^x3&sc`6;o|f^O8L8$26b# zp|6m>N2~6RShXACUf14ChXMfzKmY;|fB*y_009U<00PG^;M?U!vq`t7nm^CQBn-Gz zan_Zt?uUc%kc|hIqan)%SzAt+R9Z|$#1@eZc(h?JV#D}Khf;Zw2!ztTY&l!=$eZz~ zv#nX>Myo})VVcP2Yw3k5&U^bHRX+OLT(dCNd8QtV*VFG{q4olkGN0+(mYhbnOCQti z#LKz4R6oUHxf03~h0r?KTw$`Ucp!S_&5k(?enNTj*OW;H=Gabp+Nw!gC&YbsKhgm- z5P$##AOHafKmY;|fB*y_009UbyMXIh{300Izz y00bZa0SG_<0uU$&RB876g_Ga^3kwiJ00Izz00bZa0SG_<0uX=z1R(H_1U>