Comments!

This commit is contained in:
Melon 2023-04-24 01:35:23 +01:00
parent 71f8aaaf16
commit 0551e15979
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
12 changed files with 149 additions and 50 deletions

View File

@ -11,6 +11,7 @@ import (
"sync" "sync"
) )
// Certs is the certificate loader and management system.
type Certs struct { type Certs struct {
cDir fs.FS cDir fs.FS
kDir fs.FS kDir fs.FS
@ -18,6 +19,7 @@ type Certs struct {
m map[string]*tls.Certificate m map[string]*tls.Certificate
} }
// New creates a new cert list
func New(certDir fs.FS, keyDir fs.FS) *Certs { func New(certDir fs.FS, keyDir fs.FS) *Certs {
a := &Certs{ a := &Certs{
cDir: certDir, cDir: certDir,
@ -25,6 +27,8 @@ func New(certDir fs.FS, keyDir fs.FS) *Certs {
s: &sync.RWMutex{}, s: &sync.RWMutex{},
m: make(map[string]*tls.Certificate), m: make(map[string]*tls.Certificate),
} }
// run compile to get the initial data
a.Compile() a.Compile()
return a return a
} }
@ -62,6 +66,7 @@ func (c *Certs) Compile() {
log.Printf("[Certs] Compile failed: %s\n", err) log.Printf("[Certs] Compile failed: %s\n", err)
return return
} }
// lock while replacing the map // lock while replacing the map
c.s.Lock() c.s.Lock()
c.m = certMap c.m = certMap
@ -69,6 +74,8 @@ func (c *Certs) Compile() {
}() }()
} }
// internalCompile is a hidden internal method for loading the certificate and
// key files
func (c *Certs) internalCompile(m map[string]*tls.Certificate) error { func (c *Certs) internalCompile(m map[string]*tls.Certificate) error {
// try to read dir // try to read dir
files, err := fs.ReadDir(c.cDir, "") files, err := fs.ReadDir(c.cDir, "")

View File

@ -16,8 +16,9 @@ import (
"os" "os"
) )
// flags - each one has a usage field lol
var ( var (
databasePath = flag.String("db", "", "/path/to/database.sqlite") databasePath = flag.String("db", "", "/path/to/database.sqlite : path to the database file")
keyPath = flag.String("keys", "", "/path/to/keys : path contains the keys with names matching the certificates and '.key' extensions") keyPath = flag.String("keys", "", "/path/to/keys : path contains the keys with names matching the certificates and '.key' extensions")
certPath = flag.String("certs", "", "/path/to/certificates : path contains the certificates to load in armoured PEM encoding") certPath = flag.String("certs", "", "/path/to/certificates : path contains the certificates to load in armoured PEM encoding")
errorPagePath = flag.String("errors", "", "/path/to/error-pages : path contains the custom error pages") errorPagePath = flag.String("errors", "", "/path/to/error-pages : path contains the custom error pages")
@ -30,11 +31,12 @@ var (
func main() { func main() {
log.Println("[Violet] Starting...") log.Println("[Violet] Starting...")
// create paths // create path to cert dir
err := os.MkdirAll(*certPath, os.ModePerm) err := os.MkdirAll(*certPath, os.ModePerm)
if err != nil { if err != nil {
log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath) log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath)
} }
// create path to key dir
err = os.MkdirAll(*keyPath, os.ModePerm) err = os.MkdirAll(*keyPath, os.ModePerm)
if err != nil { if err != nil {
log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath) log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath)
@ -52,6 +54,7 @@ func main() {
dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider
dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider
// struct containing config for the http servers
srvConf := &servers.Conf{ srvConf := &servers.Conf{
ApiListen: *apiListen, ApiListen: *apiListen,
HttpListen: *httpListen, HttpListen: *httpListen,

View File

@ -8,6 +8,7 @@ import (
"sync" "sync"
) )
// Domains is the domain list and management system.
type Domains struct { type Domains struct {
db *sql.DB db *sql.DB
s *sync.RWMutex s *sync.RWMutex
@ -64,12 +65,16 @@ func (d *Domains) IsValid(host string) bool {
func (d *Domains) Compile() { func (d *Domains) Compile() {
// async compile magic // async compile magic
go func() { go func() {
// new map
domainMap := make(map[string]struct{}) domainMap := make(map[string]struct{})
// compile map and check errors
err := d.internalCompile(domainMap) err := d.internalCompile(domainMap)
if err != nil { if err != nil {
log.Printf("[Domains] Compile failed: %s\n", err) log.Printf("[Domains] Compile failed: %s\n", err)
return return
} }
// lock while replacing the map // lock while replacing the map
d.s.Lock() d.s.Lock()
d.m = domainMap d.m = domainMap

View File

@ -20,29 +20,40 @@ type ErrorPages struct {
dir fs.FS dir fs.FS
} }
// New creates a new error pages generator
func New(dir fs.FS) *ErrorPages { func New(dir fs.FS) *ErrorPages {
return &ErrorPages{ return &ErrorPages{
s: &sync.RWMutex{}, s: &sync.RWMutex{},
m: make(map[int]func(rw http.ResponseWriter)), m: make(map[int]func(rw http.ResponseWriter)),
// generic error page writer
generic: func(rw http.ResponseWriter, code int) { generic: func(rw http.ResponseWriter, code int) {
// if status text is empty then the code is unknown
a := http.StatusText(code) a := http.StatusText(code)
if a != "" { if a != "" {
// output in "xxx Error Text" format
http.Error(rw, fmt.Sprintf("%d %s\n", code, a), code) http.Error(rw, fmt.Sprintf("%d %s\n", code, a), code)
return return
} }
// output the code and generic unknown message
http.Error(rw, fmt.Sprintf("%d Unknown Error Code\n", code), code) http.Error(rw, fmt.Sprintf("%d Unknown Error Code\n", code), code)
}, },
dir: dir, dir: dir,
} }
} }
// ServeError writes the error page for the given code to the response writer
func (e *ErrorPages) ServeError(rw http.ResponseWriter, code int) { func (e *ErrorPages) ServeError(rw http.ResponseWriter, code int) {
// read lock for safety
e.s.RLock() e.s.RLock()
defer e.s.RUnlock() defer e.s.RUnlock()
// use the custom error page if it exists
if p, ok := e.m[code]; ok { if p, ok := e.m[code]; ok {
p(rw) p(rw)
return return
} }
// otherwise use the generic error page
e.generic(rw, code) e.generic(rw, code)
} }
@ -58,6 +69,7 @@ func (e *ErrorPages) Compile() {
log.Printf("[Certs] Compile failed: %s\n", err) log.Printf("[Certs] Compile failed: %s\n", err)
return return
} }
// lock while replacing the map // lock while replacing the map
e.s.Lock() e.s.Lock()
e.m = errorPageMap e.m = errorPageMap
@ -97,6 +109,8 @@ func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) err
log.Printf("[ErrorPages] WARNING: ignoring invalid error page in error pages directory: '%s'\n", name) log.Printf("[ErrorPages] WARNING: ignoring invalid error page in error pages directory: '%s'\n", name)
continue continue
} }
// check if code is in range 100-599
if nameInt < 100 || nameInt >= 600 { if nameInt < 100 || nameInt >= 600 {
log.Printf("[ErrorPages] WARNING: ignoring invalid error page in error pages directory must be 100-599: '%s'\n", name) log.Printf("[ErrorPages] WARNING: ignoring invalid error page in error pages directory must be 100-599: '%s'\n", name)
continue continue
@ -108,6 +122,7 @@ func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) err
return fmt.Errorf("failed to read html file '%s': %w", name, err) return fmt.Errorf("failed to read html file '%s': %w", name, err)
} }
// create a callback function to write the page
m[nameInt] = func(rw http.ResponseWriter) { m[nameInt] = func(rw http.ResponseWriter) {
rw.Header().Set("Content-Type", "text/html; encoding=utf-8") rw.Header().Set("Content-Type", "text/html; encoding=utf-8")
rw.WriteHeader(nameInt) rw.WriteHeader(nameInt)

View File

@ -18,6 +18,15 @@ import (
var ErrFaviconNotFound = errors.New("favicon not found") var ErrFaviconNotFound = errors.New("favicon not found")
// Favicons is a dynamic favicon generator which supports overwriting favicons
type Favicons struct {
db *sql.DB
cmd string
cLock *sync.RWMutex
faviconMap map[string]*FaviconList
}
// New creates a new dynamic favicon generator
func New(db *sql.DB, inkscapeCmd string) *Favicons { func New(db *sql.DB, inkscapeCmd string) *Favicons {
f := &Favicons{ f := &Favicons{
db: db, db: db,
@ -29,7 +38,7 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons {
// init favicons table // init favicons table
_, err := f.db.Exec(`create table if not exists favicons (id integer primary key autoincrement, host varchar, svg varchar, png varchar, ico varchar)`) _, err := f.db.Exec(`create table if not exists favicons (id integer primary key autoincrement, host varchar, svg varchar, png varchar, ico varchar)`)
if err != nil { if err != nil {
log.Printf("[WARN] Failed to generate 'domains' table\n") log.Printf("[WARN] Failed to generate 'favicons' table\n")
return nil return nil
} }
@ -38,22 +47,22 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons {
return f return f
} }
type Favicons struct { // Compile downloads the list of favicon mappings from the database and loads
db *sql.DB // them and the target favicons into memory for faster lookups
cmd string
cLock *sync.RWMutex
faviconMap map[string]*FaviconList
}
func (f *Favicons) Compile() { func (f *Favicons) Compile() {
// async compile magic
go func() { go func() {
// new map
favicons := make(map[string]*FaviconList) favicons := make(map[string]*FaviconList)
// compile map and check errors
err := f.internalCompile(favicons) err := f.internalCompile(favicons)
if err != nil { if err != nil {
// log compile errors // log compile errors
log.Printf("[Favicons] Compile failed: %s\n", err) log.Printf("[Favicons] Compile failed: %s\n", err)
return return
} }
// lock while replacing the map // lock while replacing the map
f.cLock.Lock() f.cLock.Lock()
f.faviconMap = favicons f.faviconMap = favicons
@ -61,22 +70,27 @@ func (f *Favicons) Compile() {
}() }()
} }
func (f *Favicons) GetIcons(host string) (*FaviconList, bool) { // GetIcons returns the favicon list for the provided host or nil if no
// icon is found or generated
func (f *Favicons) GetIcons(host string) *FaviconList {
// read lock for safety
f.cLock.RLock() f.cLock.RLock()
defer f.cLock.RUnlock() defer f.cLock.RUnlock()
if a, ok := f.faviconMap[host]; ok {
return a, true // return value from map
} return f.faviconMap[host]
return nil, false
} }
// internalCompile is a hidden internal method for loading and generating all
// favicons.
func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error { func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error {
// query all rows in database // query all rows in database
query, err := f.db.Query(`select * from favicons`) query, err := f.db.Query(`select host, svg, png, ico from favicons`)
if err != nil { if err != nil {
return fmt.Errorf("failed to prepare query: %w", err) return fmt.Errorf("failed to prepare query: %w", err)
} }
// loop over rows and scan in data using error group to catch errors
var g errgroup.Group var g errgroup.Group
for query.Next() { for query.Next() {
var host, rawSvg, rawPng, rawIco string var host, rawSvg, rawPng, rawIco string
@ -85,12 +99,17 @@ func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error {
return fmt.Errorf("failed to scan row: %w", err) return fmt.Errorf("failed to scan row: %w", err)
} }
// create favicon list for this row
l := &FaviconList{ l := &FaviconList{
Ico: CreateFaviconImage(rawIco), Ico: CreateFaviconImage(rawIco),
Png: CreateFaviconImage(rawPng), Png: CreateFaviconImage(rawPng),
Svg: CreateFaviconImage(rawSvg), Svg: CreateFaviconImage(rawSvg),
} }
// save the favicon list to the map
faviconMap[host] = l faviconMap[host] = l
// run the pre-process in a separate goroutine
g.Go(func() error { g.Go(func() error {
return l.PreProcess(f.convertSvgToPng) return l.PreProcess(f.convertSvgToPng)
}) })
@ -98,16 +117,19 @@ func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error {
return g.Wait() return g.Wait()
} }
// convertSvgToPng calls svg2png which runs inkscape in a subprocess
func (f *Favicons) convertSvgToPng(in []byte) ([]byte, error) { func (f *Favicons) convertSvgToPng(in []byte) ([]byte, error) {
return svg2png(f.cmd, in) return svg2png(f.cmd, in)
} }
// FaviconList contains the ico, png and svg icons for separate favicons
type FaviconList struct { type FaviconList struct {
Ico *FaviconImage // can be generated from png with wrapper Ico *FaviconImage // can be generated from png with wrapper
Png *FaviconImage // can be generated from svg with inkscape Png *FaviconImage // can be generated from svg with inkscape
Svg *FaviconImage Svg *FaviconImage
} }
// ProduceIco outputs the bytes of the ico icon or an error
func (l *FaviconList) ProduceIco() ([]byte, error) { func (l *FaviconList) ProduceIco() ([]byte, error) {
if l.Ico == nil { if l.Ico == nil {
return nil, ErrFaviconNotFound return nil, ErrFaviconNotFound
@ -115,6 +137,7 @@ func (l *FaviconList) ProduceIco() ([]byte, error) {
return l.Ico.Raw, nil return l.Ico.Raw, nil
} }
// ProducePng outputs the bytes of the png icon or an error
func (l *FaviconList) ProducePng() ([]byte, error) { func (l *FaviconList) ProducePng() ([]byte, error) {
if l.Png == nil { if l.Png == nil {
return nil, ErrFaviconNotFound return nil, ErrFaviconNotFound
@ -122,6 +145,7 @@ func (l *FaviconList) ProducePng() ([]byte, error) {
return l.Png.Raw, nil return l.Png.Raw, nil
} }
// ProduceSvg outputs the bytes of the svg icon or an error
func (l *FaviconList) ProduceSvg() ([]byte, error) { func (l *FaviconList) ProduceSvg() ([]byte, error) {
if l.Svg == nil { if l.Svg == nil {
return nil, ErrFaviconNotFound return nil, ErrFaviconNotFound
@ -129,6 +153,8 @@ func (l *FaviconList) ProduceSvg() ([]byte, error) {
return l.Svg.Raw, nil return l.Svg.Raw, nil
} }
// PreProcess takes an input of the svg2png conversion function and outputs
// an error if the SVG, PNG or ICO fails to download or generate
func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error { func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error {
var err error var err error
@ -178,10 +204,13 @@ func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error
return fmt.Errorf("[Favicons] Failed to generate ICO icon: %w", err) return fmt.Errorf("[Favicons] Failed to generate ICO icon: %w", err)
} }
} }
// generate sha256 hashes for svg, png and ico
l.genSha256() l.genSha256()
return nil return nil
} }
// genSha256 generates sha256 hashes
func (l *FaviconList) genSha256() { func (l *FaviconList) genSha256() {
if l.Svg != nil { if l.Svg != nil {
l.Svg.Hash = genSha256(l.Svg.Raw) l.Svg.Hash = genSha256(l.Svg.Raw)
@ -194,6 +223,8 @@ func (l *FaviconList) genSha256() {
} }
} }
// getFaviconViaRequest uses the standard http request library to download
// icons, outputs the raw bytes from the download or an error.
func getFaviconViaRequest(url string) ([]byte, error) { func getFaviconViaRequest(url string) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil) req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil { if err != nil {
@ -210,21 +241,27 @@ func getFaviconViaRequest(url string) ([]byte, error) {
return rawBody, nil return rawBody, nil
} }
// genSha256 generates a sha256 hash as a hex encoded string
func genSha256(in []byte) string { func genSha256(in []byte) string {
// create sha256 generator and write to it
h := sha256.New() h := sha256.New()
_, err := h.Write(in) _, err := h.Write(in)
if err != nil { if err != nil {
return "" return ""
} }
// encode as hex
return hex.EncodeToString(h.Sum(nil)) return hex.EncodeToString(h.Sum(nil))
} }
// FaviconImage stores the url, hash and raw bytes of an image
type FaviconImage struct { type FaviconImage struct {
Url string Url string
Hash string Hash string
Raw []byte Raw []byte
} }
// CreateFaviconImage outputs a FaviconImage with the specified URL or nil if
// the URL is an empty string.
func CreateFaviconImage(url string) *FaviconImage { func CreateFaviconImage(url string) *FaviconImage {
if url == "" { if url == "" {
return nil return nil

View File

@ -6,24 +6,31 @@ import (
"os/exec" "os/exec"
) )
// svg2png takes an input inkscape binary path and svg image bytes and outputs
// the png image bytes or an error.
func svg2png(inkscapeCmd string, in []byte) (out []byte, err error) { func svg2png(inkscapeCmd string, in []byte) (out []byte, err error) {
// create stdout and stderr buffers
var stdout, stderr bytes.Buffer var stdout, stderr bytes.Buffer
// prepare inkscape command and attach buffers
cmd := exec.Command(inkscapeCmd, "--export-type", "png", "--export-filename", "-", "--export-background-opacity", "0", "--pipe") cmd := exec.Command(inkscapeCmd, "--export-type", "png", "--export-filename", "-", "--export-background-opacity", "0", "--pipe")
cmd.Stdin = bytes.NewBuffer(in) cmd.Stdin = bytes.NewBuffer(in)
cmd.Stdout = &stdout cmd.Stdout = &stdout
cmd.Stderr = &stderr cmd.Stderr = &stderr
// run the command and return errors
if e := cmd.Run(); e != nil { if e := cmd.Run(); e != nil {
err = fmt.Errorf("%s\nSTDERR:\n%s", e.Error(), stderr.String()) err = fmt.Errorf("%s\nSTDERR:\n%s", e.Error(), stderr.String())
return return
} }
// error if there is no output
if stdout.Len() == 0 { if stdout.Len() == 0 {
err = fmt.Errorf("got no data from inkscape") err = fmt.Errorf("got no data from inkscape")
return return
} }
// return the raw bytes
out = stdout.Bytes() out = stdout.Bytes()
return return
} }

View File

@ -46,7 +46,7 @@ func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server
// Create and run http server // Create and run http server
s := &http.Server{ s := &http.Server{
Addr: listen, Addr: conf.ApiListen,
Handler: r, Handler: r,
ReadTimeout: time.Minute, ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute, ReadHeaderTimeout: time.Minute,

View File

@ -5,6 +5,8 @@ import (
"strings" "strings"
) )
// SplitDomainPort takes an input host and default port then outputs the domain,
// port and true or empty values and false if the split failed
func SplitDomainPort(host string, defaultPort int) (domain string, port int, ok bool) { func SplitDomainPort(host string, defaultPort int) (domain string, port int, ok bool) {
a := strings.SplitN(host, ":", 2) a := strings.SplitN(host, ":", 2)
switch len(a) { switch len(a) {
@ -21,42 +23,63 @@ func SplitDomainPort(host string, defaultPort int) (domain string, port int, ok
return return
} }
// GetDomainWithoutPort takes an input domain + port and outputs the domain
// without the port.
//
// example.com:443 => example.com
func GetDomainWithoutPort(domain string) (string, bool) { func GetDomainWithoutPort(domain string) (string, bool) {
a := strings.SplitN(domain, ":", 2) // if a valid index isn't found then return false
if len(a) == 2 { n := strings.LastIndexByte(domain, ':')
return a[0], true if n == -1 {
}
if len(a) == 0 {
return "", false return "", false
} }
return a[0], true return domain[:n], true
} }
// ReplaceSubdomainWithWildcard returns the domain with the subdomain replaced
// with a wildcard '*' character.
//
// www.example.com => *.example.com
func ReplaceSubdomainWithWildcard(domain string) (string, bool) { func ReplaceSubdomainWithWildcard(domain string) (string, bool) {
a, b := GetBaseDomain(domain) // if a valid index isn't found then return false
return "*." + a, b n := strings.IndexByte(domain, '.')
if n == -1 {
return "", false
}
return "*" + domain[n:], true
} }
func GetBaseDomain(domain string) (string, bool) { // GetParentDomain returns the parent domain stripping off the subdomain.
a := strings.SplitN(domain, ".", 2) //
l := len(a) // www.example.com => example.com
if l == 2 { func GetParentDomain(domain string) (string, bool) {
return a[1], true // if a valid index isn't found then return false
n := strings.IndexByte(domain, '.')
if n == -1 {
return "", false
} }
if l == 1 { return domain[n+1:], true
return a[0], true
}
return "", false
} }
// GetTopFqdn returns the top domain stripping off multiple layers of subdomains.
//
// hello.world.example.com => example.com
func GetTopFqdn(domain string) (string, bool) { func GetTopFqdn(domain string) (string, bool) {
a := strings.Split(domain, ".") var countDot int
l := len(a) n := strings.LastIndexFunc(domain, func(r rune) bool {
if l >= 2 { // return true if this is the second '.'
return strings.Join(a[l-2:], "."), true // otherwise counts one and continues
if r == '.' {
if countDot == 1 {
return true
}
countDot++
}
return false
})
// if a valid index isn't found then return false
if n == -1 {
return "", false
} }
if l == 1 { return domain[n+1:], true
return domain, true
}
return "", false
} }

View File

@ -38,11 +38,11 @@ func TestReplaceSubdomainWithWildcard(t *testing.T) {
} }
func TestGetBaseDomain(t *testing.T) { func TestGetBaseDomain(t *testing.T) {
domain, ok := GetBaseDomain("www.example.com") domain, ok := GetParentDomain("www.example.com")
assert.True(t, ok, "Output should be true") assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain) assert.Equal(t, "example.com", domain)
domain, ok = GetBaseDomain("www.example.com:5612") domain, ok = GetParentDomain("www.example.com:5612")
assert.True(t, ok, "Output should be true") assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com:5612", domain) assert.Equal(t, "example.com:5612", domain)
} }

View File

@ -4,12 +4,8 @@ import (
"net/http" "net/http"
) )
var ( // FastRedirect adds a location header, status code and if the method is get,
a1 = []byte("<a href=\"") // outputs the status text.
a2 = []byte("\">")
a3 = []byte("</a>.\n")
)
func FastRedirect(rw http.ResponseWriter, req *http.Request, url string, code int) { func FastRedirect(rw http.ResponseWriter, req *http.Request, url string, code int) {
rw.Header().Add("Location", url) rw.Header().Add("Location", url)
rw.WriteHeader(code) rw.WriteHeader(code)

View File

@ -1,11 +1,15 @@
package utils package utils
// Compilable is an interface for structs with an asynchronous compile method.
type Compilable interface { type Compilable interface {
Compile() Compile()
} }
// MultiCompilable is a slice of multiple Compilable interfaces.
type MultiCompilable []Compilable type MultiCompilable []Compilable
// Compile loops over the slice of Compilable interfaces and calls Compile on
// each one.
func (m MultiCompilable) Compile() { func (m MultiCompilable) Compile() {
for _, i := range m { for _, i := range m {
i.Compile() i.Compile()

View File

@ -24,6 +24,8 @@ func RunBackgroundHttps(prefix string, s *http.Server) {
logHttpServerError(prefix, s.ListenAndServeTLS("", "")) logHttpServerError(prefix, s.ListenAndServeTLS("", ""))
} }
// GetBearer returns the bearer from the Authorization header or an empty string
// if the authorization is empty or doesn't start with Bearer.
func GetBearer(req *http.Request) string { func GetBearer(req *http.Request) string {
a := req.Header.Get("Authorization") a := req.Header.Get("Authorization")
if t, ok := strings.CutPrefix(a, "Bearer "); ok { if t, ok := strings.CutPrefix(a, "Bearer "); ok {