mirror of
https://github.com/1f349/violet.git
synced 2024-11-21 10:51:40 +00:00
Comments!
This commit is contained in:
parent
71f8aaaf16
commit
0551e15979
@ -11,6 +11,7 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Certs is the certificate loader and management system.
|
||||
type Certs struct {
|
||||
cDir fs.FS
|
||||
kDir fs.FS
|
||||
@ -18,6 +19,7 @@ type Certs struct {
|
||||
m map[string]*tls.Certificate
|
||||
}
|
||||
|
||||
// New creates a new cert list
|
||||
func New(certDir fs.FS, keyDir fs.FS) *Certs {
|
||||
a := &Certs{
|
||||
cDir: certDir,
|
||||
@ -25,6 +27,8 @@ func New(certDir fs.FS, keyDir fs.FS) *Certs {
|
||||
s: &sync.RWMutex{},
|
||||
m: make(map[string]*tls.Certificate),
|
||||
}
|
||||
|
||||
// run compile to get the initial data
|
||||
a.Compile()
|
||||
return a
|
||||
}
|
||||
@ -62,6 +66,7 @@ func (c *Certs) Compile() {
|
||||
log.Printf("[Certs] Compile failed: %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// lock while replacing the map
|
||||
c.s.Lock()
|
||||
c.m = certMap
|
||||
@ -69,6 +74,8 @@ func (c *Certs) Compile() {
|
||||
}()
|
||||
}
|
||||
|
||||
// internalCompile is a hidden internal method for loading the certificate and
|
||||
// key files
|
||||
func (c *Certs) internalCompile(m map[string]*tls.Certificate) error {
|
||||
// try to read dir
|
||||
files, err := fs.ReadDir(c.cDir, "")
|
||||
|
@ -16,8 +16,9 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// flags - each one has a usage field lol
|
||||
var (
|
||||
databasePath = flag.String("db", "", "/path/to/database.sqlite")
|
||||
databasePath = flag.String("db", "", "/path/to/database.sqlite : path to the database file")
|
||||
keyPath = flag.String("keys", "", "/path/to/keys : path contains the keys with names matching the certificates and '.key' extensions")
|
||||
certPath = flag.String("certs", "", "/path/to/certificates : path contains the certificates to load in armoured PEM encoding")
|
||||
errorPagePath = flag.String("errors", "", "/path/to/error-pages : path contains the custom error pages")
|
||||
@ -30,11 +31,12 @@ var (
|
||||
func main() {
|
||||
log.Println("[Violet] Starting...")
|
||||
|
||||
// create paths
|
||||
// create path to cert dir
|
||||
err := os.MkdirAll(*certPath, os.ModePerm)
|
||||
if err != nil {
|
||||
log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath)
|
||||
}
|
||||
// create path to key dir
|
||||
err = os.MkdirAll(*keyPath, os.ModePerm)
|
||||
if err != nil {
|
||||
log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath)
|
||||
@ -52,6 +54,7 @@ func main() {
|
||||
dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider
|
||||
dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider
|
||||
|
||||
// struct containing config for the http servers
|
||||
srvConf := &servers.Conf{
|
||||
ApiListen: *apiListen,
|
||||
HttpListen: *httpListen,
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Domains is the domain list and management system.
|
||||
type Domains struct {
|
||||
db *sql.DB
|
||||
s *sync.RWMutex
|
||||
@ -64,12 +65,16 @@ func (d *Domains) IsValid(host string) bool {
|
||||
func (d *Domains) Compile() {
|
||||
// async compile magic
|
||||
go func() {
|
||||
// new map
|
||||
domainMap := make(map[string]struct{})
|
||||
|
||||
// compile map and check errors
|
||||
err := d.internalCompile(domainMap)
|
||||
if err != nil {
|
||||
log.Printf("[Domains] Compile failed: %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// lock while replacing the map
|
||||
d.s.Lock()
|
||||
d.m = domainMap
|
||||
|
@ -20,29 +20,40 @@ type ErrorPages struct {
|
||||
dir fs.FS
|
||||
}
|
||||
|
||||
// New creates a new error pages generator
|
||||
func New(dir fs.FS) *ErrorPages {
|
||||
return &ErrorPages{
|
||||
s: &sync.RWMutex{},
|
||||
m: make(map[int]func(rw http.ResponseWriter)),
|
||||
// generic error page writer
|
||||
generic: func(rw http.ResponseWriter, code int) {
|
||||
// if status text is empty then the code is unknown
|
||||
a := http.StatusText(code)
|
||||
if a != "" {
|
||||
// output in "xxx Error Text" format
|
||||
http.Error(rw, fmt.Sprintf("%d %s\n", code, a), code)
|
||||
return
|
||||
}
|
||||
// output the code and generic unknown message
|
||||
http.Error(rw, fmt.Sprintf("%d Unknown Error Code\n", code), code)
|
||||
},
|
||||
dir: dir,
|
||||
}
|
||||
}
|
||||
|
||||
// ServeError writes the error page for the given code to the response writer
|
||||
func (e *ErrorPages) ServeError(rw http.ResponseWriter, code int) {
|
||||
// read lock for safety
|
||||
e.s.RLock()
|
||||
defer e.s.RUnlock()
|
||||
|
||||
// use the custom error page if it exists
|
||||
if p, ok := e.m[code]; ok {
|
||||
p(rw)
|
||||
return
|
||||
}
|
||||
|
||||
// otherwise use the generic error page
|
||||
e.generic(rw, code)
|
||||
}
|
||||
|
||||
@ -58,6 +69,7 @@ func (e *ErrorPages) Compile() {
|
||||
log.Printf("[Certs] Compile failed: %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// lock while replacing the map
|
||||
e.s.Lock()
|
||||
e.m = errorPageMap
|
||||
@ -97,6 +109,8 @@ func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) err
|
||||
log.Printf("[ErrorPages] WARNING: ignoring invalid error page in error pages directory: '%s'\n", name)
|
||||
continue
|
||||
}
|
||||
|
||||
// check if code is in range 100-599
|
||||
if nameInt < 100 || nameInt >= 600 {
|
||||
log.Printf("[ErrorPages] WARNING: ignoring invalid error page in error pages directory must be 100-599: '%s'\n", name)
|
||||
continue
|
||||
@ -108,6 +122,7 @@ func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) err
|
||||
return fmt.Errorf("failed to read html file '%s': %w", name, err)
|
||||
}
|
||||
|
||||
// create a callback function to write the page
|
||||
m[nameInt] = func(rw http.ResponseWriter) {
|
||||
rw.Header().Set("Content-Type", "text/html; encoding=utf-8")
|
||||
rw.WriteHeader(nameInt)
|
||||
|
@ -18,6 +18,15 @@ import (
|
||||
|
||||
var ErrFaviconNotFound = errors.New("favicon not found")
|
||||
|
||||
// Favicons is a dynamic favicon generator which supports overwriting favicons
|
||||
type Favicons struct {
|
||||
db *sql.DB
|
||||
cmd string
|
||||
cLock *sync.RWMutex
|
||||
faviconMap map[string]*FaviconList
|
||||
}
|
||||
|
||||
// New creates a new dynamic favicon generator
|
||||
func New(db *sql.DB, inkscapeCmd string) *Favicons {
|
||||
f := &Favicons{
|
||||
db: db,
|
||||
@ -29,7 +38,7 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons {
|
||||
// init favicons table
|
||||
_, err := f.db.Exec(`create table if not exists favicons (id integer primary key autoincrement, host varchar, svg varchar, png varchar, ico varchar)`)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Failed to generate 'domains' table\n")
|
||||
log.Printf("[WARN] Failed to generate 'favicons' table\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -38,22 +47,22 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons {
|
||||
return f
|
||||
}
|
||||
|
||||
type Favicons struct {
|
||||
db *sql.DB
|
||||
cmd string
|
||||
cLock *sync.RWMutex
|
||||
faviconMap map[string]*FaviconList
|
||||
}
|
||||
|
||||
// Compile downloads the list of favicon mappings from the database and loads
|
||||
// them and the target favicons into memory for faster lookups
|
||||
func (f *Favicons) Compile() {
|
||||
// async compile magic
|
||||
go func() {
|
||||
// new map
|
||||
favicons := make(map[string]*FaviconList)
|
||||
|
||||
// compile map and check errors
|
||||
err := f.internalCompile(favicons)
|
||||
if err != nil {
|
||||
// log compile errors
|
||||
log.Printf("[Favicons] Compile failed: %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// lock while replacing the map
|
||||
f.cLock.Lock()
|
||||
f.faviconMap = favicons
|
||||
@ -61,22 +70,27 @@ func (f *Favicons) Compile() {
|
||||
}()
|
||||
}
|
||||
|
||||
func (f *Favicons) GetIcons(host string) (*FaviconList, bool) {
|
||||
// GetIcons returns the favicon list for the provided host or nil if no
|
||||
// icon is found or generated
|
||||
func (f *Favicons) GetIcons(host string) *FaviconList {
|
||||
// read lock for safety
|
||||
f.cLock.RLock()
|
||||
defer f.cLock.RUnlock()
|
||||
if a, ok := f.faviconMap[host]; ok {
|
||||
return a, true
|
||||
}
|
||||
return nil, false
|
||||
|
||||
// return value from map
|
||||
return f.faviconMap[host]
|
||||
}
|
||||
|
||||
// internalCompile is a hidden internal method for loading and generating all
|
||||
// favicons.
|
||||
func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error {
|
||||
// query all rows in database
|
||||
query, err := f.db.Query(`select * from favicons`)
|
||||
query, err := f.db.Query(`select host, svg, png, ico from favicons`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare query: %w", err)
|
||||
}
|
||||
|
||||
// loop over rows and scan in data using error group to catch errors
|
||||
var g errgroup.Group
|
||||
for query.Next() {
|
||||
var host, rawSvg, rawPng, rawIco string
|
||||
@ -85,12 +99,17 @@ func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error {
|
||||
return fmt.Errorf("failed to scan row: %w", err)
|
||||
}
|
||||
|
||||
// create favicon list for this row
|
||||
l := &FaviconList{
|
||||
Ico: CreateFaviconImage(rawIco),
|
||||
Png: CreateFaviconImage(rawPng),
|
||||
Svg: CreateFaviconImage(rawSvg),
|
||||
}
|
||||
|
||||
// save the favicon list to the map
|
||||
faviconMap[host] = l
|
||||
|
||||
// run the pre-process in a separate goroutine
|
||||
g.Go(func() error {
|
||||
return l.PreProcess(f.convertSvgToPng)
|
||||
})
|
||||
@ -98,16 +117,19 @@ func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error {
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
// convertSvgToPng calls svg2png which runs inkscape in a subprocess
|
||||
func (f *Favicons) convertSvgToPng(in []byte) ([]byte, error) {
|
||||
return svg2png(f.cmd, in)
|
||||
}
|
||||
|
||||
// FaviconList contains the ico, png and svg icons for separate favicons
|
||||
type FaviconList struct {
|
||||
Ico *FaviconImage // can be generated from png with wrapper
|
||||
Png *FaviconImage // can be generated from svg with inkscape
|
||||
Svg *FaviconImage
|
||||
}
|
||||
|
||||
// ProduceIco outputs the bytes of the ico icon or an error
|
||||
func (l *FaviconList) ProduceIco() ([]byte, error) {
|
||||
if l.Ico == nil {
|
||||
return nil, ErrFaviconNotFound
|
||||
@ -115,6 +137,7 @@ func (l *FaviconList) ProduceIco() ([]byte, error) {
|
||||
return l.Ico.Raw, nil
|
||||
}
|
||||
|
||||
// ProducePng outputs the bytes of the png icon or an error
|
||||
func (l *FaviconList) ProducePng() ([]byte, error) {
|
||||
if l.Png == nil {
|
||||
return nil, ErrFaviconNotFound
|
||||
@ -122,6 +145,7 @@ func (l *FaviconList) ProducePng() ([]byte, error) {
|
||||
return l.Png.Raw, nil
|
||||
}
|
||||
|
||||
// ProduceSvg outputs the bytes of the svg icon or an error
|
||||
func (l *FaviconList) ProduceSvg() ([]byte, error) {
|
||||
if l.Svg == nil {
|
||||
return nil, ErrFaviconNotFound
|
||||
@ -129,6 +153,8 @@ func (l *FaviconList) ProduceSvg() ([]byte, error) {
|
||||
return l.Svg.Raw, nil
|
||||
}
|
||||
|
||||
// PreProcess takes an input of the svg2png conversion function and outputs
|
||||
// an error if the SVG, PNG or ICO fails to download or generate
|
||||
func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error {
|
||||
var err error
|
||||
|
||||
@ -178,10 +204,13 @@ func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error
|
||||
return fmt.Errorf("[Favicons] Failed to generate ICO icon: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// generate sha256 hashes for svg, png and ico
|
||||
l.genSha256()
|
||||
return nil
|
||||
}
|
||||
|
||||
// genSha256 generates sha256 hashes
|
||||
func (l *FaviconList) genSha256() {
|
||||
if l.Svg != nil {
|
||||
l.Svg.Hash = genSha256(l.Svg.Raw)
|
||||
@ -194,6 +223,8 @@ func (l *FaviconList) genSha256() {
|
||||
}
|
||||
}
|
||||
|
||||
// getFaviconViaRequest uses the standard http request library to download
|
||||
// icons, outputs the raw bytes from the download or an error.
|
||||
func getFaviconViaRequest(url string) ([]byte, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
@ -210,21 +241,27 @@ func getFaviconViaRequest(url string) ([]byte, error) {
|
||||
return rawBody, nil
|
||||
}
|
||||
|
||||
// genSha256 generates a sha256 hash as a hex encoded string
|
||||
func genSha256(in []byte) string {
|
||||
// create sha256 generator and write to it
|
||||
h := sha256.New()
|
||||
_, err := h.Write(in)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
// encode as hex
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// FaviconImage stores the url, hash and raw bytes of an image
|
||||
type FaviconImage struct {
|
||||
Url string
|
||||
Hash string
|
||||
Raw []byte
|
||||
}
|
||||
|
||||
// CreateFaviconImage outputs a FaviconImage with the specified URL or nil if
|
||||
// the URL is an empty string.
|
||||
func CreateFaviconImage(url string) *FaviconImage {
|
||||
if url == "" {
|
||||
return nil
|
||||
|
@ -6,24 +6,31 @@ import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// svg2png takes an input inkscape binary path and svg image bytes and outputs
|
||||
// the png image bytes or an error.
|
||||
func svg2png(inkscapeCmd string, in []byte) (out []byte, err error) {
|
||||
// create stdout and stderr buffers
|
||||
var stdout, stderr bytes.Buffer
|
||||
|
||||
// prepare inkscape command and attach buffers
|
||||
cmd := exec.Command(inkscapeCmd, "--export-type", "png", "--export-filename", "-", "--export-background-opacity", "0", "--pipe")
|
||||
cmd.Stdin = bytes.NewBuffer(in)
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
// run the command and return errors
|
||||
if e := cmd.Run(); e != nil {
|
||||
err = fmt.Errorf("%s\nSTDERR:\n%s", e.Error(), stderr.String())
|
||||
return
|
||||
}
|
||||
|
||||
// error if there is no output
|
||||
if stdout.Len() == 0 {
|
||||
err = fmt.Errorf("got no data from inkscape")
|
||||
return
|
||||
}
|
||||
|
||||
// return the raw bytes
|
||||
out = stdout.Bytes()
|
||||
return
|
||||
}
|
||||
|
@ -46,7 +46,7 @@ func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server
|
||||
|
||||
// Create and run http server
|
||||
s := &http.Server{
|
||||
Addr: listen,
|
||||
Addr: conf.ApiListen,
|
||||
Handler: r,
|
||||
ReadTimeout: time.Minute,
|
||||
ReadHeaderTimeout: time.Minute,
|
||||
|
@ -5,6 +5,8 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SplitDomainPort takes an input host and default port then outputs the domain,
|
||||
// port and true or empty values and false if the split failed
|
||||
func SplitDomainPort(host string, defaultPort int) (domain string, port int, ok bool) {
|
||||
a := strings.SplitN(host, ":", 2)
|
||||
switch len(a) {
|
||||
@ -21,42 +23,63 @@ func SplitDomainPort(host string, defaultPort int) (domain string, port int, ok
|
||||
return
|
||||
}
|
||||
|
||||
// GetDomainWithoutPort takes an input domain + port and outputs the domain
|
||||
// without the port.
|
||||
//
|
||||
// example.com:443 => example.com
|
||||
func GetDomainWithoutPort(domain string) (string, bool) {
|
||||
a := strings.SplitN(domain, ":", 2)
|
||||
if len(a) == 2 {
|
||||
return a[0], true
|
||||
}
|
||||
if len(a) == 0 {
|
||||
// if a valid index isn't found then return false
|
||||
n := strings.LastIndexByte(domain, ':')
|
||||
if n == -1 {
|
||||
return "", false
|
||||
}
|
||||
return a[0], true
|
||||
return domain[:n], true
|
||||
}
|
||||
|
||||
// ReplaceSubdomainWithWildcard returns the domain with the subdomain replaced
|
||||
// with a wildcard '*' character.
|
||||
//
|
||||
// www.example.com => *.example.com
|
||||
func ReplaceSubdomainWithWildcard(domain string) (string, bool) {
|
||||
a, b := GetBaseDomain(domain)
|
||||
return "*." + a, b
|
||||
}
|
||||
|
||||
func GetBaseDomain(domain string) (string, bool) {
|
||||
a := strings.SplitN(domain, ".", 2)
|
||||
l := len(a)
|
||||
if l == 2 {
|
||||
return a[1], true
|
||||
}
|
||||
if l == 1 {
|
||||
return a[0], true
|
||||
}
|
||||
// if a valid index isn't found then return false
|
||||
n := strings.IndexByte(domain, '.')
|
||||
if n == -1 {
|
||||
return "", false
|
||||
}
|
||||
return "*" + domain[n:], true
|
||||
}
|
||||
|
||||
// GetParentDomain returns the parent domain stripping off the subdomain.
|
||||
//
|
||||
// www.example.com => example.com
|
||||
func GetParentDomain(domain string) (string, bool) {
|
||||
// if a valid index isn't found then return false
|
||||
n := strings.IndexByte(domain, '.')
|
||||
if n == -1 {
|
||||
return "", false
|
||||
}
|
||||
return domain[n+1:], true
|
||||
}
|
||||
|
||||
// GetTopFqdn returns the top domain stripping off multiple layers of subdomains.
|
||||
//
|
||||
// hello.world.example.com => example.com
|
||||
func GetTopFqdn(domain string) (string, bool) {
|
||||
a := strings.Split(domain, ".")
|
||||
l := len(a)
|
||||
if l >= 2 {
|
||||
return strings.Join(a[l-2:], "."), true
|
||||
var countDot int
|
||||
n := strings.LastIndexFunc(domain, func(r rune) bool {
|
||||
// return true if this is the second '.'
|
||||
// otherwise counts one and continues
|
||||
if r == '.' {
|
||||
if countDot == 1 {
|
||||
return true
|
||||
}
|
||||
if l == 1 {
|
||||
return domain, true
|
||||
countDot++
|
||||
}
|
||||
return false
|
||||
})
|
||||
// if a valid index isn't found then return false
|
||||
if n == -1 {
|
||||
return "", false
|
||||
}
|
||||
return domain[n+1:], true
|
||||
}
|
||||
|
@ -38,11 +38,11 @@ func TestReplaceSubdomainWithWildcard(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetBaseDomain(t *testing.T) {
|
||||
domain, ok := GetBaseDomain("www.example.com")
|
||||
domain, ok := GetParentDomain("www.example.com")
|
||||
assert.True(t, ok, "Output should be true")
|
||||
assert.Equal(t, "example.com", domain)
|
||||
|
||||
domain, ok = GetBaseDomain("www.example.com:5612")
|
||||
domain, ok = GetParentDomain("www.example.com:5612")
|
||||
assert.True(t, ok, "Output should be true")
|
||||
assert.Equal(t, "example.com:5612", domain)
|
||||
}
|
||||
|
@ -4,12 +4,8 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var (
|
||||
a1 = []byte("<a href=\"")
|
||||
a2 = []byte("\">")
|
||||
a3 = []byte("</a>.\n")
|
||||
)
|
||||
|
||||
// FastRedirect adds a location header, status code and if the method is get,
|
||||
// outputs the status text.
|
||||
func FastRedirect(rw http.ResponseWriter, req *http.Request, url string, code int) {
|
||||
rw.Header().Add("Location", url)
|
||||
rw.WriteHeader(code)
|
||||
|
@ -1,11 +1,15 @@
|
||||
package utils
|
||||
|
||||
// Compilable is an interface for structs with an asynchronous compile method.
|
||||
type Compilable interface {
|
||||
Compile()
|
||||
}
|
||||
|
||||
// MultiCompilable is a slice of multiple Compilable interfaces.
|
||||
type MultiCompilable []Compilable
|
||||
|
||||
// Compile loops over the slice of Compilable interfaces and calls Compile on
|
||||
// each one.
|
||||
func (m MultiCompilable) Compile() {
|
||||
for _, i := range m {
|
||||
i.Compile()
|
||||
|
@ -24,6 +24,8 @@ func RunBackgroundHttps(prefix string, s *http.Server) {
|
||||
logHttpServerError(prefix, s.ListenAndServeTLS("", ""))
|
||||
}
|
||||
|
||||
// GetBearer returns the bearer from the Authorization header or an empty string
|
||||
// if the authorization is empty or doesn't start with Bearer.
|
||||
func GetBearer(req *http.Request) string {
|
||||
a := req.Header.Get("Authorization")
|
||||
if t, ok := strings.CutPrefix(a, "Bearer "); ok {
|
||||
|
Loading…
Reference in New Issue
Block a user