diff --git a/favicons/favicon-list.go b/favicons/favicon-list.go index ab275b7..76f81f9 100644 --- a/favicons/favicon-list.go +++ b/favicons/favicon-list.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/sha256" "encoding/hex" + "errors" "fmt" "github.com/MrMelon54/png2ico" "image/png" @@ -18,6 +19,27 @@ type FaviconList struct { Svg *FaviconImage } +var ErrInvalidFaviconExtension = errors.New("invalid favicon extension") + +// ProduceForExt outputs the bytes for the ico/png/svg icon and the HTTP +// Content-Type header to output. +func (l *FaviconList) ProduceForExt(ext string) (raw []byte, contentType string, err error) { + switch ext { + case ".ico": + contentType = "image/x-icon" + raw, err = l.ProduceIco() + case ".png": + contentType = "image/png" + raw, err = l.ProducePng() + case ".svg": + contentType = "image/svg+xml" + raw, err = l.ProduceSvg() + default: + err = ErrInvalidFaviconExtension + } + return +} + // ProduceIco outputs the bytes of the ico icon or an error func (l *FaviconList) ProduceIco() ([]byte, error) { if l.Ico == nil { diff --git a/servers/api_test.go b/servers/api_test.go index 58dd10c..4120a48 100644 --- a/servers/api_test.go +++ b/servers/api_test.go @@ -18,9 +18,10 @@ var snakeOilProv = genSnakeOilProv() type fakeDomains struct{} -func (f *fakeDomains) IsValid(host string) bool { return host == "example.com" } -func (f *fakeDomains) Put(domain string, active bool) {} -func (f *fakeDomains) Delete(domain string) {} +func (f *fakeDomains) IsValid(host string) bool { return host == "example.com" } +func (f *fakeDomains) Put(string, bool) {} +func (f *fakeDomains) Delete(string) {} +func (f *fakeDomains) Compile() {} func genSnakeOilProv() mjwt.Signer { key, err := rsa.GenerateKey(rand.Reader, 1024) diff --git a/servers/https.go b/servers/https.go index b85a1e7..9718cf6 100644 --- a/servers/https.go +++ b/servers/https.go @@ -10,6 +10,7 @@ import ( "log" "net" "net/http" + "path" "time" ) @@ -69,33 +70,14 @@ func setupFaviconMiddleware(fav *favicons.Favicons, next http.Handler) http.Hand return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { if req.Header.Get("X-Violet-Raw-Favicon") != "1" { switch req.URL.Path { - case "/favicon.svg": + case "/favicon.svg", "/favicon.png", "/favicon.ico": icons := fav.GetIcons(req.Host) - raw, err := icons.ProduceSvg() + raw, contentType, err := icons.ProduceForExt(path.Ext(req.URL.Path)) if err != nil { - utils.RespondVioletError(rw, http.StatusTeapot, "No SVG icon available") - return - } - rw.WriteHeader(http.StatusOK) - _, _ = rw.Write(raw) - return - case "/favicon.png": - icons := fav.GetIcons(req.Host) - raw, err := icons.ProducePng() - if err != nil { - utils.RespondVioletError(rw, http.StatusTeapot, "No PNG icon available") - return - } - rw.WriteHeader(http.StatusOK) - _, _ = rw.Write(raw) - return - case "/favicon.ico": - icons := fav.GetIcons(req.Host) - raw, err := icons.ProduceIco() - if err != nil { - utils.RespondVioletError(rw, http.StatusTeapot, "No ICO icon available") + utils.RespondVioletError(rw, http.StatusTeapot, "No icon available") return } + rw.Header().Set("Content-Type", contentType) rw.WriteHeader(http.StatusOK) _, _ = rw.Write(raw) return diff --git a/utils/response_test.go b/utils/response_test.go index 2f33781..058ddb9 100644 --- a/utils/response_test.go +++ b/utils/response_test.go @@ -16,7 +16,7 @@ func TestRespondHttpStatus(t *testing.T) { assert.Equal(t, "418 I'm a teapot", res.Status) a, err := io.ReadAll(res.Body) assert.NoError(t, err) - assert.Equal(t, "418 I'm a teapot\n\n", string(a)) + assert.Equal(t, "418 I'm a teapot\n", string(a)) } func TestRespondVioletError(t *testing.T) { @@ -27,6 +27,6 @@ func TestRespondVioletError(t *testing.T) { assert.Equal(t, "418 I'm a teapot", res.Status) a, err := io.ReadAll(res.Body) assert.NoError(t, err) - assert.Equal(t, "418 I'm a teapot\n\n", string(a)) + assert.Equal(t, "418 I'm a teapot\n", string(a)) assert.Equal(t, "Hidden Error Message", res.Header.Get("X-Violet-Error")) }