Comments!

This commit is contained in:
Melon 2023-04-24 01:35:23 +01:00
parent 71f8aaaf16
commit 0551e15979
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
12 changed files with 149 additions and 50 deletions

View File

@ -11,6 +11,7 @@ import (
"sync"
)
// Certs is the certificate loader and management system.
type Certs struct {
cDir fs.FS
kDir fs.FS
@ -18,6 +19,7 @@ type Certs struct {
m map[string]*tls.Certificate
}
// New creates a new cert list
func New(certDir fs.FS, keyDir fs.FS) *Certs {
a := &Certs{
cDir: certDir,
@ -25,6 +27,8 @@ func New(certDir fs.FS, keyDir fs.FS) *Certs {
s: &sync.RWMutex{},
m: make(map[string]*tls.Certificate),
}
// run compile to get the initial data
a.Compile()
return a
}
@ -62,6 +66,7 @@ func (c *Certs) Compile() {
log.Printf("[Certs] Compile failed: %s\n", err)
return
}
// lock while replacing the map
c.s.Lock()
c.m = certMap
@ -69,6 +74,8 @@ func (c *Certs) Compile() {
}()
}
// internalCompile is a hidden internal method for loading the certificate and
// key files
func (c *Certs) internalCompile(m map[string]*tls.Certificate) error {
// try to read dir
files, err := fs.ReadDir(c.cDir, "")

View File

@ -16,8 +16,9 @@ import (
"os"
)
// flags - each one has a usage field lol
var (
databasePath = flag.String("db", "", "/path/to/database.sqlite")
databasePath = flag.String("db", "", "/path/to/database.sqlite : path to the database file")
keyPath = flag.String("keys", "", "/path/to/keys : path contains the keys with names matching the certificates and '.key' extensions")
certPath = flag.String("certs", "", "/path/to/certificates : path contains the certificates to load in armoured PEM encoding")
errorPagePath = flag.String("errors", "", "/path/to/error-pages : path contains the custom error pages")
@ -30,11 +31,12 @@ var (
func main() {
log.Println("[Violet] Starting...")
// create paths
// create path to cert dir
err := os.MkdirAll(*certPath, os.ModePerm)
if err != nil {
log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath)
}
// create path to key dir
err = os.MkdirAll(*keyPath, os.ModePerm)
if err != nil {
log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath)
@ -52,6 +54,7 @@ func main() {
dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider
dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider
// struct containing config for the http servers
srvConf := &servers.Conf{
ApiListen: *apiListen,
HttpListen: *httpListen,

View File

@ -8,6 +8,7 @@ import (
"sync"
)
// Domains is the domain list and management system.
type Domains struct {
db *sql.DB
s *sync.RWMutex
@ -64,12 +65,16 @@ func (d *Domains) IsValid(host string) bool {
func (d *Domains) Compile() {
// async compile magic
go func() {
// new map
domainMap := make(map[string]struct{})
// compile map and check errors
err := d.internalCompile(domainMap)
if err != nil {
log.Printf("[Domains] Compile failed: %s\n", err)
return
}
// lock while replacing the map
d.s.Lock()
d.m = domainMap

View File

@ -20,29 +20,40 @@ type ErrorPages struct {
dir fs.FS
}
// New creates a new error pages generator
func New(dir fs.FS) *ErrorPages {
return &ErrorPages{
s: &sync.RWMutex{},
m: make(map[int]func(rw http.ResponseWriter)),
// generic error page writer
generic: func(rw http.ResponseWriter, code int) {
// if status text is empty then the code is unknown
a := http.StatusText(code)
if a != "" {
// output in "xxx Error Text" format
http.Error(rw, fmt.Sprintf("%d %s\n", code, a), code)
return
}
// output the code and generic unknown message
http.Error(rw, fmt.Sprintf("%d Unknown Error Code\n", code), code)
},
dir: dir,
}
}
// ServeError writes the error page for the given code to the response writer
func (e *ErrorPages) ServeError(rw http.ResponseWriter, code int) {
// read lock for safety
e.s.RLock()
defer e.s.RUnlock()
// use the custom error page if it exists
if p, ok := e.m[code]; ok {
p(rw)
return
}
// otherwise use the generic error page
e.generic(rw, code)
}
@ -58,6 +69,7 @@ func (e *ErrorPages) Compile() {
log.Printf("[Certs] Compile failed: %s\n", err)
return
}
// lock while replacing the map
e.s.Lock()
e.m = errorPageMap
@ -97,6 +109,8 @@ func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) err
log.Printf("[ErrorPages] WARNING: ignoring invalid error page in error pages directory: '%s'\n", name)
continue
}
// check if code is in range 100-599
if nameInt < 100 || nameInt >= 600 {
log.Printf("[ErrorPages] WARNING: ignoring invalid error page in error pages directory must be 100-599: '%s'\n", name)
continue
@ -108,6 +122,7 @@ func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) err
return fmt.Errorf("failed to read html file '%s': %w", name, err)
}
// create a callback function to write the page
m[nameInt] = func(rw http.ResponseWriter) {
rw.Header().Set("Content-Type", "text/html; encoding=utf-8")
rw.WriteHeader(nameInt)

View File

@ -18,6 +18,15 @@ import (
var ErrFaviconNotFound = errors.New("favicon not found")
// Favicons is a dynamic favicon generator which supports overwriting favicons
type Favicons struct {
db *sql.DB
cmd string
cLock *sync.RWMutex
faviconMap map[string]*FaviconList
}
// New creates a new dynamic favicon generator
func New(db *sql.DB, inkscapeCmd string) *Favicons {
f := &Favicons{
db: db,
@ -29,7 +38,7 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons {
// init favicons table
_, err := f.db.Exec(`create table if not exists favicons (id integer primary key autoincrement, host varchar, svg varchar, png varchar, ico varchar)`)
if err != nil {
log.Printf("[WARN] Failed to generate 'domains' table\n")
log.Printf("[WARN] Failed to generate 'favicons' table\n")
return nil
}
@ -38,22 +47,22 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons {
return f
}
type Favicons struct {
db *sql.DB
cmd string
cLock *sync.RWMutex
faviconMap map[string]*FaviconList
}
// Compile downloads the list of favicon mappings from the database and loads
// them and the target favicons into memory for faster lookups
func (f *Favicons) Compile() {
// async compile magic
go func() {
// new map
favicons := make(map[string]*FaviconList)
// compile map and check errors
err := f.internalCompile(favicons)
if err != nil {
// log compile errors
log.Printf("[Favicons] Compile failed: %s\n", err)
return
}
// lock while replacing the map
f.cLock.Lock()
f.faviconMap = favicons
@ -61,22 +70,27 @@ func (f *Favicons) Compile() {
}()
}
func (f *Favicons) GetIcons(host string) (*FaviconList, bool) {
// GetIcons returns the favicon list for the provided host or nil if no
// icon is found or generated
func (f *Favicons) GetIcons(host string) *FaviconList {
// read lock for safety
f.cLock.RLock()
defer f.cLock.RUnlock()
if a, ok := f.faviconMap[host]; ok {
return a, true
}
return nil, false
// return value from map
return f.faviconMap[host]
}
// internalCompile is a hidden internal method for loading and generating all
// favicons.
func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error {
// query all rows in database
query, err := f.db.Query(`select * from favicons`)
query, err := f.db.Query(`select host, svg, png, ico from favicons`)
if err != nil {
return fmt.Errorf("failed to prepare query: %w", err)
}
// loop over rows and scan in data using error group to catch errors
var g errgroup.Group
for query.Next() {
var host, rawSvg, rawPng, rawIco string
@ -85,12 +99,17 @@ func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error {
return fmt.Errorf("failed to scan row: %w", err)
}
// create favicon list for this row
l := &FaviconList{
Ico: CreateFaviconImage(rawIco),
Png: CreateFaviconImage(rawPng),
Svg: CreateFaviconImage(rawSvg),
}
// save the favicon list to the map
faviconMap[host] = l
// run the pre-process in a separate goroutine
g.Go(func() error {
return l.PreProcess(f.convertSvgToPng)
})
@ -98,16 +117,19 @@ func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error {
return g.Wait()
}
// convertSvgToPng calls svg2png which runs inkscape in a subprocess
func (f *Favicons) convertSvgToPng(in []byte) ([]byte, error) {
return svg2png(f.cmd, in)
}
// FaviconList contains the ico, png and svg icons for separate favicons
type FaviconList struct {
Ico *FaviconImage // can be generated from png with wrapper
Png *FaviconImage // can be generated from svg with inkscape
Svg *FaviconImage
}
// ProduceIco outputs the bytes of the ico icon or an error
func (l *FaviconList) ProduceIco() ([]byte, error) {
if l.Ico == nil {
return nil, ErrFaviconNotFound
@ -115,6 +137,7 @@ func (l *FaviconList) ProduceIco() ([]byte, error) {
return l.Ico.Raw, nil
}
// ProducePng outputs the bytes of the png icon or an error
func (l *FaviconList) ProducePng() ([]byte, error) {
if l.Png == nil {
return nil, ErrFaviconNotFound
@ -122,6 +145,7 @@ func (l *FaviconList) ProducePng() ([]byte, error) {
return l.Png.Raw, nil
}
// ProduceSvg outputs the bytes of the svg icon or an error
func (l *FaviconList) ProduceSvg() ([]byte, error) {
if l.Svg == nil {
return nil, ErrFaviconNotFound
@ -129,6 +153,8 @@ func (l *FaviconList) ProduceSvg() ([]byte, error) {
return l.Svg.Raw, nil
}
// PreProcess takes an input of the svg2png conversion function and outputs
// an error if the SVG, PNG or ICO fails to download or generate
func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error {
var err error
@ -178,10 +204,13 @@ func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error
return fmt.Errorf("[Favicons] Failed to generate ICO icon: %w", err)
}
}
// generate sha256 hashes for svg, png and ico
l.genSha256()
return nil
}
// genSha256 generates sha256 hashes
func (l *FaviconList) genSha256() {
if l.Svg != nil {
l.Svg.Hash = genSha256(l.Svg.Raw)
@ -194,6 +223,8 @@ func (l *FaviconList) genSha256() {
}
}
// getFaviconViaRequest uses the standard http request library to download
// icons, outputs the raw bytes from the download or an error.
func getFaviconViaRequest(url string) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
@ -210,21 +241,27 @@ func getFaviconViaRequest(url string) ([]byte, error) {
return rawBody, nil
}
// genSha256 generates a sha256 hash as a hex encoded string
func genSha256(in []byte) string {
// create sha256 generator and write to it
h := sha256.New()
_, err := h.Write(in)
if err != nil {
return ""
}
// encode as hex
return hex.EncodeToString(h.Sum(nil))
}
// FaviconImage stores the url, hash and raw bytes of an image
type FaviconImage struct {
Url string
Hash string
Raw []byte
}
// CreateFaviconImage outputs a FaviconImage with the specified URL or nil if
// the URL is an empty string.
func CreateFaviconImage(url string) *FaviconImage {
if url == "" {
return nil

View File

@ -6,24 +6,31 @@ import (
"os/exec"
)
// svg2png takes an input inkscape binary path and svg image bytes and outputs
// the png image bytes or an error.
func svg2png(inkscapeCmd string, in []byte) (out []byte, err error) {
// create stdout and stderr buffers
var stdout, stderr bytes.Buffer
// prepare inkscape command and attach buffers
cmd := exec.Command(inkscapeCmd, "--export-type", "png", "--export-filename", "-", "--export-background-opacity", "0", "--pipe")
cmd.Stdin = bytes.NewBuffer(in)
cmd.Stdout = &stdout
cmd.Stderr = &stderr
// run the command and return errors
if e := cmd.Run(); e != nil {
err = fmt.Errorf("%s\nSTDERR:\n%s", e.Error(), stderr.String())
return
}
// error if there is no output
if stdout.Len() == 0 {
err = fmt.Errorf("got no data from inkscape")
return
}
// return the raw bytes
out = stdout.Bytes()
return
}

View File

@ -46,7 +46,7 @@ func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server
// Create and run http server
s := &http.Server{
Addr: listen,
Addr: conf.ApiListen,
Handler: r,
ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute,

View File

@ -5,6 +5,8 @@ import (
"strings"
)
// SplitDomainPort takes an input host and default port then outputs the domain,
// port and true or empty values and false if the split failed
func SplitDomainPort(host string, defaultPort int) (domain string, port int, ok bool) {
a := strings.SplitN(host, ":", 2)
switch len(a) {
@ -21,42 +23,63 @@ func SplitDomainPort(host string, defaultPort int) (domain string, port int, ok
return
}
// GetDomainWithoutPort takes an input domain + port and outputs the domain
// without the port.
//
// example.com:443 => example.com
func GetDomainWithoutPort(domain string) (string, bool) {
a := strings.SplitN(domain, ":", 2)
if len(a) == 2 {
return a[0], true
}
if len(a) == 0 {
// if a valid index isn't found then return false
n := strings.LastIndexByte(domain, ':')
if n == -1 {
return "", false
}
return a[0], true
return domain[:n], true
}
// ReplaceSubdomainWithWildcard returns the domain with the subdomain replaced
// with a wildcard '*' character.
//
// www.example.com => *.example.com
func ReplaceSubdomainWithWildcard(domain string) (string, bool) {
a, b := GetBaseDomain(domain)
return "*." + a, b
}
func GetBaseDomain(domain string) (string, bool) {
a := strings.SplitN(domain, ".", 2)
l := len(a)
if l == 2 {
return a[1], true
}
if l == 1 {
return a[0], true
}
// if a valid index isn't found then return false
n := strings.IndexByte(domain, '.')
if n == -1 {
return "", false
}
return "*" + domain[n:], true
}
// GetParentDomain returns the parent domain stripping off the subdomain.
//
// www.example.com => example.com
func GetParentDomain(domain string) (string, bool) {
// if a valid index isn't found then return false
n := strings.IndexByte(domain, '.')
if n == -1 {
return "", false
}
return domain[n+1:], true
}
// GetTopFqdn returns the top domain stripping off multiple layers of subdomains.
//
// hello.world.example.com => example.com
func GetTopFqdn(domain string) (string, bool) {
a := strings.Split(domain, ".")
l := len(a)
if l >= 2 {
return strings.Join(a[l-2:], "."), true
var countDot int
n := strings.LastIndexFunc(domain, func(r rune) bool {
// return true if this is the second '.'
// otherwise counts one and continues
if r == '.' {
if countDot == 1 {
return true
}
if l == 1 {
return domain, true
countDot++
}
return false
})
// if a valid index isn't found then return false
if n == -1 {
return "", false
}
return domain[n+1:], true
}

View File

@ -38,11 +38,11 @@ func TestReplaceSubdomainWithWildcard(t *testing.T) {
}
func TestGetBaseDomain(t *testing.T) {
domain, ok := GetBaseDomain("www.example.com")
domain, ok := GetParentDomain("www.example.com")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain)
domain, ok = GetBaseDomain("www.example.com:5612")
domain, ok = GetParentDomain("www.example.com:5612")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com:5612", domain)
}

View File

@ -4,12 +4,8 @@ import (
"net/http"
)
var (
a1 = []byte("<a href=\"")
a2 = []byte("\">")
a3 = []byte("</a>.\n")
)
// FastRedirect adds a location header, status code and if the method is get,
// outputs the status text.
func FastRedirect(rw http.ResponseWriter, req *http.Request, url string, code int) {
rw.Header().Add("Location", url)
rw.WriteHeader(code)

View File

@ -1,11 +1,15 @@
package utils
// Compilable is an interface for structs with an asynchronous compile method.
type Compilable interface {
Compile()
}
// MultiCompilable is a slice of multiple Compilable interfaces.
type MultiCompilable []Compilable
// Compile loops over the slice of Compilable interfaces and calls Compile on
// each one.
func (m MultiCompilable) Compile() {
for _, i := range m {
i.Compile()

View File

@ -24,6 +24,8 @@ func RunBackgroundHttps(prefix string, s *http.Server) {
logHttpServerError(prefix, s.ListenAndServeTLS("", ""))
}
// GetBearer returns the bearer from the Authorization header or an empty string
// if the authorization is empty or doesn't start with Bearer.
func GetBearer(req *http.Request) string {
a := req.Header.Get("Authorization")
if t, ok := strings.CutPrefix(a, "Bearer "); ok {