diff --git a/.gitignore b/.gitignore index 9b1dffd..857e074 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ *.sqlite +*.local +violet diff --git a/certs/certs.go b/certs/certs.go index dc801c9..f9abc83 100644 --- a/certs/certs.go +++ b/certs/certs.go @@ -5,6 +5,7 @@ import ( "crypto/x509/pkix" "fmt" "github.com/MrMelon54/certgen" + "github.com/MrMelon54/rescheduler" "github.com/MrMelon54/violet/utils" "io/fs" "log" @@ -24,6 +25,7 @@ type Certs struct { m map[string]*tls.Certificate ca *certgen.CertGen sn atomic.Int64 + r *rescheduler.Rescheduler } // New creates a new cert list @@ -35,6 +37,9 @@ func New(certDir fs.FS, keyDir fs.FS, selfCert bool) *Certs { s: &sync.RWMutex{}, m: make(map[string]*tls.Certificate), } + c.r = rescheduler.NewRescheduler(c.threadCompile) + + // in self-signed mode generate a CA certificate to sign other certificates if c.ss { ca, err := certgen.MakeCaTls(4096, pkix.Name{ Country: []string{"GB"}, @@ -81,6 +86,8 @@ func (c *Certs) GetCertForDomain(domain string) *tls.Certificate { if err != nil { return nil } + + // save the generated leaf for loading if the domain is requested again leaf := serverTls.GetTlsLeaf() c.m[domain] = &leaf return &leaf @@ -97,29 +104,33 @@ func (c *Certs) GetCertForDomain(domain string) *tls.Certificate { return nil } +// Compile loads the certificates and keys from the directories. +// +// This method makes use of the rescheduler instead of just ignoring multiple +// calls. func (c *Certs) Compile() { // don't bother compiling in self-signed mode if c.ss { return } + c.r.Run() +} - // async compile magic - go func() { - // new map - certMap := make(map[string]*tls.Certificate) +func (c *Certs) threadCompile() { + // new map + certMap := make(map[string]*tls.Certificate) - // compile map and check errors - err := c.internalCompile(certMap) - if err != nil { - log.Printf("[Certs] Compile failed: %s\n", err) - return - } + // compile map and check errors + err := c.internalCompile(certMap) + if err != nil { + log.Printf("[Certs] Compile failed: %s\n", err) + return + } - // lock while replacing the map - c.s.Lock() - c.m = certMap - c.s.Unlock() - }() + // lock while replacing the map + c.s.Lock() + c.m = certMap + c.s.Unlock() } // internalCompile is a hidden internal method for loading the certificate and diff --git a/cmd/violet/main.go b/cmd/violet/main.go index 97ea2e8..3e29d92 100644 --- a/cmd/violet/main.go +++ b/cmd/violet/main.go @@ -5,6 +5,7 @@ import ( _ "embed" "flag" "fmt" + "github.com/MrMelon54/mjwt" "github.com/MrMelon54/violet/certs" "github.com/MrMelon54/violet/domains" errorPages "github.com/MrMelon54/violet/error-pages" @@ -25,6 +26,7 @@ import ( // flags - each one has a usage field lol var ( databasePath = flag.String("db", "", "/path/to/database.sqlite : path to the database file") + mjwtPubKey = flag.String("mjwt", "", "/path/to/mjwt-public-key.pem : path to the pem encoded rsa public key file") keyPath = flag.String("keys", "", "/path/to/keys : path contains the keys with names matching the certificates and '.key' extensions") certPath = flag.String("certs", "", "/path/to/certificates : path contains the certificates to load in armoured PEM encoding") selfSigned = flag.Bool("ss", false, "enable self-signed certificate mode") @@ -33,26 +35,35 @@ var ( httpListen = flag.String("http", "0.0.0.0:80", "address for http listening") httpsListen = flag.String("https", "0.0.0.0:443", "address for https listening") inkscapeCmd = flag.String("inkscape", "inkscape", "Path to inkscape binary") - rateLimit = flag.Uint64("ratelimit", 300, "Rate limit (max requests per minute)") + rateLimit = flag.Uint64("rate-limit", 300, "Rate limit (max requests per minute)") ) func main() { log.Println("[Violet] Starting...") flag.Parse() - if *certPath != "" { - // create path to cert dir - err := os.MkdirAll(*certPath, os.ModePerm) - if err != nil { - log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath) + // the cert and key paths are useless in self-signed mode + if !*selfSigned { + if *certPath != "" { + // create path to cert dir + err := os.MkdirAll(*certPath, os.ModePerm) + if err != nil { + log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath) + } + } + if *keyPath != "" { + // create path to key dir + err := os.MkdirAll(*keyPath, os.ModePerm) + if err != nil { + log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath) + } } } - if *keyPath != "" { - // create path to key dir - err := os.MkdirAll(*keyPath, os.ModePerm) - if err != nil { - log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath) - } + + // load the MJWT RSA public key from a pem encoded file + mjwtVerify, err := mjwt.NewMJwtVerifierFromFile(*mjwtPubKey) + if err != nil { + log.Fatalf("[Violet] Failed to load MJWT verifier public key from file: '%s'", *mjwtPubKey) } // open sqlite database @@ -80,7 +91,7 @@ func main() { Acme: acmeChallenges, Certs: allowedCerts, Favicons: dynamicFavicons, - Verify: nil, // TODO: add mjwt verify support + Verify: mjwtVerify, ErrorPages: dynamicErrorPages, Router: dynamicRouter, } diff --git a/domains/create-table-domains.sql b/domains/create-table-domains.sql index 129e224..279adcf 100644 --- a/domains/create-table-domains.sql +++ b/domains/create-table-domains.sql @@ -1,6 +1,6 @@ CREATE TABLE IF NOT EXISTS domains ( id INTEGER PRIMARY KEY AUTOINCREMENT, - domain TEXT, + domain TEXT UNIQUE, active INTEGER DEFAULT 1 ); diff --git a/domains/domains.go b/domains/domains.go index 343a057..9380e6c 100644 --- a/domains/domains.go +++ b/domains/domains.go @@ -3,6 +3,7 @@ package domains import ( "database/sql" _ "embed" + "github.com/MrMelon54/rescheduler" "github.com/MrMelon54/violet/utils" "log" "strings" @@ -17,6 +18,7 @@ type Domains struct { db *sql.DB s *sync.RWMutex m map[string]struct{} + r *rescheduler.Rescheduler } // New creates a new domain list @@ -26,6 +28,7 @@ func New(db *sql.DB) *Domains { s: &sync.RWMutex{}, m: make(map[string]struct{}), } + a.r = rescheduler.NewRescheduler(a.threadCompile) // init domains table _, err := a.db.Exec(createTableDomains) @@ -64,25 +67,27 @@ func (d *Domains) IsValid(host string) bool { // Compile downloads the list of domains from the database and loads them into // memory for faster lookups. // -// This method is asynchronous and uses locks for safety. +// This method makes use of the rescheduler instead of just ignoring multiple +// calls. func (d *Domains) Compile() { - // async compile magic - go func() { - // new map - domainMap := make(map[string]struct{}) + d.r.Run() +} - // compile map and check errors - err := d.internalCompile(domainMap) - if err != nil { - log.Printf("[Domains] Compile failed: %s\n", err) - return - } +func (d *Domains) threadCompile() { + // new map + domainMap := make(map[string]struct{}) - // lock while replacing the map - d.s.Lock() - d.m = domainMap - d.s.Unlock() - }() + // compile map and check errors + err := d.internalCompile(domainMap) + if err != nil { + log.Printf("[Domains] Compile failed: %s\n", err) + return + } + + // lock while replacing the map + d.s.Lock() + d.m = domainMap + d.s.Unlock() } // internalCompile is a hidden internal method for querying the database during @@ -110,3 +115,21 @@ func (d *Domains) internalCompile(m map[string]struct{}) error { // check for errors return rows.Err() } + +func (d *Domains) Put(domain string, active bool) { + d.s.Lock() + defer d.s.Unlock() + _, err := d.db.Exec("INSERT OR REPLACE INTO domains (domain, active) VALUES (?, ?)", domain, active) + if err != nil { + log.Printf("[Violet] Database error: %s\n", err) + } +} + +func (d *Domains) Delete(domain string) { + d.s.Lock() + defer d.s.Unlock() + _, err := d.db.Exec("INSERT OR REPLACE INTO domains (domain, active) VALUES (?, ?)", domain, false) + if err != nil { + log.Printf("[Violet] Database error: %s\n", err) + } +} diff --git a/domains/domains_test.go b/domains/domains_test.go index a4a68b9..d2a0e3f 100644 --- a/domains/domains_test.go +++ b/domains/domains_test.go @@ -12,7 +12,7 @@ func TestDomainsNew(t *testing.T) { assert.NoError(t, err) domains := New(db) - _, err = db.Exec("insert into domains (domain, active) values (?, ?)", "example.com", 1) + _, err = db.Exec("INSERT OR IGNORE INTO domains (domain, active) VALUES (?, ?)", "example.com", 1) assert.NoError(t, err) domains.Compile() @@ -31,7 +31,7 @@ func TestDomains_IsValid(t *testing.T) { assert.NoError(t, err) domains := New(db) - _, err = domains.db.Exec("insert into domains (domain, active) values (?, ?)", "example.com", 1) + _, err = domains.db.Exec("INSERT OR IGNORE INTO domains (domain, active) VALUES (?, ?)", "example.com", 1) assert.NoError(t, err) domains.s.Lock() diff --git a/error-pages/error-pages.go b/error-pages/error-pages.go index 38f316d..86819fb 100644 --- a/error-pages/error-pages.go +++ b/error-pages/error-pages.go @@ -2,6 +2,7 @@ package error_pages import ( "fmt" + "github.com/MrMelon54/rescheduler" "io/fs" "log" "net/http" @@ -18,11 +19,12 @@ type ErrorPages struct { m map[int]func(rw http.ResponseWriter) generic func(rw http.ResponseWriter, code int) dir fs.FS + r *rescheduler.Rescheduler } // New creates a new error pages generator func New(dir fs.FS) *ErrorPages { - return &ErrorPages{ + e := &ErrorPages{ s: &sync.RWMutex{}, m: make(map[int]func(rw http.ResponseWriter)), // generic error page writer @@ -40,6 +42,8 @@ func New(dir fs.FS) *ErrorPages { }, dir: dir, } + e.r = rescheduler.NewRescheduler(e.threadCompile) + return e } // ServeError writes the error page for the given code to the response writer @@ -58,26 +62,31 @@ func (e *ErrorPages) ServeError(rw http.ResponseWriter, code int) { e.generic(rw, code) } +// Compile loads the error pages the certificates and keys from the directories. +// +// This method makes use of the rescheduler instead of just ignoring multiple +// calls. func (e *ErrorPages) Compile() { - // async compile magic - go func() { - // new map - errorPageMap := make(map[int]func(rw http.ResponseWriter)) + e.r.Run() +} - // compile map and check errors - if e.dir != nil { - err := e.internalCompile(errorPageMap) - if err != nil { - log.Printf("[Certs] Compile failed: %s\n", err) - return - } +func (e *ErrorPages) threadCompile() { + // new map + errorPageMap := make(map[int]func(rw http.ResponseWriter)) + + // compile map and check errors + if e.dir != nil { + err := e.internalCompile(errorPageMap) + if err != nil { + log.Printf("[Certs] Compile failed: %s\n", err) + return } + } - // lock while replacing the map - e.s.Lock() - e.m = errorPageMap - e.s.Unlock() - }() + // lock while replacing the map + e.s.Lock() + e.m = errorPageMap + e.s.Unlock() } func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) error { diff --git a/favicons/favicons.go b/favicons/favicons.go index e3682b8..bbf4995 100644 --- a/favicons/favicons.go +++ b/favicons/favicons.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "fmt" + "github.com/MrMelon54/rescheduler" "golang.org/x/sync/errgroup" "log" "sync" @@ -17,6 +18,7 @@ type Favicons struct { cmd string cLock *sync.RWMutex faviconMap map[string]*FaviconList + r *rescheduler.Rescheduler } // New creates a new dynamic favicon generator @@ -27,6 +29,7 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons { cLock: &sync.RWMutex{}, faviconMap: make(map[string]*FaviconList), } + f.r = rescheduler.NewRescheduler(f.threadCompile) // init favicons table _, err := f.db.Exec(`create table if not exists favicons (id integer primary key autoincrement, host varchar, svg varchar, png varchar, ico varchar)`) @@ -40,29 +43,6 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons { return f } -// Compile downloads the list of favicon mappings from the database and loads -// them and the target favicons into memory for faster lookups -func (f *Favicons) Compile() { - // async compile magic - go func() { - // new map - favicons := make(map[string]*FaviconList) - - // compile map and check errors - err := f.internalCompile(favicons) - if err != nil { - // log compile errors - log.Printf("[Favicons] Compile failed: %s\n", err) - return - } - - // lock while replacing the map - f.cLock.Lock() - f.faviconMap = favicons - f.cLock.Unlock() - }() -} - // GetIcons returns the favicon list for the provided host or nil if no // icon is found or generated func (f *Favicons) GetIcons(host string) *FaviconList { @@ -74,9 +54,36 @@ func (f *Favicons) GetIcons(host string) *FaviconList { return f.faviconMap[host] } +// Compile downloads the list of favicon mappings from the database and loads +// them and the target favicons into memory for faster lookups +// +// This method makes use of the rescheduler instead of just ignoring multiple +// calls. +func (f *Favicons) Compile() { + f.r.Run() +} + +func (f *Favicons) threadCompile() { + // new map + favicons := make(map[string]*FaviconList) + + // compile map and check errors + err := f.internalCompile(favicons) + if err != nil { + // log compile errors + log.Printf("[Favicons] Compile failed: %s\n", err) + return + } + + // lock while replacing the map + f.cLock.Lock() + f.faviconMap = favicons + f.cLock.Unlock() +} + // internalCompile is a hidden internal method for loading and generating all // favicons. -func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error { +func (f *Favicons) internalCompile(m map[string]*FaviconList) error { // query all rows in database query, err := f.db.Query(`select host, svg, png, ico from favicons`) if err != nil { @@ -100,7 +107,7 @@ func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error { } // save the favicon list to the map - faviconMap[host] = l + m[host] = l // run the pre-process in a separate goroutine g.Go(func() error { diff --git a/go.mod b/go.mod index f5a8609..7947ba5 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/MrMelon54/certgen v0.0.1 github.com/MrMelon54/mjwt v0.0.2 github.com/MrMelon54/png2ico v1.0.1 + github.com/MrMelon54/rescheduler v0.0.1 github.com/MrMelon54/trie v0.0.2 github.com/julienschmidt/httprouter v1.3.0 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index 22efb14..be21a2d 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/MrMelon54/mjwt v0.0.2 h1:jDqyPnFloh80XdSmZ6jt9qhUj/ULcoQ4QSHXPdkAIE4= github.com/MrMelon54/mjwt v0.0.2/go.mod h1:HzY8P6Je+ovS/fwK5sILRMq5mnZT4+WuFRc98LBy7z4= github.com/MrMelon54/png2ico v1.0.1 h1:zJoSSl4OkvSIMWGyGPvb8fWNa0KrUvMIjgNGLNLJhVQ= github.com/MrMelon54/png2ico v1.0.1/go.mod h1:NOv3tO4497mInG+3tcFkIohmxCywUwMLU8WNxJZLVmU= +github.com/MrMelon54/rescheduler v0.0.1 h1:gzNvL8X81M00uYN0i9clFVrXCkG1UuLNYxDcvjKyBqo= +github.com/MrMelon54/rescheduler v0.0.1/go.mod h1:OQDFtZHdS4/qA/r7rtJUQA22/hbpnZ9MGQCXOPjhC6w= github.com/MrMelon54/trie v0.0.2 h1:ZXWcX5ij62O9K4I/anuHmVg8L3tF0UGdlPceAASwKEY= github.com/MrMelon54/trie v0.0.2/go.mod h1:sGCGOcqb+DxSxvHgSOpbpkmA7mFZR47YDExy9OCbVZI= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/router/manager.go b/router/manager.go index 988d215..447621e 100644 --- a/router/manager.go +++ b/router/manager.go @@ -4,6 +4,7 @@ import ( "database/sql" _ "embed" "fmt" + "github.com/MrMelon54/rescheduler" "github.com/MrMelon54/violet/proxy" "github.com/MrMelon54/violet/target" "github.com/MrMelon54/violet/utils" @@ -20,6 +21,7 @@ type Manager struct { s *sync.RWMutex r *Router p *proxy.HybridTransport + z *rescheduler.Rescheduler } var ( @@ -42,6 +44,7 @@ func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager { r: New(proxy), p: proxy, } + m.z = rescheduler.NewRescheduler(m.threadCompile) // init routes table _, err := m.db.Exec(createTableRoutes) @@ -69,22 +72,24 @@ func (m *Manager) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } func (m *Manager) Compile() { - go func() { - // new router - router := New(m.p) + m.z.Run() +} - // compile router and check errors - err := m.internalCompile(router) - if err != nil { - log.Printf("[Manager] Compile failed: %s\n", err) - return - } +func (m *Manager) threadCompile() { + // new router + router := New(m.p) - // lock while replacing router - m.s.Lock() - m.r = router - m.s.Unlock() - }() + // compile router and check errors + err := m.internalCompile(router) + if err != nil { + log.Printf("[Manager] Compile failed: %s\n", err) + return + } + + // lock while replacing router + m.s.Lock() + m.r = router + m.s.Unlock() } // internalCompile is a hidden internal method for querying the database during diff --git a/servers/api.go b/servers/api.go index cfd3f2b..60a3fe3 100644 --- a/servers/api.go +++ b/servers/api.go @@ -28,6 +28,28 @@ func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server rw.WriteHeader(http.StatusAccepted) }) + // Endpoint for domains + r.PUT("/domain/:domain", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + if !hasPerms(conf.Verify, req, "violet:domains") { + utils.RespondHttpStatus(rw, http.StatusForbidden) + return + } + + // add domain with active state + q := req.URL.Query() + conf.Domains.Put(params.ByName("domain"), q.Get("active") == "1") + }) + r.DELETE("/domain/:domain", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + if !hasPerms(conf.Verify, req, "violet:domains") { + utils.RespondHttpStatus(rw, http.StatusForbidden) + return + } + + // add domain with active state + q := req.URL.Query() + conf.Domains.Put(params.ByName("domain"), q.Get("active") == "1") + }) + // Endpoint for acme-challenge r.PUT("/acme-challenge/:domain/:key/:value", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { if !hasPerms(conf.Verify, req, "violet:acme-challenge") { diff --git a/servers/api_test.go b/servers/api_test.go index d465998..58dd10c 100644 --- a/servers/api_test.go +++ b/servers/api_test.go @@ -18,7 +18,9 @@ var snakeOilProv = genSnakeOilProv() type fakeDomains struct{} -func (f *fakeDomains) IsValid(host string) bool { return host == "example.com" } +func (f *fakeDomains) IsValid(host string) bool { return host == "example.com" } +func (f *fakeDomains) Put(domain string, active bool) {} +func (f *fakeDomains) Delete(domain string) {} func genSnakeOilProv() mjwt.Signer { key, err := rsa.GenerateKey(rand.Reader, 1024) diff --git a/servers/conf.go b/servers/conf.go index ab3e877..f6f738d 100644 --- a/servers/conf.go +++ b/servers/conf.go @@ -27,6 +27,8 @@ type Conf struct { type DomainProvider interface { IsValid(host string) bool + Put(domain string, active bool) + Delete(domain string) } type AcmeChallengeProvider interface { @@ -37,4 +39,5 @@ type AcmeChallengeProvider interface { type CertProvider interface { GetCertForDomain(domain string) *tls.Certificate + Compile() } diff --git a/utils/multi-compilable.go b/utils/compilable.go similarity index 100% rename from utils/multi-compilable.go rename to utils/compilable.go diff --git a/utils/multi-compilable_test.go b/utils/compilable_test.go similarity index 100% rename from utils/multi-compilable_test.go rename to utils/compilable_test.go