violet/error-pages/error-pages.go

150 lines
3.6 KiB
Go

package error_pages
import (
"fmt"
"github.com/1f349/violet/logger"
"github.com/mrmelon54/rescheduler"
"io/fs"
"net/http"
"path/filepath"
"strconv"
"strings"
"sync"
)
var Logger = logger.Logger.WithPrefix("Violet Error Pages")
// ErrorPages stores the custom error pages and is called by the servers to
// output meaningful pages for HTTP error codes
type ErrorPages struct {
s *sync.RWMutex
m map[int]func(rw http.ResponseWriter)
generic func(rw http.ResponseWriter, code int)
dir fs.FS
r *rescheduler.Rescheduler
}
// New creates a new error pages generator
func New(dir fs.FS) *ErrorPages {
e := &ErrorPages{
s: &sync.RWMutex{},
m: make(map[int]func(rw http.ResponseWriter)),
// generic error page writer
generic: func(rw http.ResponseWriter, code int) {
// if status text is empty then the code is unknown
a := http.StatusText(code)
fmt.Printf("%d - %s\n", code, a)
if a != "" {
// output in "xxx Error Text" format
http.Error(rw, fmt.Sprintf("%d %s\n", code, a), code)
return
}
// output the code and generic unknown message
http.Error(rw, fmt.Sprintf("%d Unknown Error Code\n", code), code)
},
dir: dir,
}
e.r = rescheduler.NewRescheduler(e.threadCompile)
return e
}
// ServeError writes the error page for the given code to the response writer
func (e *ErrorPages) ServeError(rw http.ResponseWriter, code int) {
// read lock for safety
e.s.RLock()
defer e.s.RUnlock()
// use the custom error page if it exists
if p, ok := e.m[code]; ok {
p(rw)
return
}
// otherwise use the generic error page
e.generic(rw, code)
}
// Compile loads the error pages the certificates and keys from the directories.
//
// This method makes use of the rescheduler instead of just ignoring multiple
// calls.
func (e *ErrorPages) Compile() {
e.r.Run()
}
func (e *ErrorPages) threadCompile() {
// new map
errorPageMap := make(map[int]func(rw http.ResponseWriter))
// compile map and check errors
if e.dir != nil {
err := e.internalCompile(errorPageMap)
if err != nil {
Logger.Info("Compile failed", "err", err)
return
}
}
// lock while replacing the map
e.s.Lock()
e.m = errorPageMap
e.s.Unlock()
}
func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) error {
// try to read dir
files, err := fs.ReadDir(e.dir, ".")
if err != nil {
return fmt.Errorf("failed to read error pages dir: %w", err)
}
Logger.Info("Compiling lookup table", "page count", len(files))
// find and load error pages
for _, i := range files {
// skip dirs
if i.IsDir() {
continue
}
// get file name and extension
name := i.Name()
ext := filepath.Ext(name)
// if the extension is not 'html' then ignore the file
if ext != ".html" {
Logger.Warn("Ignoring non '.html' file in error pages directory", "name", name)
continue
}
// if the name can't be
nameInt, err := strconv.Atoi(strings.TrimSuffix(name, ".html"))
if err != nil {
Logger.Warn("Ignoring invalid error page in error pages directory", "name", name)
continue
}
// check if code is in range 100-599
if nameInt < 100 || nameInt >= 600 {
Logger.Warn("Ignoring invalid error page in error pages directory must be 100-599", "name", name)
continue
}
// try to read html file
htmlData, err := fs.ReadFile(e.dir, name)
if err != nil {
return fmt.Errorf("failed to read html file '%s': %w", name, err)
}
// create a callback function to write the page
m[nameInt] = func(rw http.ResponseWriter) {
rw.Header().Set("Content-Type", "text/html; encoding=utf-8")
rw.WriteHeader(nameInt)
_, _ = rw.Write(htmlData)
}
}
// well no errors happened
return nil
}