From 371a150c2a9c0ef0049683e337e9dfd91dc16ded Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Sat, 22 Apr 2023 18:11:21 +0100 Subject: [PATCH] Server config and comments --- .idea/sqldialects.xml | 1 - certs/certs.go | 120 +++++++++++++++++++++++++++++++++++++ cmd/violet/init.sql | 1 - cmd/violet/main.go | 70 ++++++++++++++-------- domains/domains.go | 13 +++- error-pages/error-pages.go | 46 ++++++++++++++ favicons/favicons.go | 3 + go.mod | 3 + go.sum | 6 ++ servers/conf.go | 24 ++++++++ servers/http.go | 14 +++-- servers/https.go | 78 ++++++++++++++++++++++++ utils/domain-utils.go | 7 ++- utils/domain-utils_test.go | 4 +- 14 files changed, 351 insertions(+), 39 deletions(-) create mode 100644 certs/certs.go delete mode 100644 cmd/violet/init.sql create mode 100644 error-pages/error-pages.go create mode 100644 favicons/favicons.go create mode 100644 servers/conf.go diff --git a/.idea/sqldialects.xml b/.idea/sqldialects.xml index ec81629..c0e01ca 100644 --- a/.idea/sqldialects.xml +++ b/.idea/sqldialects.xml @@ -1,7 +1,6 @@ - \ No newline at end of file diff --git a/certs/certs.go b/certs/certs.go new file mode 100644 index 0000000..098e8e3 --- /dev/null +++ b/certs/certs.go @@ -0,0 +1,120 @@ +package certs + +import ( + "code.mrmelon54.com/melon/certgen" + "crypto/tls" + "fmt" + "github.com/MrMelon54/violet/utils" + "io/fs" + "log" + "path/filepath" + "sync" +) + +type Certs struct { + cDir fs.FS + kDir fs.FS + s *sync.RWMutex + m map[string]*tls.Certificate +} + +func New(certDir fs.FS, keyDir fs.FS) *Certs { + a := &Certs{ + cDir: certDir, + kDir: keyDir, + s: &sync.RWMutex{}, + m: make(map[string]*tls.Certificate), + } + a.Compile() + return a +} + +func (c *Certs) GetCertForDomain(domain string) *tls.Certificate { + // safety read lock + c.s.RLock() + defer c.s.RUnlock() + + // lookup and return cert + if cert, ok := c.m[domain]; ok { + return cert + } + + // lookup and return wildcard cert + if wildcardDomain, ok := utils.ReplaceSubdomainWithWildcard(domain); ok { + if cert, ok := c.m[wildcardDomain]; ok { + return cert + } + } + + // no cert found + return nil +} + +func (c *Certs) Compile() { + // async compile magic + go func() { + certMap := make(map[string]*tls.Certificate) + err := c.internalCompile(certMap) + if err != nil { + log.Printf("[Certs] Compile failed: %s\n", err) + return + } + // lock while replacing the map + c.s.Lock() + c.m = certMap + c.s.Unlock() + }() +} + +func (c *Certs) internalCompile(m map[string]*tls.Certificate) error { + // try to read dir + files, err := fs.ReadDir(c.cDir, "") + if err != nil { + return fmt.Errorf("failed to read cert dir: %w", err) + } + + log.Printf("[Certs] Compiling lookup table for %d certificates\n", len(files)) + + // find and parse certs + for _, i := range files { + // skip dirs + if i.IsDir() { + continue + } + + // get file name and extension + name := i.Name() + ext := filepath.Ext(name) + keyName := name[:len(name)-len(ext)] + "key" + + // try to read cert file + certData, err := fs.ReadFile(c.cDir, name) + if err != nil { + return fmt.Errorf("failed to read cert file '%s': %w", name, err) + } + + // try to read key file + keyData, err := fs.ReadFile(c.kDir, keyName) + if err != nil { + return fmt.Errorf("failed to read key file '%s': %w", keyName, err) + } + + // load key pair + pair, err := tls.X509KeyPair(certData, keyData) + if err != nil { + return fmt.Errorf("failed to load x509 key pair '%s + %s': %w", name, keyName, err) + } + + // load tls leaf + cert := &pair + leaf := certgen.TlsLeaf(cert) + + // save in map under each dns name + for _, j := range leaf.DNSNames { + m[j] = cert + } + } + + // well no errors happened + return nil +} diff --git a/cmd/violet/init.sql b/cmd/violet/init.sql deleted file mode 100644 index 4cc4215..0000000 --- a/cmd/violet/init.sql +++ /dev/null @@ -1 +0,0 @@ -create table acme_challenge (id integer not null primary key, key varchar, value varchar); diff --git a/cmd/violet/main.go b/cmd/violet/main.go index 89fc271..b230249 100644 --- a/cmd/violet/main.go +++ b/cmd/violet/main.go @@ -3,8 +3,8 @@ package main import ( "database/sql" _ "embed" - "errors" "flag" + "github.com/MrMelon54/violet/certs" "github.com/MrMelon54/violet/domains" "github.com/MrMelon54/violet/proxy" "github.com/MrMelon54/violet/router" @@ -15,45 +15,65 @@ import ( "os" ) -//go:embed init.sql -var initSql string - var ( - databasePath = flag.String("db", "", "/path/to/database.sqlite") - certPath = flag.String("cert", "", "/path/to/certificates") - apiListen = flag.String("api", "127.0.0.1:8080", "address for api listening") - httpListen = flag.String("http", "0.0.0.0:80", "address for http listening") - httpsListen = flag.String("https", "0.0.0.0:443", "address for https listening") + databasePath = flag.String("db", "", "/path/to/database.sqlite") + keyPath = flag.String("keys", "", "/path/to/keys : path contains the keys with names matching the certificates and '.key' extensions") + certPath = flag.String("certs", "", "/path/to/certificates : path contains the certificates to load in armoured PEM encoding") + errorPagePath = flag.String("errors", "", "/path/to/error-pages : path contains the custom error pages") + apiListen = flag.String("api", "127.0.0.1:8080", "address for api listening") + httpListen = flag.String("http", "0.0.0.0:80", "address for http listening") + httpsListen = flag.String("https", "0.0.0.0:443", "address for https listening") ) func main() { log.Println("[Violet] Starting...") - _, err := os.Stat(*certPath) - if errors.Is(err, os.ErrNotExist) { - log.Fatalf("[Violet] Certificate path '%s' does not exists", *certPath) + // create paths + err := os.MkdirAll(*certPath, os.ModePerm) + if err != nil { + log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath) + } + err = os.MkdirAll(*keyPath, os.ModePerm) + if err != nil { + log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath) } - _, err = os.Stat(*databasePath) - dbExists := !errors.Is(err, os.ErrNotExist) - + // open sqlite database db, err := sql.Open("sqlite3", *databasePath) if err != nil { log.Fatalf("[Violet] Failed to open database '%s'...", *databasePath) } - if !dbExists { - log.Println("[Violet] Creating new database and running init.sql") - _, err = db.Exec(initSql) - if err != nil { - log.Fatalf("[Violet] Failed to run init.sql") - } - } - + // load allowed domains allowedDomains := domains.New(db) + + // load allowed certificates + allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath)) + + // create reverse proxy and reverseProxy := proxy.CreateHybridReverseProxy() r := router.New(reverseProxy) - servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{allowedDomains}) - servers.NewHttpServer(*httpListen, 0, allowedDomains, db) + srvConf := &servers.Conf{ + ApiListen: *apiListen, + HttpListen: *httpListen, + HttpsListen: *httpsListen, + DB: db, + Domains: allowedDomains, + Certs: allowedCerts, + Favicons: dynamicFavicons, + Verify: apiVerify, + ErrorPages: dynamicErrorPages, + Proxy: reverseProxy, + } + + if *apiListen != "" { + servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{allowedDomains}) + } + if *httpListen != "" { + servers.NewHttpServer(srvConf) + } + if *httpsListen != "" { + servers.NewHttpsServer(srvConf) + } } diff --git a/domains/domains.go b/domains/domains.go index 52dbd7f..46cc42e 100644 --- a/domains/domains.go +++ b/domains/domains.go @@ -14,14 +14,18 @@ type Domains struct { m map[string]struct{} } +// New creates a new domain list func New(db *sql.DB) *Domains { - return &Domains{ + a := &Domains{ db: db, s: &sync.RWMutex{}, m: make(map[string]struct{}), } + a.Compile() + return a } +// IsValid returns true if a domain is valid. func (d *Domains) IsValid(host string) bool { // remove the port domain, ok := utils.GetDomainWithoutPort(host) @@ -34,6 +38,7 @@ func (d *Domains) IsValid(host string) bool { defer d.s.RUnlock() // check root domains `www.example.com`, `example.com`, `com` + // TODO: could be faster using indexes and cropping the string? n := strings.Split(domain, ".") for i := 0; i < len(n); i++ { if _, ok := d.m[strings.Join(n[i:], ".")]; ok { @@ -43,6 +48,10 @@ func (d *Domains) IsValid(host string) bool { return false } +// Compile downloads the list of domains from the database and loads them into +// memory for faster lookups. +// +// This method is asynchronous and uses locks for safety. func (d *Domains) Compile() { // async compile magic go func() { @@ -59,6 +68,8 @@ func (d *Domains) Compile() { }() } +// internalCompile is a hidden internal method for querying the database during +// the Compile() method. func (d *Domains) internalCompile(m map[string]struct{}) error { log.Println("[Domains] Updating domains from database") diff --git a/error-pages/error-pages.go b/error-pages/error-pages.go new file mode 100644 index 0000000..3ec6468 --- /dev/null +++ b/error-pages/error-pages.go @@ -0,0 +1,46 @@ +package error_pages + +import ( + "fmt" + "net/http" + "sync" +) + +// ErrorPages stores the custom error pages and is called by the servers to +// output meaningful pages for HTTP error codes +type ErrorPages struct { + s *sync.RWMutex + m map[int]func(rw http.ResponseWriter) + generic func(rw http.ResponseWriter, code int) + dir string +} + +func New(dir string) *ErrorPages { + return &ErrorPages{ + s: &sync.RWMutex{}, + m: make(map[int]func(rw http.ResponseWriter)), + generic: func(rw http.ResponseWriter, code int) { + a := http.StatusText(code) + if a != "" { + http.Error(rw, fmt.Sprintf("%d %s\n", code, a), code) + return + } + http.Error(rw, fmt.Sprintf("%d Unknown Error Code\n", code), code) + }, + dir: dir, + } +} + +func (e *ErrorPages) Compile() { + +} + +func (e *ErrorPages) ServeError(rw http.ResponseWriter, code int) { + e.s.RLock() + defer e.s.RUnlock() + if p, ok := e.m[code]; ok { + p(rw) + return + } + e.generic(rw, code) +} diff --git a/favicons/favicons.go b/favicons/favicons.go new file mode 100644 index 0000000..ac582b2 --- /dev/null +++ b/favicons/favicons.go @@ -0,0 +1,3 @@ +package favicons + +type Favicons struct{} diff --git a/go.mod b/go.mod index 8472289..643b8c9 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,15 @@ module github.com/MrMelon54/violet go 1.20 require ( + code.mrmelon54.com/melon/certgen v0.0.0-20220830133534-0fb4cb7e67d1 code.mrmelon54.com/melon/summer-utils v0.0.3 github.com/MrMelon54/trie v0.0.2 + github.com/gorilla/mux v1.8.0 github.com/julienschmidt/httprouter v1.3.0 github.com/mattn/go-sqlite3 v1.14.16 github.com/mrmelon54/mjwt v0.0.1 github.com/rs/cors v1.9.0 + github.com/sethvargo/go-limiter v0.7.2 github.com/stretchr/testify v1.8.2 golang.org/x/net v0.9.0 ) diff --git a/go.sum b/go.sum index c9a78f3..452fb92 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +code.mrmelon54.com/melon/certgen v0.0.0-20220830133534-0fb4cb7e67d1 h1:tll8DwvO1CL+xXJIMLyDmQYoYr/gA4BkcUFtNHB1BFo= +code.mrmelon54.com/melon/certgen v0.0.0-20220830133534-0fb4cb7e67d1/go.mod h1:Liyhe1bkNyeVfw6LicCgrQ+4oUT/w/qONLjvejkUim0= code.mrmelon54.com/melon/summer-utils v0.0.3 h1:Bz4o5BBOqWCNGpKkxUum4rwMn/DIdyMCKGQ/D6SXD6Q= code.mrmelon54.com/melon/summer-utils v0.0.3/go.mod h1:Gh/baXSzkf1ZhHonpPP8oQkyhhmFZcC2yTMlrwclDUw= github.com/MrMelon54/trie v0.0.2 h1:ZXWcX5ij62O9K4I/anuHmVg8L3tF0UGdlPceAASwKEY= @@ -7,6 +9,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= @@ -19,6 +23,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE= github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= +github.com/sethvargo/go-limiter v0.7.2 h1:FgC4N7RMpV5gMrUdda15FaFTkQ/L4fEqM7seXMs4oO8= +github.com/sethvargo/go-limiter v0.7.2/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/servers/conf.go b/servers/conf.go new file mode 100644 index 0000000..6479c86 --- /dev/null +++ b/servers/conf.go @@ -0,0 +1,24 @@ +package servers + +import ( + "database/sql" + "github.com/MrMelon54/violet/certs" + "github.com/MrMelon54/violet/domains" + errorPages "github.com/MrMelon54/violet/error-pages" + "github.com/MrMelon54/violet/favicons" + "github.com/mrmelon54/mjwt" + "net/http/httputil" +) + +type Conf struct { + ApiListen string + HttpListen string + HttpsListen string + DB *sql.DB + Domains *domains.Domains + Certs *certs.Certs + Favicons *favicons.Favicons + Verify mjwt.Provider + ErrorPages *errorPages.ErrorPages + Proxy *httputil.ReverseProxy +} diff --git a/servers/http.go b/servers/http.go index b2b99bd..7f04479 100644 --- a/servers/http.go +++ b/servers/http.go @@ -1,9 +1,7 @@ package servers import ( - "database/sql" "fmt" - "github.com/MrMelon54/violet/domains" "github.com/MrMelon54/violet/utils" "github.com/julienschmidt/httprouter" "log" @@ -17,9 +15,13 @@ import ( // // `/.well-known/acme-challenge/{token}` is used for outputting answers for // acme challenges, this is used for Lets Encrypt HTTP verification. -func NewHttpServer(listen string, httpsPort int, domainCheck *domains.Domains, db *sql.DB) *http.Server { +func NewHttpServer(conf *Conf) *http.Server { r := httprouter.New() var secureExtend string + _, httpsPort, ok := utils.SplitDomainPort(conf.HttpsListen, 443) + if !ok { + httpsPort = 443 + } if httpsPort != 443 { secureExtend = fmt.Sprintf(":%d", httpsPort) } @@ -28,7 +30,7 @@ func NewHttpServer(listen string, httpsPort int, domainCheck *domains.Domains, d r.GET("/.well-known/acme-challenge/{key}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { if h, ok := utils.GetDomainWithoutPort(req.Host); ok { // check if the host is valid - if !domainCheck.IsValid(req.Host) { + if !conf.Domains.IsValid(req.Host) { http.Error(rw, fmt.Sprintf("%d %s\n", 420, "Invalid host"), 420) return } @@ -41,7 +43,7 @@ func NewHttpServer(listen string, httpsPort int, domainCheck *domains.Domains, d } // prepare for executing query - prepare, err := db.Prepare("select value from acme_challenges limit 1 where domain = ? and key = ?") + prepare, err := conf.DB.Prepare("select value from acme_challenges limit 1 where domain = ? and key = ?") if err != nil { utils.RespondHttpStatus(rw, http.StatusInternalServerError) return @@ -79,7 +81,7 @@ func NewHttpServer(listen string, httpsPort int, domainCheck *domains.Domains, d // Create and run http server s := &http.Server{ - Addr: listen, + Addr: conf.HttpListen, Handler: r, ReadTimeout: time.Minute, ReadHeaderTimeout: time.Minute, diff --git a/servers/https.go b/servers/https.go index 84c4cc0..6b48e7d 100644 --- a/servers/https.go +++ b/servers/https.go @@ -1 +1,79 @@ package servers + +import ( + "crypto/tls" + "fmt" + "github.com/MrMelon54/violet/router" + "github.com/MrMelon54/violet/utils" + "github.com/gorilla/mux" + "github.com/sethvargo/go-limiter/httplimit" + "github.com/sethvargo/go-limiter/memorystore" + "log" + "net" + "net/http" + "time" +) + +// NewHttpsServer creates and runs a http server containing the public https +// endpoints for the reverse proxy. +func NewHttpsServer(conf *Conf) *http.Server { + r := router.New(conf.Proxy) + + s := &http.Server{ + Addr: conf.HttpsListen, + Handler: setupRateLimiter(300).Middleware(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusNotImplemented) + _, _ = rw.Write([]byte(fmt.Sprintf("%#v\n", req))) + _ = r + // TODO: serve from router and proxy + // r.ServeHTTP(rw, req) + })), + DisableGeneralOptionsHandler: false, + TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + // error out on invalid domains + if !conf.Domains.IsValid(info.ServerName) { + return nil, fmt.Errorf("invalid hostname used: '%s'", info.ServerName) + } + + // find a certificate + cert := conf.Certs.GetCertForDomain(info.ServerName) + if cert == nil { + return nil, fmt.Errorf("failed to find certificate for: '%s'", info.ServerName) + } + + // time to return + return cert, nil + }}, + ReadTimeout: 150 * time.Second, + ReadHeaderTimeout: 150 * time.Second, + WriteTimeout: 150 * time.Second, + IdleTimeout: 150 * time.Second, + MaxHeaderBytes: 4096000, + ConnState: func(conn net.Conn, state http.ConnState) { + fmt.Printf("%s => %s: %s\n", conn.LocalAddr(), conn.RemoteAddr(), state.String()) + }, + } + log.Printf("[HTTPS] Starting HTTPS server on: '%s'\n", s.Addr) + go utils.RunBackgroundHttps("HTTPS", s) + return s +} + +// setupRateLimiter is an internal function to create a middleware to manage +// rate limits. +func setupRateLimiter(rateLimit uint64) mux.MiddlewareFunc { + // create memory store + store, err := memorystore.New(&memorystore.Config{ + Tokens: rateLimit, + Interval: time.Minute, + }) + if err != nil { + log.Fatalln(err) + } + + // create a middleware using ips as the key for rate limits + middleware, err := httplimit.NewMiddleware(store, httplimit.IPKeyFunc()) + if err != nil { + log.Fatalln(err) + } + return middleware.Handle +} diff --git a/utils/domain-utils.go b/utils/domain-utils.go index 43d7d08..2762c53 100644 --- a/utils/domain-utils.go +++ b/utils/domain-utils.go @@ -1,16 +1,17 @@ package utils import ( - "fmt" + "strconv" "strings" ) -func SplitDomainPort(host string, defaultPort uint16) (domain string, port uint16, ok bool) { +func SplitDomainPort(host string, defaultPort int) (domain string, port int, ok bool) { a := strings.SplitN(host, ":", 2) switch len(a) { case 2: domain = a[0] - _, err := fmt.Sscanf(a[1], "%d", &port) + p, err := strconv.Atoi(a[1]) + port = p ok = err == nil case 1: domain = a[0] diff --git a/utils/domain-utils_test.go b/utils/domain-utils_test.go index a73609f..4eca7d9 100644 --- a/utils/domain-utils_test.go +++ b/utils/domain-utils_test.go @@ -9,12 +9,12 @@ func TestSplitDomainPort(t *testing.T) { domain, port, ok := SplitDomainPort("www.example.com:5612", 443) assert.True(t, ok, "Output should be true") assert.Equal(t, "www.example.com", domain) - assert.Equal(t, uint16(5612), port) + assert.Equal(t, int(5612), port) domain, port, ok = SplitDomainPort("example.com", 443) assert.True(t, ok, "Output should be true") assert.Equal(t, "example.com", domain) - assert.Equal(t, uint16(443), port) + assert.Equal(t, int(443), port) } func TestDomainWithoutPort(t *testing.T) {