diff --git a/certs/certs.go b/certs/certs.go index 999ca45..b79b3e5 100644 --- a/certs/certs.go +++ b/certs/certs.go @@ -11,6 +11,7 @@ import ( "sync" ) +// Certs is the certificate loader and management system. type Certs struct { cDir fs.FS kDir fs.FS @@ -18,6 +19,7 @@ type Certs struct { m map[string]*tls.Certificate } +// New creates a new cert list func New(certDir fs.FS, keyDir fs.FS) *Certs { a := &Certs{ cDir: certDir, @@ -25,6 +27,8 @@ func New(certDir fs.FS, keyDir fs.FS) *Certs { s: &sync.RWMutex{}, m: make(map[string]*tls.Certificate), } + + // run compile to get the initial data a.Compile() return a } @@ -62,6 +66,7 @@ func (c *Certs) Compile() { log.Printf("[Certs] Compile failed: %s\n", err) return } + // lock while replacing the map c.s.Lock() c.m = certMap @@ -69,6 +74,8 @@ func (c *Certs) Compile() { }() } +// internalCompile is a hidden internal method for loading the certificate and +// key files func (c *Certs) internalCompile(m map[string]*tls.Certificate) error { // try to read dir files, err := fs.ReadDir(c.cDir, "") diff --git a/cmd/violet/main.go b/cmd/violet/main.go index 4094c82..1d51441 100644 --- a/cmd/violet/main.go +++ b/cmd/violet/main.go @@ -16,8 +16,9 @@ import ( "os" ) +// flags - each one has a usage field lol var ( - databasePath = flag.String("db", "", "/path/to/database.sqlite") + databasePath = flag.String("db", "", "/path/to/database.sqlite : path to the database file") keyPath = flag.String("keys", "", "/path/to/keys : path contains the keys with names matching the certificates and '.key' extensions") certPath = flag.String("certs", "", "/path/to/certificates : path contains the certificates to load in armoured PEM encoding") errorPagePath = flag.String("errors", "", "/path/to/error-pages : path contains the custom error pages") @@ -30,11 +31,12 @@ var ( func main() { log.Println("[Violet] Starting...") - // create paths + // create path to cert dir err := os.MkdirAll(*certPath, os.ModePerm) if err != nil { log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath) } + // create path to key dir err = os.MkdirAll(*keyPath, os.ModePerm) if err != nil { log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath) @@ -52,6 +54,7 @@ func main() { dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider + // struct containing config for the http servers srvConf := &servers.Conf{ ApiListen: *apiListen, HttpListen: *httpListen, diff --git a/domains/domains.go b/domains/domains.go index 3caebca..5ae003b 100644 --- a/domains/domains.go +++ b/domains/domains.go @@ -8,6 +8,7 @@ import ( "sync" ) +// Domains is the domain list and management system. type Domains struct { db *sql.DB s *sync.RWMutex @@ -64,12 +65,16 @@ func (d *Domains) IsValid(host string) bool { func (d *Domains) Compile() { // async compile magic go func() { + // new map domainMap := make(map[string]struct{}) + + // compile map and check errors err := d.internalCompile(domainMap) if err != nil { log.Printf("[Domains] Compile failed: %s\n", err) return } + // lock while replacing the map d.s.Lock() d.m = domainMap diff --git a/error-pages/error-pages.go b/error-pages/error-pages.go index 1175f04..2534e0d 100644 --- a/error-pages/error-pages.go +++ b/error-pages/error-pages.go @@ -20,29 +20,40 @@ type ErrorPages struct { dir fs.FS } +// New creates a new error pages generator func New(dir fs.FS) *ErrorPages { return &ErrorPages{ s: &sync.RWMutex{}, m: make(map[int]func(rw http.ResponseWriter)), + // generic error page writer generic: func(rw http.ResponseWriter, code int) { + // if status text is empty then the code is unknown a := http.StatusText(code) if a != "" { + // output in "xxx Error Text" format http.Error(rw, fmt.Sprintf("%d %s\n", code, a), code) return } + // output the code and generic unknown message http.Error(rw, fmt.Sprintf("%d Unknown Error Code\n", code), code) }, dir: dir, } } +// ServeError writes the error page for the given code to the response writer func (e *ErrorPages) ServeError(rw http.ResponseWriter, code int) { + // read lock for safety e.s.RLock() defer e.s.RUnlock() + + // use the custom error page if it exists if p, ok := e.m[code]; ok { p(rw) return } + + // otherwise use the generic error page e.generic(rw, code) } @@ -58,6 +69,7 @@ func (e *ErrorPages) Compile() { log.Printf("[Certs] Compile failed: %s\n", err) return } + // lock while replacing the map e.s.Lock() e.m = errorPageMap @@ -97,6 +109,8 @@ func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) err log.Printf("[ErrorPages] WARNING: ignoring invalid error page in error pages directory: '%s'\n", name) continue } + + // check if code is in range 100-599 if nameInt < 100 || nameInt >= 600 { log.Printf("[ErrorPages] WARNING: ignoring invalid error page in error pages directory must be 100-599: '%s'\n", name) continue @@ -108,6 +122,7 @@ func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) err return fmt.Errorf("failed to read html file '%s': %w", name, err) } + // create a callback function to write the page m[nameInt] = func(rw http.ResponseWriter) { rw.Header().Set("Content-Type", "text/html; encoding=utf-8") rw.WriteHeader(nameInt) diff --git a/favicons/favicons.go b/favicons/favicons.go index aff791b..372c60f 100644 --- a/favicons/favicons.go +++ b/favicons/favicons.go @@ -18,6 +18,15 @@ import ( var ErrFaviconNotFound = errors.New("favicon not found") +// Favicons is a dynamic favicon generator which supports overwriting favicons +type Favicons struct { + db *sql.DB + cmd string + cLock *sync.RWMutex + faviconMap map[string]*FaviconList +} + +// New creates a new dynamic favicon generator func New(db *sql.DB, inkscapeCmd string) *Favicons { f := &Favicons{ db: db, @@ -29,7 +38,7 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons { // init favicons table _, err := f.db.Exec(`create table if not exists favicons (id integer primary key autoincrement, host varchar, svg varchar, png varchar, ico varchar)`) if err != nil { - log.Printf("[WARN] Failed to generate 'domains' table\n") + log.Printf("[WARN] Failed to generate 'favicons' table\n") return nil } @@ -38,22 +47,22 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons { return f } -type Favicons struct { - db *sql.DB - cmd string - cLock *sync.RWMutex - faviconMap map[string]*FaviconList -} - +// Compile downloads the list of favicon mappings from the database and loads +// them and the target favicons into memory for faster lookups func (f *Favicons) Compile() { + // async compile magic go func() { + // new map favicons := make(map[string]*FaviconList) + + // compile map and check errors err := f.internalCompile(favicons) if err != nil { // log compile errors log.Printf("[Favicons] Compile failed: %s\n", err) return } + // lock while replacing the map f.cLock.Lock() f.faviconMap = favicons @@ -61,22 +70,27 @@ func (f *Favicons) Compile() { }() } -func (f *Favicons) GetIcons(host string) (*FaviconList, bool) { +// GetIcons returns the favicon list for the provided host or nil if no +// icon is found or generated +func (f *Favicons) GetIcons(host string) *FaviconList { + // read lock for safety f.cLock.RLock() defer f.cLock.RUnlock() - if a, ok := f.faviconMap[host]; ok { - return a, true - } - return nil, false + + // return value from map + return f.faviconMap[host] } +// internalCompile is a hidden internal method for loading and generating all +// favicons. func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error { // query all rows in database - query, err := f.db.Query(`select * from favicons`) + query, err := f.db.Query(`select host, svg, png, ico from favicons`) if err != nil { return fmt.Errorf("failed to prepare query: %w", err) } + // loop over rows and scan in data using error group to catch errors var g errgroup.Group for query.Next() { var host, rawSvg, rawPng, rawIco string @@ -85,12 +99,17 @@ func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error { return fmt.Errorf("failed to scan row: %w", err) } + // create favicon list for this row l := &FaviconList{ Ico: CreateFaviconImage(rawIco), Png: CreateFaviconImage(rawPng), Svg: CreateFaviconImage(rawSvg), } + + // save the favicon list to the map faviconMap[host] = l + + // run the pre-process in a separate goroutine g.Go(func() error { return l.PreProcess(f.convertSvgToPng) }) @@ -98,16 +117,19 @@ func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error { return g.Wait() } +// convertSvgToPng calls svg2png which runs inkscape in a subprocess func (f *Favicons) convertSvgToPng(in []byte) ([]byte, error) { return svg2png(f.cmd, in) } +// FaviconList contains the ico, png and svg icons for separate favicons type FaviconList struct { Ico *FaviconImage // can be generated from png with wrapper Png *FaviconImage // can be generated from svg with inkscape Svg *FaviconImage } +// ProduceIco outputs the bytes of the ico icon or an error func (l *FaviconList) ProduceIco() ([]byte, error) { if l.Ico == nil { return nil, ErrFaviconNotFound @@ -115,6 +137,7 @@ func (l *FaviconList) ProduceIco() ([]byte, error) { return l.Ico.Raw, nil } +// ProducePng outputs the bytes of the png icon or an error func (l *FaviconList) ProducePng() ([]byte, error) { if l.Png == nil { return nil, ErrFaviconNotFound @@ -122,6 +145,7 @@ func (l *FaviconList) ProducePng() ([]byte, error) { return l.Png.Raw, nil } +// ProduceSvg outputs the bytes of the svg icon or an error func (l *FaviconList) ProduceSvg() ([]byte, error) { if l.Svg == nil { return nil, ErrFaviconNotFound @@ -129,6 +153,8 @@ func (l *FaviconList) ProduceSvg() ([]byte, error) { return l.Svg.Raw, nil } +// PreProcess takes an input of the svg2png conversion function and outputs +// an error if the SVG, PNG or ICO fails to download or generate func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error { var err error @@ -178,10 +204,13 @@ func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error return fmt.Errorf("[Favicons] Failed to generate ICO icon: %w", err) } } + + // generate sha256 hashes for svg, png and ico l.genSha256() return nil } +// genSha256 generates sha256 hashes func (l *FaviconList) genSha256() { if l.Svg != nil { l.Svg.Hash = genSha256(l.Svg.Raw) @@ -194,6 +223,8 @@ func (l *FaviconList) genSha256() { } } +// getFaviconViaRequest uses the standard http request library to download +// icons, outputs the raw bytes from the download or an error. func getFaviconViaRequest(url string) ([]byte, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { @@ -210,21 +241,27 @@ func getFaviconViaRequest(url string) ([]byte, error) { return rawBody, nil } +// genSha256 generates a sha256 hash as a hex encoded string func genSha256(in []byte) string { + // create sha256 generator and write to it h := sha256.New() _, err := h.Write(in) if err != nil { return "" } + // encode as hex return hex.EncodeToString(h.Sum(nil)) } +// FaviconImage stores the url, hash and raw bytes of an image type FaviconImage struct { Url string Hash string Raw []byte } +// CreateFaviconImage outputs a FaviconImage with the specified URL or nil if +// the URL is an empty string. func CreateFaviconImage(url string) *FaviconImage { if url == "" { return nil diff --git a/favicons/svg2png.go b/favicons/svg2png.go index 269322b..de2caf1 100644 --- a/favicons/svg2png.go +++ b/favicons/svg2png.go @@ -6,24 +6,31 @@ import ( "os/exec" ) +// svg2png takes an input inkscape binary path and svg image bytes and outputs +// the png image bytes or an error. func svg2png(inkscapeCmd string, in []byte) (out []byte, err error) { + // create stdout and stderr buffers var stdout, stderr bytes.Buffer + // prepare inkscape command and attach buffers cmd := exec.Command(inkscapeCmd, "--export-type", "png", "--export-filename", "-", "--export-background-opacity", "0", "--pipe") cmd.Stdin = bytes.NewBuffer(in) cmd.Stdout = &stdout cmd.Stderr = &stderr + // run the command and return errors if e := cmd.Run(); e != nil { err = fmt.Errorf("%s\nSTDERR:\n%s", e.Error(), stderr.String()) return } + // error if there is no output if stdout.Len() == 0 { err = fmt.Errorf("got no data from inkscape") return } + // return the raw bytes out = stdout.Bytes() return } diff --git a/servers/api.go b/servers/api.go index f26963b..292ebc2 100644 --- a/servers/api.go +++ b/servers/api.go @@ -46,7 +46,7 @@ func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server // Create and run http server s := &http.Server{ - Addr: listen, + Addr: conf.ApiListen, Handler: r, ReadTimeout: time.Minute, ReadHeaderTimeout: time.Minute, diff --git a/utils/domain-utils.go b/utils/domain-utils.go index 2762c53..fee486f 100644 --- a/utils/domain-utils.go +++ b/utils/domain-utils.go @@ -5,6 +5,8 @@ import ( "strings" ) +// SplitDomainPort takes an input host and default port then outputs the domain, +// port and true or empty values and false if the split failed func SplitDomainPort(host string, defaultPort int) (domain string, port int, ok bool) { a := strings.SplitN(host, ":", 2) switch len(a) { @@ -21,42 +23,63 @@ func SplitDomainPort(host string, defaultPort int) (domain string, port int, ok return } +// GetDomainWithoutPort takes an input domain + port and outputs the domain +// without the port. +// +// example.com:443 => example.com func GetDomainWithoutPort(domain string) (string, bool) { - a := strings.SplitN(domain, ":", 2) - if len(a) == 2 { - return a[0], true - } - if len(a) == 0 { + // if a valid index isn't found then return false + n := strings.LastIndexByte(domain, ':') + if n == -1 { return "", false } - return a[0], true + return domain[:n], true } +// ReplaceSubdomainWithWildcard returns the domain with the subdomain replaced +// with a wildcard '*' character. +// +// www.example.com => *.example.com func ReplaceSubdomainWithWildcard(domain string) (string, bool) { - a, b := GetBaseDomain(domain) - return "*." + a, b + // if a valid index isn't found then return false + n := strings.IndexByte(domain, '.') + if n == -1 { + return "", false + } + return "*" + domain[n:], true } -func GetBaseDomain(domain string) (string, bool) { - a := strings.SplitN(domain, ".", 2) - l := len(a) - if l == 2 { - return a[1], true +// GetParentDomain returns the parent domain stripping off the subdomain. +// +// www.example.com => example.com +func GetParentDomain(domain string) (string, bool) { + // if a valid index isn't found then return false + n := strings.IndexByte(domain, '.') + if n == -1 { + return "", false } - if l == 1 { - return a[0], true - } - return "", false + return domain[n+1:], true } +// GetTopFqdn returns the top domain stripping off multiple layers of subdomains. +// +// hello.world.example.com => example.com func GetTopFqdn(domain string) (string, bool) { - a := strings.Split(domain, ".") - l := len(a) - if l >= 2 { - return strings.Join(a[l-2:], "."), true + var countDot int + n := strings.LastIndexFunc(domain, func(r rune) bool { + // return true if this is the second '.' + // otherwise counts one and continues + if r == '.' { + if countDot == 1 { + return true + } + countDot++ + } + return false + }) + // if a valid index isn't found then return false + if n == -1 { + return "", false } - if l == 1 { - return domain, true - } - return "", false + return domain[n+1:], true } diff --git a/utils/domain-utils_test.go b/utils/domain-utils_test.go index 4eca7d9..46251e1 100644 --- a/utils/domain-utils_test.go +++ b/utils/domain-utils_test.go @@ -38,11 +38,11 @@ func TestReplaceSubdomainWithWildcard(t *testing.T) { } func TestGetBaseDomain(t *testing.T) { - domain, ok := GetBaseDomain("www.example.com") + domain, ok := GetParentDomain("www.example.com") assert.True(t, ok, "Output should be true") assert.Equal(t, "example.com", domain) - domain, ok = GetBaseDomain("www.example.com:5612") + domain, ok = GetParentDomain("www.example.com:5612") assert.True(t, ok, "Output should be true") assert.Equal(t, "example.com:5612", domain) } diff --git a/utils/fast-redirect.go b/utils/fast-redirect.go index e82784b..1e51b21 100644 --- a/utils/fast-redirect.go +++ b/utils/fast-redirect.go @@ -4,12 +4,8 @@ import ( "net/http" ) -var ( - a1 = []byte("") - a3 = []byte(".\n") -) - +// FastRedirect adds a location header, status code and if the method is get, +// outputs the status text. func FastRedirect(rw http.ResponseWriter, req *http.Request, url string, code int) { rw.Header().Add("Location", url) rw.WriteHeader(code) diff --git a/utils/multi-compilable.go b/utils/multi-compilable.go index ec76a8f..5c5ace9 100644 --- a/utils/multi-compilable.go +++ b/utils/multi-compilable.go @@ -1,11 +1,15 @@ package utils +// Compilable is an interface for structs with an asynchronous compile method. type Compilable interface { Compile() } +// MultiCompilable is a slice of multiple Compilable interfaces. type MultiCompilable []Compilable +// Compile loops over the slice of Compilable interfaces and calls Compile on +// each one. func (m MultiCompilable) Compile() { for _, i := range m { i.Compile() diff --git a/utils/server-utils.go b/utils/server-utils.go index 30f4ca6..e390693 100644 --- a/utils/server-utils.go +++ b/utils/server-utils.go @@ -24,6 +24,8 @@ func RunBackgroundHttps(prefix string, s *http.Server) { logHttpServerError(prefix, s.ListenAndServeTLS("", "")) } +// GetBearer returns the bearer from the Authorization header or an empty string +// if the authorization is empty or doesn't start with Bearer. func GetBearer(req *http.Request) string { a := req.Header.Get("Authorization") if t, ok := strings.CutPrefix(a, "Bearer "); ok {