Rework compilable system with the rescheduler library

This commit is contained in:
Melon 2023-06-20 16:48:04 +01:00
parent e9db9d6ef2
commit 629057edc3
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
16 changed files with 202 additions and 104 deletions

2
.gitignore vendored
View File

@ -1 +1,3 @@
*.sqlite
*.local
violet

View File

@ -5,6 +5,7 @@ import (
"crypto/x509/pkix"
"fmt"
"github.com/MrMelon54/certgen"
"github.com/MrMelon54/rescheduler"
"github.com/MrMelon54/violet/utils"
"io/fs"
"log"
@ -24,6 +25,7 @@ type Certs struct {
m map[string]*tls.Certificate
ca *certgen.CertGen
sn atomic.Int64
r *rescheduler.Rescheduler
}
// New creates a new cert list
@ -35,6 +37,9 @@ func New(certDir fs.FS, keyDir fs.FS, selfCert bool) *Certs {
s: &sync.RWMutex{},
m: make(map[string]*tls.Certificate),
}
c.r = rescheduler.NewRescheduler(c.threadCompile)
// in self-signed mode generate a CA certificate to sign other certificates
if c.ss {
ca, err := certgen.MakeCaTls(4096, pkix.Name{
Country: []string{"GB"},
@ -81,6 +86,8 @@ func (c *Certs) GetCertForDomain(domain string) *tls.Certificate {
if err != nil {
return nil
}
// save the generated leaf for loading if the domain is requested again
leaf := serverTls.GetTlsLeaf()
c.m[domain] = &leaf
return &leaf
@ -97,29 +104,33 @@ func (c *Certs) GetCertForDomain(domain string) *tls.Certificate {
return nil
}
// Compile loads the certificates and keys from the directories.
//
// This method makes use of the rescheduler instead of just ignoring multiple
// calls.
func (c *Certs) Compile() {
// don't bother compiling in self-signed mode
if c.ss {
return
}
c.r.Run()
}
// async compile magic
go func() {
// new map
certMap := make(map[string]*tls.Certificate)
func (c *Certs) threadCompile() {
// new map
certMap := make(map[string]*tls.Certificate)
// compile map and check errors
err := c.internalCompile(certMap)
if err != nil {
log.Printf("[Certs] Compile failed: %s\n", err)
return
}
// compile map and check errors
err := c.internalCompile(certMap)
if err != nil {
log.Printf("[Certs] Compile failed: %s\n", err)
return
}
// lock while replacing the map
c.s.Lock()
c.m = certMap
c.s.Unlock()
}()
// lock while replacing the map
c.s.Lock()
c.m = certMap
c.s.Unlock()
}
// internalCompile is a hidden internal method for loading the certificate and

View File

@ -5,6 +5,7 @@ import (
_ "embed"
"flag"
"fmt"
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/violet/certs"
"github.com/MrMelon54/violet/domains"
errorPages "github.com/MrMelon54/violet/error-pages"
@ -25,6 +26,7 @@ import (
// flags - each one has a usage field lol
var (
databasePath = flag.String("db", "", "/path/to/database.sqlite : path to the database file")
mjwtPubKey = flag.String("mjwt", "", "/path/to/mjwt-public-key.pem : path to the pem encoded rsa public key file")
keyPath = flag.String("keys", "", "/path/to/keys : path contains the keys with names matching the certificates and '.key' extensions")
certPath = flag.String("certs", "", "/path/to/certificates : path contains the certificates to load in armoured PEM encoding")
selfSigned = flag.Bool("ss", false, "enable self-signed certificate mode")
@ -33,26 +35,35 @@ var (
httpListen = flag.String("http", "0.0.0.0:80", "address for http listening")
httpsListen = flag.String("https", "0.0.0.0:443", "address for https listening")
inkscapeCmd = flag.String("inkscape", "inkscape", "Path to inkscape binary")
rateLimit = flag.Uint64("ratelimit", 300, "Rate limit (max requests per minute)")
rateLimit = flag.Uint64("rate-limit", 300, "Rate limit (max requests per minute)")
)
func main() {
log.Println("[Violet] Starting...")
flag.Parse()
if *certPath != "" {
// create path to cert dir
err := os.MkdirAll(*certPath, os.ModePerm)
if err != nil {
log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath)
// the cert and key paths are useless in self-signed mode
if !*selfSigned {
if *certPath != "" {
// create path to cert dir
err := os.MkdirAll(*certPath, os.ModePerm)
if err != nil {
log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath)
}
}
if *keyPath != "" {
// create path to key dir
err := os.MkdirAll(*keyPath, os.ModePerm)
if err != nil {
log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath)
}
}
}
if *keyPath != "" {
// create path to key dir
err := os.MkdirAll(*keyPath, os.ModePerm)
if err != nil {
log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath)
}
// load the MJWT RSA public key from a pem encoded file
mjwtVerify, err := mjwt.NewMJwtVerifierFromFile(*mjwtPubKey)
if err != nil {
log.Fatalf("[Violet] Failed to load MJWT verifier public key from file: '%s'", *mjwtPubKey)
}
// open sqlite database
@ -80,7 +91,7 @@ func main() {
Acme: acmeChallenges,
Certs: allowedCerts,
Favicons: dynamicFavicons,
Verify: nil, // TODO: add mjwt verify support
Verify: mjwtVerify,
ErrorPages: dynamicErrorPages,
Router: dynamicRouter,
}

View File

@ -1,6 +1,6 @@
CREATE TABLE IF NOT EXISTS domains
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT,
domain TEXT UNIQUE,
active INTEGER DEFAULT 1
);

View File

@ -3,6 +3,7 @@ package domains
import (
"database/sql"
_ "embed"
"github.com/MrMelon54/rescheduler"
"github.com/MrMelon54/violet/utils"
"log"
"strings"
@ -17,6 +18,7 @@ type Domains struct {
db *sql.DB
s *sync.RWMutex
m map[string]struct{}
r *rescheduler.Rescheduler
}
// New creates a new domain list
@ -26,6 +28,7 @@ func New(db *sql.DB) *Domains {
s: &sync.RWMutex{},
m: make(map[string]struct{}),
}
a.r = rescheduler.NewRescheduler(a.threadCompile)
// init domains table
_, err := a.db.Exec(createTableDomains)
@ -64,25 +67,27 @@ func (d *Domains) IsValid(host string) bool {
// Compile downloads the list of domains from the database and loads them into
// memory for faster lookups.
//
// This method is asynchronous and uses locks for safety.
// This method makes use of the rescheduler instead of just ignoring multiple
// calls.
func (d *Domains) Compile() {
// async compile magic
go func() {
// new map
domainMap := make(map[string]struct{})
d.r.Run()
}
// compile map and check errors
err := d.internalCompile(domainMap)
if err != nil {
log.Printf("[Domains] Compile failed: %s\n", err)
return
}
func (d *Domains) threadCompile() {
// new map
domainMap := make(map[string]struct{})
// lock while replacing the map
d.s.Lock()
d.m = domainMap
d.s.Unlock()
}()
// compile map and check errors
err := d.internalCompile(domainMap)
if err != nil {
log.Printf("[Domains] Compile failed: %s\n", err)
return
}
// lock while replacing the map
d.s.Lock()
d.m = domainMap
d.s.Unlock()
}
// internalCompile is a hidden internal method for querying the database during
@ -110,3 +115,21 @@ func (d *Domains) internalCompile(m map[string]struct{}) error {
// check for errors
return rows.Err()
}
func (d *Domains) Put(domain string, active bool) {
d.s.Lock()
defer d.s.Unlock()
_, err := d.db.Exec("INSERT OR REPLACE INTO domains (domain, active) VALUES (?, ?)", domain, active)
if err != nil {
log.Printf("[Violet] Database error: %s\n", err)
}
}
func (d *Domains) Delete(domain string) {
d.s.Lock()
defer d.s.Unlock()
_, err := d.db.Exec("INSERT OR REPLACE INTO domains (domain, active) VALUES (?, ?)", domain, false)
if err != nil {
log.Printf("[Violet] Database error: %s\n", err)
}
}

View File

@ -12,7 +12,7 @@ func TestDomainsNew(t *testing.T) {
assert.NoError(t, err)
domains := New(db)
_, err = db.Exec("insert into domains (domain, active) values (?, ?)", "example.com", 1)
_, err = db.Exec("INSERT OR IGNORE INTO domains (domain, active) VALUES (?, ?)", "example.com", 1)
assert.NoError(t, err)
domains.Compile()
@ -31,7 +31,7 @@ func TestDomains_IsValid(t *testing.T) {
assert.NoError(t, err)
domains := New(db)
_, err = domains.db.Exec("insert into domains (domain, active) values (?, ?)", "example.com", 1)
_, err = domains.db.Exec("INSERT OR IGNORE INTO domains (domain, active) VALUES (?, ?)", "example.com", 1)
assert.NoError(t, err)
domains.s.Lock()

View File

@ -2,6 +2,7 @@ package error_pages
import (
"fmt"
"github.com/MrMelon54/rescheduler"
"io/fs"
"log"
"net/http"
@ -18,11 +19,12 @@ type ErrorPages struct {
m map[int]func(rw http.ResponseWriter)
generic func(rw http.ResponseWriter, code int)
dir fs.FS
r *rescheduler.Rescheduler
}
// New creates a new error pages generator
func New(dir fs.FS) *ErrorPages {
return &ErrorPages{
e := &ErrorPages{
s: &sync.RWMutex{},
m: make(map[int]func(rw http.ResponseWriter)),
// generic error page writer
@ -40,6 +42,8 @@ func New(dir fs.FS) *ErrorPages {
},
dir: dir,
}
e.r = rescheduler.NewRescheduler(e.threadCompile)
return e
}
// ServeError writes the error page for the given code to the response writer
@ -58,26 +62,31 @@ func (e *ErrorPages) ServeError(rw http.ResponseWriter, code int) {
e.generic(rw, code)
}
// Compile loads the error pages the certificates and keys from the directories.
//
// This method makes use of the rescheduler instead of just ignoring multiple
// calls.
func (e *ErrorPages) Compile() {
// async compile magic
go func() {
// new map
errorPageMap := make(map[int]func(rw http.ResponseWriter))
e.r.Run()
}
// compile map and check errors
if e.dir != nil {
err := e.internalCompile(errorPageMap)
if err != nil {
log.Printf("[Certs] Compile failed: %s\n", err)
return
}
func (e *ErrorPages) threadCompile() {
// new map
errorPageMap := make(map[int]func(rw http.ResponseWriter))
// compile map and check errors
if e.dir != nil {
err := e.internalCompile(errorPageMap)
if err != nil {
log.Printf("[Certs] Compile failed: %s\n", err)
return
}
}
// lock while replacing the map
e.s.Lock()
e.m = errorPageMap
e.s.Unlock()
}()
// lock while replacing the map
e.s.Lock()
e.m = errorPageMap
e.s.Unlock()
}
func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) error {

View File

@ -4,6 +4,7 @@ import (
"database/sql"
"errors"
"fmt"
"github.com/MrMelon54/rescheduler"
"golang.org/x/sync/errgroup"
"log"
"sync"
@ -17,6 +18,7 @@ type Favicons struct {
cmd string
cLock *sync.RWMutex
faviconMap map[string]*FaviconList
r *rescheduler.Rescheduler
}
// New creates a new dynamic favicon generator
@ -27,6 +29,7 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons {
cLock: &sync.RWMutex{},
faviconMap: make(map[string]*FaviconList),
}
f.r = rescheduler.NewRescheduler(f.threadCompile)
// init favicons table
_, err := f.db.Exec(`create table if not exists favicons (id integer primary key autoincrement, host varchar, svg varchar, png varchar, ico varchar)`)
@ -40,29 +43,6 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons {
return f
}
// Compile downloads the list of favicon mappings from the database and loads
// them and the target favicons into memory for faster lookups
func (f *Favicons) Compile() {
// async compile magic
go func() {
// new map
favicons := make(map[string]*FaviconList)
// compile map and check errors
err := f.internalCompile(favicons)
if err != nil {
// log compile errors
log.Printf("[Favicons] Compile failed: %s\n", err)
return
}
// lock while replacing the map
f.cLock.Lock()
f.faviconMap = favicons
f.cLock.Unlock()
}()
}
// GetIcons returns the favicon list for the provided host or nil if no
// icon is found or generated
func (f *Favicons) GetIcons(host string) *FaviconList {
@ -74,9 +54,36 @@ func (f *Favicons) GetIcons(host string) *FaviconList {
return f.faviconMap[host]
}
// Compile downloads the list of favicon mappings from the database and loads
// them and the target favicons into memory for faster lookups
//
// This method makes use of the rescheduler instead of just ignoring multiple
// calls.
func (f *Favicons) Compile() {
f.r.Run()
}
func (f *Favicons) threadCompile() {
// new map
favicons := make(map[string]*FaviconList)
// compile map and check errors
err := f.internalCompile(favicons)
if err != nil {
// log compile errors
log.Printf("[Favicons] Compile failed: %s\n", err)
return
}
// lock while replacing the map
f.cLock.Lock()
f.faviconMap = favicons
f.cLock.Unlock()
}
// internalCompile is a hidden internal method for loading and generating all
// favicons.
func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error {
func (f *Favicons) internalCompile(m map[string]*FaviconList) error {
// query all rows in database
query, err := f.db.Query(`select host, svg, png, ico from favicons`)
if err != nil {
@ -100,7 +107,7 @@ func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error {
}
// save the favicon list to the map
faviconMap[host] = l
m[host] = l
// run the pre-process in a separate goroutine
g.Go(func() error {

1
go.mod
View File

@ -6,6 +6,7 @@ require (
github.com/MrMelon54/certgen v0.0.1
github.com/MrMelon54/mjwt v0.0.2
github.com/MrMelon54/png2ico v1.0.1
github.com/MrMelon54/rescheduler v0.0.1
github.com/MrMelon54/trie v0.0.2
github.com/julienschmidt/httprouter v1.3.0
github.com/mattn/go-sqlite3 v1.14.16

2
go.sum
View File

@ -4,6 +4,8 @@ github.com/MrMelon54/mjwt v0.0.2 h1:jDqyPnFloh80XdSmZ6jt9qhUj/ULcoQ4QSHXPdkAIE4=
github.com/MrMelon54/mjwt v0.0.2/go.mod h1:HzY8P6Je+ovS/fwK5sILRMq5mnZT4+WuFRc98LBy7z4=
github.com/MrMelon54/png2ico v1.0.1 h1:zJoSSl4OkvSIMWGyGPvb8fWNa0KrUvMIjgNGLNLJhVQ=
github.com/MrMelon54/png2ico v1.0.1/go.mod h1:NOv3tO4497mInG+3tcFkIohmxCywUwMLU8WNxJZLVmU=
github.com/MrMelon54/rescheduler v0.0.1 h1:gzNvL8X81M00uYN0i9clFVrXCkG1UuLNYxDcvjKyBqo=
github.com/MrMelon54/rescheduler v0.0.1/go.mod h1:OQDFtZHdS4/qA/r7rtJUQA22/hbpnZ9MGQCXOPjhC6w=
github.com/MrMelon54/trie v0.0.2 h1:ZXWcX5ij62O9K4I/anuHmVg8L3tF0UGdlPceAASwKEY=
github.com/MrMelon54/trie v0.0.2/go.mod h1:sGCGOcqb+DxSxvHgSOpbpkmA7mFZR47YDExy9OCbVZI=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=

View File

@ -4,6 +4,7 @@ import (
"database/sql"
_ "embed"
"fmt"
"github.com/MrMelon54/rescheduler"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/target"
"github.com/MrMelon54/violet/utils"
@ -20,6 +21,7 @@ type Manager struct {
s *sync.RWMutex
r *Router
p *proxy.HybridTransport
z *rescheduler.Rescheduler
}
var (
@ -42,6 +44,7 @@ func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager {
r: New(proxy),
p: proxy,
}
m.z = rescheduler.NewRescheduler(m.threadCompile)
// init routes table
_, err := m.db.Exec(createTableRoutes)
@ -69,22 +72,24 @@ func (m *Manager) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
func (m *Manager) Compile() {
go func() {
// new router
router := New(m.p)
m.z.Run()
}
// compile router and check errors
err := m.internalCompile(router)
if err != nil {
log.Printf("[Manager] Compile failed: %s\n", err)
return
}
func (m *Manager) threadCompile() {
// new router
router := New(m.p)
// lock while replacing router
m.s.Lock()
m.r = router
m.s.Unlock()
}()
// compile router and check errors
err := m.internalCompile(router)
if err != nil {
log.Printf("[Manager] Compile failed: %s\n", err)
return
}
// lock while replacing router
m.s.Lock()
m.r = router
m.s.Unlock()
}
// internalCompile is a hidden internal method for querying the database during

View File

@ -28,6 +28,28 @@ func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server
rw.WriteHeader(http.StatusAccepted)
})
// Endpoint for domains
r.PUT("/domain/:domain", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
if !hasPerms(conf.Verify, req, "violet:domains") {
utils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
// add domain with active state
q := req.URL.Query()
conf.Domains.Put(params.ByName("domain"), q.Get("active") == "1")
})
r.DELETE("/domain/:domain", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
if !hasPerms(conf.Verify, req, "violet:domains") {
utils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
// add domain with active state
q := req.URL.Query()
conf.Domains.Put(params.ByName("domain"), q.Get("active") == "1")
})
// Endpoint for acme-challenge
r.PUT("/acme-challenge/:domain/:key/:value", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
if !hasPerms(conf.Verify, req, "violet:acme-challenge") {

View File

@ -18,7 +18,9 @@ var snakeOilProv = genSnakeOilProv()
type fakeDomains struct{}
func (f *fakeDomains) IsValid(host string) bool { return host == "example.com" }
func (f *fakeDomains) IsValid(host string) bool { return host == "example.com" }
func (f *fakeDomains) Put(domain string, active bool) {}
func (f *fakeDomains) Delete(domain string) {}
func genSnakeOilProv() mjwt.Signer {
key, err := rsa.GenerateKey(rand.Reader, 1024)

View File

@ -27,6 +27,8 @@ type Conf struct {
type DomainProvider interface {
IsValid(host string) bool
Put(domain string, active bool)
Delete(domain string)
}
type AcmeChallengeProvider interface {
@ -37,4 +39,5 @@ type AcmeChallengeProvider interface {
type CertProvider interface {
GetCertForDomain(domain string) *tls.Certificate
Compile()
}