diff --git a/benchmarks/go.mod b/benchmarks/go.mod deleted file mode 100644 index e1c816f..0000000 --- a/benchmarks/go.mod +++ /dev/null @@ -1,12 +0,0 @@ -module benchmarks - -go 1.20 - -require ( - github.com/MrMelon54/violet v0.0.0-20230419182034-77d570ac1e6d - github.com/gorilla/mux v1.8.0 -) - -require github.com/MrMelon54/trie v0.0.2 // indirect - -replace github.com/MrMelon54/violet => ../ diff --git a/benchmarks/go.sum b/benchmarks/go.sum deleted file mode 100644 index e254c89..0000000 --- a/benchmarks/go.sum +++ /dev/null @@ -1,8 +0,0 @@ -github.com/MrMelon54/trie v0.0.2 h1:ZXWcX5ij62O9K4I/anuHmVg8L3tF0UGdlPceAASwKEY= -github.com/MrMelon54/trie v0.0.2/go.mod h1:sGCGOcqb+DxSxvHgSOpbpkmA7mFZR47YDExy9OCbVZI= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/benchmarks/router_test.go b/benchmarks/router_test.go deleted file mode 100644 index 61d68c8..0000000 --- a/benchmarks/router_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package benchmarks - -import ( - "github.com/MrMelon54/violet/router" - "github.com/MrMelon54/violet/target" - gorillaRouter "github.com/gorilla/mux" - "net/http" - "net/http/httptest" - "testing" -) - -func benchRequest(b *testing.B, router http.Handler, r *http.Request) { - w := httptest.NewRecorder() - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - router.ServeHTTP(w, r) - } - if w.Header().Get("Location") != "https://example.com" { - b.Fatal("Location: ", w.Header().Get("Location"), " != https://example.com") - } -} - -func BenchmarkVioletRouter(b *testing.B) { - r := router.New(nil) - r.AddRedirect("*.example.com", "", target.Redirect{ - Pre: true, - Host: "example.com", - Code: http.StatusPermanentRedirect, - }) - benchRequest(b, r, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil)) -} - -func BenchmarkGorillaMux(b *testing.B) { - r := gorillaRouter.NewRouter() - r.Host("{subdomain}.example.com").Handler(target.Redirect{ - Pre: true, - Host: "example.com", - Code: http.StatusPermanentRedirect, - }) - benchRequest(b, r, httptest.NewRequest(http.MethodGet, "https://www.example.com/", nil)) -} diff --git a/cmd/violet/main.go b/cmd/violet/main.go index e7bf603..69dc0f8 100644 --- a/cmd/violet/main.go +++ b/cmd/violet/main.go @@ -62,7 +62,7 @@ func main() { allowedDomains := domains.New(db) // load allowed domains allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath), *selfSigned) // load certificate manager - reverseProxy := proxy.CreateHybridReverseProxy() // load reverse proxy + reverseProxy := proxy.NewHybridTransport() // load reverse proxy dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider dynamicRouter := router.NewManager(db, reverseProxy) // load dynamic router manager diff --git a/domains/domains.go b/domains/domains.go index 9bac6e6..343a057 100644 --- a/domains/domains.go +++ b/domains/domains.go @@ -48,12 +48,15 @@ func (d *Domains) IsValid(host string) bool { defer d.s.RUnlock() // check root domains `www.example.com`, `example.com`, `com` - // TODO: could be faster using indexes and cropping the string? - n := strings.Split(domain, ".") - for i := 0; i < len(n); i++ { - if _, ok := d.m[strings.Join(n[i:], ".")]; ok { + for len(domain) > 0 { + if _, ok := d.m[domain]; ok { return true } + n := strings.IndexByte(domain, '.') + if n == -1 { + break + } + domain = domain[n+1:] } return false } diff --git a/domains/domains_test.go b/domains/domains_test.go new file mode 100644 index 0000000..a4a68b9 --- /dev/null +++ b/domains/domains_test.go @@ -0,0 +1,45 @@ +package domains + +import ( + "database/sql" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestDomainsNew(t *testing.T) { + db, err := sql.Open("sqlite3", "file::memory:?cache=shared") + assert.NoError(t, err) + + domains := New(db) + _, err = db.Exec("insert into domains (domain, active) values (?, ?)", "example.com", 1) + assert.NoError(t, err) + domains.Compile() + + if _, ok := domains.m["example.com"]; ok { + assert.True(t, ok) + } + + if _, ok := domains.m["www.example.com"]; !ok { + assert.False(t, ok) + } +} + +func TestDomains_IsValid(t *testing.T) { + // open sqlite database + db, err := sql.Open("sqlite3", "file::memory:?cache=shared") + assert.NoError(t, err) + + domains := New(db) + _, err = domains.db.Exec("insert into domains (domain, active) values (?, ?)", "example.com", 1) + assert.NoError(t, err) + + domains.s.Lock() + assert.NoError(t, domains.internalCompile(domains.m)) + domains.s.Unlock() + + assert.True(t, domains.IsValid("example.com")) + assert.True(t, domains.IsValid("www.example.com")) + assert.False(t, domains.IsValid("notexample.com")) + assert.False(t, domains.IsValid("www.notexample.com")) +} diff --git a/error-pages/error-pages.go b/error-pages/error-pages.go index 2534e0d..7eb91dc 100644 --- a/error-pages/error-pages.go +++ b/error-pages/error-pages.go @@ -29,6 +29,7 @@ func New(dir fs.FS) *ErrorPages { generic: func(rw http.ResponseWriter, code int) { // if status text is empty then the code is unknown a := http.StatusText(code) + fmt.Printf("%d - %s\n", code, a) if a != "" { // output in "xxx Error Text" format http.Error(rw, fmt.Sprintf("%d %s\n", code, a), code) @@ -64,10 +65,12 @@ func (e *ErrorPages) Compile() { errorPageMap := make(map[int]func(rw http.ResponseWriter)) // compile map and check errors - err := e.internalCompile(errorPageMap) - if err != nil { - log.Printf("[Certs] Compile failed: %s\n", err) - return + if e.dir != nil { + err := e.internalCompile(errorPageMap) + if err != nil { + log.Printf("[Certs] Compile failed: %s\n", err) + return + } } // lock while replacing the map diff --git a/error-pages/error-pages_test.go b/error-pages/error-pages_test.go new file mode 100644 index 0000000..d5ee363 --- /dev/null +++ b/error-pages/error-pages_test.go @@ -0,0 +1,31 @@ +package error_pages + +import ( + "github.com/stretchr/testify/assert" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func TestErrorPages_ServeError(t *testing.T) { + errorPages := New(nil) + + rec := httptest.NewRecorder() + errorPages.ServeError(rec, http.StatusTeapot) + res := rec.Result() + assert.Equal(t, http.StatusTeapot, res.StatusCode) + assert.Equal(t, "418 I'm a teapot", res.Status) + a, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.Equal(t, "418 I'm a teapot\n\n", string(a)) + + rec = httptest.NewRecorder() + errorPages.ServeError(rec, 469) + res = rec.Result() + assert.Equal(t, 469, res.StatusCode) + assert.Equal(t, "469 ", res.Status) + a, err = io.ReadAll(res.Body) + assert.NoError(t, err) + assert.Equal(t, "469 Unknown Error Code\n\n", string(a)) +} diff --git a/favicons/example.ico b/favicons/example.ico new file mode 100644 index 0000000..be2e640 Binary files /dev/null and b/favicons/example.ico differ diff --git a/favicons/example.png b/favicons/example.png new file mode 100644 index 0000000..e6bd204 Binary files /dev/null and b/favicons/example.png differ diff --git a/favicons/example.svg b/favicons/example.svg new file mode 100644 index 0000000..528b014 --- /dev/null +++ b/favicons/example.svg @@ -0,0 +1,4 @@ + + + + diff --git a/favicons/favicon-image.go b/favicons/favicon-image.go new file mode 100644 index 0000000..91a49bc --- /dev/null +++ b/favicons/favicon-image.go @@ -0,0 +1,17 @@ +package favicons + +// FaviconImage stores the url, hash and raw bytes of an image +type FaviconImage struct { + Url string + Hash string + Raw []byte +} + +// CreateFaviconImage outputs a FaviconImage with the specified URL or nil if +// the URL is an empty string. +func CreateFaviconImage(url string) *FaviconImage { + if url == "" { + return nil + } + return &FaviconImage{Url: url} +} diff --git a/favicons/favicon-list.go b/favicons/favicon-list.go new file mode 100644 index 0000000..5ec0e95 --- /dev/null +++ b/favicons/favicon-list.go @@ -0,0 +1,144 @@ +package favicons + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "fmt" + "github.com/mrmelon54/png2ico" + "image/png" + "io" + "net/http" +) + +// FaviconList contains the ico, png and svg icons for separate favicons +type FaviconList struct { + Ico *FaviconImage // can be generated from png with wrapper + Png *FaviconImage // can be generated from svg with inkscape + Svg *FaviconImage +} + +// ProduceIco outputs the bytes of the ico icon or an error +func (l *FaviconList) ProduceIco() ([]byte, error) { + if l.Ico == nil { + return nil, ErrFaviconNotFound + } + return l.Ico.Raw, nil +} + +// ProducePng outputs the bytes of the png icon or an error +func (l *FaviconList) ProducePng() ([]byte, error) { + if l.Png == nil { + return nil, ErrFaviconNotFound + } + return l.Png.Raw, nil +} + +// ProduceSvg outputs the bytes of the svg icon or an error +func (l *FaviconList) ProduceSvg() ([]byte, error) { + if l.Svg == nil { + return nil, ErrFaviconNotFound + } + return l.Svg.Raw, nil +} + +// PreProcess takes an input of the svg2png conversion function and outputs +// an error if the SVG, PNG or ICO fails to download or generate +func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error { + var err error + + // SVG + if l.Svg != nil { + // download SVG + l.Svg.Raw, err = getFaviconViaRequest(l.Svg.Url) + if err != nil { + return fmt.Errorf("[Favicons] Failed to fetch SVG icon: %w", err) + } + l.Svg.Hash = hex.EncodeToString(sha256.New().Sum(l.Svg.Raw)) + } + + // PNG + if l.Png != nil { + // download PNG + l.Png.Raw, err = getFaviconViaRequest(l.Png.Url) + if err != nil { + return fmt.Errorf("[Favicons] Failed to fetch PNG icon: %w", err) + } + } else if l.Svg != nil { + // generate PNG from SVG + l.Png = &FaviconImage{} + l.Png.Raw, err = convert(l.Svg.Raw) + if err != nil { + return fmt.Errorf("[Favicons] Failed to generate PNG icon: %w", err) + } + } + + // ICO + if l.Ico != nil { + // download ICO + l.Ico.Raw, err = getFaviconViaRequest(l.Ico.Url) + if err != nil { + return fmt.Errorf("[Favicons] Failed to fetch ICO icon: %w", err) + } + } else if l.Png != nil { + // generate ICO from PNG + l.Ico = &FaviconImage{} + decode, err := png.Decode(bytes.NewReader(l.Png.Raw)) + if err != nil { + return fmt.Errorf("[Favicons] Failed to decode PNG icon: %w", err) + } + b := decode.Bounds() + l.Ico.Raw, err = png2ico.ConvertPngToIco(l.Png.Raw, b.Dx(), b.Dy()) + if err != nil { + return fmt.Errorf("[Favicons] Failed to generate ICO icon: %w", err) + } + } + + // generate sha256 hashes for svg, png and ico + l.genSha256() + return nil +} + +// genSha256 generates sha256 hashes +func (l *FaviconList) genSha256() { + if l.Svg != nil { + l.Svg.Hash = genSha256(l.Svg.Raw) + } + if l.Png != nil { + l.Png.Hash = genSha256(l.Png.Raw) + } + if l.Ico != nil { + l.Ico.Hash = genSha256(l.Ico.Raw) + } +} + +// getFaviconViaRequest uses the standard http request library to download +// icons, outputs the raw bytes from the download or an error. +var getFaviconViaRequest = func(url string) ([]byte, error) { + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("[Favicons] Failed to send request '%s': %w", url, err) + } + req.Header.Set("X-Violet-Raw-Favicon", "1") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("[Favicons] Failed to do request '%s': %w", url, err) + } + rawBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("[Favicons] Failed to read response '%s': %w", url, err) + } + return rawBody, nil +} + +// genSha256 generates a sha256 hash as a hex encoded string +func genSha256(in []byte) string { + // create sha256 generator and write to it + h := sha256.New() + _, err := h.Write(in) + if err != nil { + return "" + } + // encode as hex + return hex.EncodeToString(h.Sum(nil)) +} diff --git a/favicons/favicon-list_test.go b/favicons/favicon-list_test.go new file mode 100644 index 0000000..9becf14 --- /dev/null +++ b/favicons/favicon-list_test.go @@ -0,0 +1,33 @@ +package favicons + +import ( + "bytes" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestFaviconList_PreProcess(t *testing.T) { + getFaviconViaRequest = func(_ string) ([]byte, error) { + return exampleSvg, nil + } + icons := &FaviconList{Svg: &FaviconImage{Url: "https://example.com/assets/logo.svg"}} + assert.NoError(t, icons.PreProcess(func(in []byte) ([]byte, error) { + return svg2png("inkscape", in) + })) + iconSvg, err := icons.ProduceSvg() + assert.NoError(t, err) + iconPng, err := icons.ProducePng() + assert.NoError(t, err) + iconIco, err := icons.ProduceIco() + assert.NoError(t, err) + + assert.Equal(t, "https://example.com/assets/logo.svg", icons.Svg.Url) + + assert.Equal(t, "74cdc17d0502a690941799c327d9ca1ed042e76c784def43a42937f2eed270b4", icons.Svg.Hash) + assert.Equal(t, "84841341dafbb1e54c62d160dfc5e48c3f8db4b22265a4dbe2e0318debf9b670", icons.Png.Hash) + assert.Equal(t, "33fc667fdb0e32305f2ee27e7dd7feb781cc776638d0971db7e18cc6335a15c7", icons.Ico.Hash) + + assert.Equal(t, 0, bytes.Compare(exampleSvg, iconSvg)) + assert.Equal(t, 0, bytes.Compare(examplePng, iconPng)) + assert.Equal(t, 0, bytes.Compare(exampleIco, iconIco)) +} diff --git a/favicons/favicons.go b/favicons/favicons.go index 372c60f..e3682b8 100644 --- a/favicons/favicons.go +++ b/favicons/favicons.go @@ -1,18 +1,11 @@ package favicons import ( - "bytes" - "crypto/sha256" "database/sql" - "encoding/hex" "errors" "fmt" - "github.com/mrmelon54/png2ico" "golang.org/x/sync/errgroup" - "image/png" - "io" "log" - "net/http" "sync" ) @@ -121,150 +114,3 @@ func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error { func (f *Favicons) convertSvgToPng(in []byte) ([]byte, error) { return svg2png(f.cmd, in) } - -// FaviconList contains the ico, png and svg icons for separate favicons -type FaviconList struct { - Ico *FaviconImage // can be generated from png with wrapper - Png *FaviconImage // can be generated from svg with inkscape - Svg *FaviconImage -} - -// ProduceIco outputs the bytes of the ico icon or an error -func (l *FaviconList) ProduceIco() ([]byte, error) { - if l.Ico == nil { - return nil, ErrFaviconNotFound - } - return l.Ico.Raw, nil -} - -// ProducePng outputs the bytes of the png icon or an error -func (l *FaviconList) ProducePng() ([]byte, error) { - if l.Png == nil { - return nil, ErrFaviconNotFound - } - return l.Png.Raw, nil -} - -// ProduceSvg outputs the bytes of the svg icon or an error -func (l *FaviconList) ProduceSvg() ([]byte, error) { - if l.Svg == nil { - return nil, ErrFaviconNotFound - } - return l.Svg.Raw, nil -} - -// PreProcess takes an input of the svg2png conversion function and outputs -// an error if the SVG, PNG or ICO fails to download or generate -func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error { - var err error - - // SVG - if l.Svg != nil { - // download SVG - l.Svg.Raw, err = getFaviconViaRequest(l.Svg.Url) - if err != nil { - return fmt.Errorf("[Favicons] Failed to fetch SVG icon: %w", err) - } - l.Svg.Hash = hex.EncodeToString(sha256.New().Sum(l.Svg.Raw)) - } - - // PNG - if l.Png != nil { - // download PNG - l.Png.Raw, err = getFaviconViaRequest(l.Png.Url) - if err != nil { - return fmt.Errorf("[Favicons] Failed to fetch PNG icon: %w", err) - } - } else if l.Svg != nil { - // generate PNG from SVG - l.Png = &FaviconImage{} - l.Png.Raw, err = convert(l.Svg.Raw) - if err != nil { - return fmt.Errorf("[Favicons] Failed to generate PNG icon: %w", err) - } - } - - // ICO - if l.Ico != nil { - // download ICO - l.Ico.Raw, err = getFaviconViaRequest(l.Ico.Url) - if err != nil { - return fmt.Errorf("[Favicons] Failed to fetch ICO icon: %w", err) - } - } else if l.Png != nil { - // generate ICO from PNG - l.Ico = &FaviconImage{} - decode, err := png.Decode(bytes.NewReader(l.Png.Raw)) - if err != nil { - return fmt.Errorf("[Favicons] Failed to decode PNG icon: %w", err) - } - b := decode.Bounds() - l.Ico.Raw, err = png2ico.ConvertPngToIco(l.Png.Raw, b.Dx(), b.Dy()) - if err != nil { - return fmt.Errorf("[Favicons] Failed to generate ICO icon: %w", err) - } - } - - // generate sha256 hashes for svg, png and ico - l.genSha256() - return nil -} - -// genSha256 generates sha256 hashes -func (l *FaviconList) genSha256() { - if l.Svg != nil { - l.Svg.Hash = genSha256(l.Svg.Raw) - } - if l.Png != nil { - l.Png.Hash = genSha256(l.Png.Raw) - } - if l.Ico != nil { - l.Ico.Hash = genSha256(l.Ico.Raw) - } -} - -// getFaviconViaRequest uses the standard http request library to download -// icons, outputs the raw bytes from the download or an error. -func getFaviconViaRequest(url string) ([]byte, error) { - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return nil, fmt.Errorf("[Favicons] Failed to send request '%s': %w", url, err) - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("[Favicons] Failed to do request '%s': %w", url, err) - } - rawBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("[Favicons] Failed to read response '%s': %w", url, err) - } - return rawBody, nil -} - -// genSha256 generates a sha256 hash as a hex encoded string -func genSha256(in []byte) string { - // create sha256 generator and write to it - h := sha256.New() - _, err := h.Write(in) - if err != nil { - return "" - } - // encode as hex - return hex.EncodeToString(h.Sum(nil)) -} - -// FaviconImage stores the url, hash and raw bytes of an image -type FaviconImage struct { - Url string - Hash string - Raw []byte -} - -// CreateFaviconImage outputs a FaviconImage with the specified URL or nil if -// the URL is an empty string. -func CreateFaviconImage(url string) *FaviconImage { - if url == "" { - return nil - } - return &FaviconImage{Url: url} -} diff --git a/favicons/favicons_test.go b/favicons/favicons_test.go new file mode 100644 index 0000000..9d85dd8 --- /dev/null +++ b/favicons/favicons_test.go @@ -0,0 +1,51 @@ +package favicons + +import ( + "bytes" + "database/sql" + _ "embed" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "testing" +) + +var ( + //go:embed example.svg + exampleSvg []byte + //go:embed example.png + examplePng []byte + //go:embed example.ico + exampleIco []byte +) + +func TestFaviconsNew(t *testing.T) { + getFaviconViaRequest = func(_ string) ([]byte, error) { return exampleSvg, nil } + + db, err := sql.Open("sqlite3", "file::memory:?cache=shared") + assert.NoError(t, err) + + favicons := New(db, "inkscape") + _, err = db.Exec("insert into favicons (host, svg, png, ico) values (?, ?, ?, ?)", "example.com", "https://example.com/assets/logo.svg", "", "") + assert.NoError(t, err) + favicons.cLock.Lock() + assert.NoError(t, favicons.internalCompile(favicons.faviconMap)) + favicons.cLock.Unlock() + + icons := favicons.GetIcons("example.com") + iconSvg, err := icons.ProduceSvg() + assert.NoError(t, err) + iconPng, err := icons.ProducePng() + assert.NoError(t, err) + iconIco, err := icons.ProduceIco() + assert.NoError(t, err) + + assert.Equal(t, "https://example.com/assets/logo.svg", icons.Svg.Url) + + assert.Equal(t, "74cdc17d0502a690941799c327d9ca1ed042e76c784def43a42937f2eed270b4", icons.Svg.Hash) + assert.Equal(t, "84841341dafbb1e54c62d160dfc5e48c3f8db4b22265a4dbe2e0318debf9b670", icons.Png.Hash) + assert.Equal(t, "33fc667fdb0e32305f2ee27e7dd7feb781cc776638d0971db7e18cc6335a15c7", icons.Ico.Hash) + + assert.Equal(t, 0, bytes.Compare(exampleSvg, iconSvg)) + assert.Equal(t, 0, bytes.Compare(examplePng, iconPng)) + assert.Equal(t, 0, bytes.Compare(exampleIco, iconIco)) +} diff --git a/go.mod b/go.mod index 6126f5e..96917a5 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ require ( code.mrmelon54.com/melon/certgen v0.0.0-20220830133534-0fb4cb7e67d1 code.mrmelon54.com/melon/summer-utils v0.0.3 github.com/MrMelon54/trie v0.0.2 - github.com/gorilla/mux v1.8.0 github.com/julienschmidt/httprouter v1.3.0 github.com/mattn/go-sqlite3 v1.14.16 github.com/mrmelon54/mjwt v0.0.1 @@ -24,6 +23,7 @@ require ( github.com/kr/pretty v0.1.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/text v0.9.0 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 0182b7d..b5468e7 100644 --- a/go.sum +++ b/go.sum @@ -9,8 +9,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= @@ -43,6 +41,8 @@ golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/proxy/hybrid-transport.go b/proxy/hybrid-transport.go new file mode 100644 index 0000000..22235e0 --- /dev/null +++ b/proxy/hybrid-transport.go @@ -0,0 +1,73 @@ +package proxy + +import ( + "crypto/tls" + "net" + "net/http" + "sync" + "time" +) + +type HybridTransport struct { + baseDialer *net.Dialer + normalTransport http.RoundTripper + insecureTransport http.RoundTripper + socksSync *sync.RWMutex + socksTransport map[string]http.RoundTripper +} + +// NewHybridTransport creates a new hybrid transport +func NewHybridTransport() *HybridTransport { + return NewHybridTransportWithCalls(nil, nil) +} + +// NewHybridTransportWithCalls creates new hybrid transport with custom normal +// and insecure http.RoundTripper functions. +// +// NewHybridTransportWithCalls(nil, nil) is equivalent to NewHybridTransport() +func NewHybridTransportWithCalls(normal, insecure http.RoundTripper) *HybridTransport { + h := &HybridTransport{ + baseDialer: &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }, + normalTransport: normal, + insecureTransport: insecure, + } + if h.normalTransport == nil { + h.normalTransport = &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: h.baseDialer.DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 15, + TLSHandshakeTimeout: 10 * time.Second, + IdleConnTimeout: 30 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + } + } + if h.insecureTransport == nil { + h.insecureTransport = &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: h.baseDialer.DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 15, + TLSHandshakeTimeout: 10 * time.Second, + IdleConnTimeout: 30 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + } + return h +} + +// SecureRoundTrip calls the secure transport +func (h *HybridTransport) SecureRoundTrip(req *http.Request) (*http.Response, error) { + return h.normalTransport.RoundTrip(req) +} + +// InsecureRoundTrip calls the insecure transport +func (h *HybridTransport) InsecureRoundTrip(req *http.Request) (*http.Response, error) { + return h.insecureTransport.RoundTrip(req) +} diff --git a/proxy/hybrid-transport_test.go b/proxy/hybrid-transport_test.go new file mode 100644 index 0000000..8c3b34f --- /dev/null +++ b/proxy/hybrid-transport_test.go @@ -0,0 +1,16 @@ +package proxy + +import ( + "github.com/stretchr/testify/assert" + "net/http" + "testing" +) + +func TestNewHybridTransport(t *testing.T) { + h := NewHybridTransport() + req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) + assert.NoError(t, err) + trip, err := h.SecureRoundTrip(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, trip.StatusCode) +} diff --git a/proxy/reverse-proxy.go b/proxy/reverse-proxy.go deleted file mode 100644 index 46657d1..0000000 --- a/proxy/reverse-proxy.go +++ /dev/null @@ -1,145 +0,0 @@ -package proxy - -import ( - "context" - "crypto/tls" - "errors" - "fmt" - "golang.org/x/net/proxy" - "log" - "net" - "net/http" - "net/http/httputil" - "sync" - "time" -) - -type reverseProxyHostKey int - -type ReverseProxyContext interface { - IsIgnoreCert() bool - UpdateHeaders(http.Header) -} - -func SetReverseProxyHost(req *http.Request, hf ReverseProxyContext) *http.Request { - ctx := req.Context() - ctx2 := context.WithValue(ctx, reverseProxyHostKey(0), hf) - return req.WithContext(ctx2) -} - -func CreateHybridReverseProxy() *httputil.ReverseProxy { - return &httputil.ReverseProxy{ - Director: func(req *http.Request) {}, - Transport: NewHybridTransport(), - ModifyResponse: func(rw *http.Response) error { return nil }, - ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { - log.Printf("[ReverseProxy] Request: %#v\n -- Error: %s\n", req, err) - rw.WriteHeader(http.StatusBadGateway) - _, _ = rw.Write([]byte("502 Bad gateway\n")) - }, - } -} - -type HybridTransport struct { - baseDialer *net.Dialer - normalTransport http.RoundTripper - insecureTransport http.RoundTripper - socksSync *sync.RWMutex - socksTransport map[string]http.RoundTripper -} - -func NewHybridTransport() *HybridTransport { - h := &HybridTransport{ - baseDialer: &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }, - socksSync: &sync.RWMutex{}, - socksTransport: make(map[string]http.RoundTripper), - } - h.normalTransport = &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: h.baseDialer.DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: 15, - TLSHandshakeTimeout: 10 * time.Second, - IdleConnTimeout: 30 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - ResponseHeaderTimeout: 10 * time.Second, - } - h.insecureTransport = &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: h.baseDialer.DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: 15, - TLSHandshakeTimeout: 10 * time.Second, - IdleConnTimeout: 30 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - ResponseHeaderTimeout: 10 * time.Second, - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - return h -} - -func (h *HybridTransport) RoundTrip(req *http.Request) (*http.Response, error) { - newHost := req.Context().Value(reverseProxyHostKey(0)) - hf, ok := newHost.(ReverseProxyContext) - if !ok { - return nil, errors.New("failed to detect reverse proxy configuration") - } - - // Do a round trip using existing transports - var trip *http.Response - var err error - if hf.IsIgnoreCert() { - trip, err = h.insecureTransport.RoundTrip(req) - } else { - trip, err = h.normalTransport.RoundTrip(req) - } - if err != nil { - return nil, err - } - - // Override headers - hf.UpdateHeaders(trip.Header) - return trip, nil -} - -func (h *HybridTransport) getSocksProxy(addr string, insecure bool) (http.RoundTripper, error) { - if insecure { - addr = "%i-" + addr - } - h.socksSync.RLock() - s, ok := h.socksTransport[addr] - h.socksSync.RUnlock() - if ok { - return s, nil - } - - dialer, err := proxy.SOCKS5("tcp", addr, nil, proxy.Direct) - if err != nil { - return nil, fmt.Errorf("cannot connect to the proxy: %s", err) - } - - if f, ok := dialer.(proxy.ContextDialer); ok { - t := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: f.DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: 15, - TLSHandshakeTimeout: 10 * time.Second, - IdleConnTimeout: 30 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - ResponseHeaderTimeout: 10 * time.Second, - DisableKeepAlives: true, - } - if insecure { - t.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - h.socksSync.Lock() - h.socksTransport[addr] = t - h.socksSync.Unlock() - return t, nil - } - return nil, errors.New("cannot create socks5 dialer") -} diff --git a/proxy/reverse-proxy_test.go b/proxy/reverse-proxy_test.go new file mode 100644 index 0000000..ddf43a3 --- /dev/null +++ b/proxy/reverse-proxy_test.go @@ -0,0 +1,48 @@ +package proxy + +import ( + "log" + "net/http" + "net/http/httptest" + "net/http/httputil" + "testing" +) + +type customTransport struct{} + +func (c *customTransport) RoundTrip(_ *http.Request) (*http.Response, error) { + res := httptest.NewRecorder() + res.WriteHeader(http.StatusOK) + res.Write([]byte{0x54, 0x54}) + return res.Result(), nil +} + +func SetupReverseProxy() *httputil.ReverseProxy { + return &httputil.ReverseProxy{ + Director: func(req *http.Request) {}, + Transport: &customTransport{}, + ModifyResponse: func(rw *http.Response) error { return nil }, + ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { + log.Printf("[ReverseProxy] Request: %#v\n -- Error: %s\n", req, err) + rw.WriteHeader(http.StatusBadGateway) + _, _ = rw.Write([]byte("502 Bad gateway\n")) + }, + } +} + +func BenchmarkHttpUtilReverseProxy(b *testing.B) { + rev := SetupReverseProxy() + req, _ := http.NewRequest(http.MethodGet, "https://example.com", nil) + for i := 0; i < b.N; i++ { + rec := httptest.NewRecorder() + rev.ServeHTTP(rec, req) + } +} + +func BenchmarkCustomTransport(b *testing.B) { + req, _ := http.NewRequest(http.MethodGet, "https://example.com", nil) + t := &customTransport{} + for i := 0; i < b.N; i++ { + _, _ = t.RoundTrip(req) + } +} diff --git a/router/manager.go b/router/manager.go index 0d8f7b8..dd884d1 100644 --- a/router/manager.go +++ b/router/manager.go @@ -4,11 +4,11 @@ import ( "database/sql" _ "embed" "fmt" + "github.com/MrMelon54/violet/proxy" "github.com/MrMelon54/violet/target" "github.com/MrMelon54/violet/utils" "log" "net/http" - "net/http/httputil" "strings" "sync" ) @@ -19,7 +19,7 @@ type Manager struct { db *sql.DB s *sync.RWMutex r *Router - p *httputil.ReverseProxy + p *proxy.HybridTransport } var ( @@ -35,7 +35,7 @@ var ( // NewManager create a new manager, initialises the routes and redirects tables // in the database and runs a first time compile. -func NewManager(db *sql.DB, proxy *httputil.ReverseProxy) *Manager { +func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager { m := &Manager{ db: db, s: &sync.RWMutex{}, diff --git a/router/router.go b/router/router.go index b6e89c0..53f21da 100644 --- a/router/router.go +++ b/router/router.go @@ -3,10 +3,10 @@ package router import ( "fmt" "github.com/MrMelon54/trie" + "github.com/MrMelon54/violet/proxy" "github.com/MrMelon54/violet/target" "github.com/MrMelon54/violet/utils" "net/http" - "net/http/httputil" "strings" ) @@ -14,10 +14,10 @@ type Router struct { route map[string]*trie.Trie[target.Route] redirect map[string]*trie.Trie[target.Redirect] notFound http.Handler - proxy *httputil.ReverseProxy + proxy *proxy.HybridTransport } -func New(proxy *httputil.ReverseProxy) *Router { +func New(proxy *proxy.HybridTransport) *Router { return &Router{ route: make(map[string]*trie.Trie[target.Route]), redirect: make(map[string]*trie.Trie[target.Redirect]), diff --git a/servers/http.go b/servers/http.go index 7f04479..416ff7f 100644 --- a/servers/http.go +++ b/servers/http.go @@ -14,7 +14,7 @@ import ( // endpoints for the reverse proxy. // // `/.well-known/acme-challenge/{token}` is used for outputting answers for -// acme challenges, this is used for Lets Encrypt HTTP verification. +// acme challenges, this is used for Let's Encrypt HTTP verification. func NewHttpServer(conf *Conf) *http.Server { r := httprouter.New() var secureExtend string diff --git a/servers/https.go b/servers/https.go index 31b6afc..185a661 100644 --- a/servers/https.go +++ b/servers/https.go @@ -3,8 +3,8 @@ package servers import ( "crypto/tls" "fmt" + "github.com/MrMelon54/violet/favicons" "github.com/MrMelon54/violet/utils" - "github.com/gorilla/mux" "github.com/sethvargo/go-limiter/httplimit" "github.com/sethvargo/go-limiter/memorystore" "log" @@ -18,7 +18,7 @@ import ( func NewHttpsServer(conf *Conf) *http.Server { s := &http.Server{ Addr: conf.HttpsListen, - Handler: setupRateLimiter(300).Middleware(conf.Router), + Handler: setupRateLimiter(300, setupFaviconMiddleware(conf.Favicons, conf.Router)), DisableGeneralOptionsHandler: false, TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { // error out on invalid domains @@ -51,7 +51,7 @@ func NewHttpsServer(conf *Conf) *http.Server { // setupRateLimiter is an internal function to create a middleware to manage // rate limits. -func setupRateLimiter(rateLimit uint64) mux.MiddlewareFunc { +func setupRateLimiter(rateLimit uint64, next http.Handler) http.Handler { // create memory store store, err := memorystore.New(&memorystore.Config{ Tokens: rateLimit, @@ -66,5 +66,45 @@ func setupRateLimiter(rateLimit uint64) mux.MiddlewareFunc { if err != nil { log.Fatalln(err) } - return middleware.Handle + return middleware.Handle(next) +} + +func setupFaviconMiddleware(fav *favicons.Favicons, next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.Header.Get("X-Violet-Raw-Favicon") != "1" { + switch req.URL.Path { + case "/favicon.svg": + icons := fav.GetIcons(req.Host) + raw, err := icons.ProduceSvg() + if err != nil { + utils.RespondVioletError(rw, http.StatusTeapot, "No SVG icon available") + return + } + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write(raw) + return + case "/favicon.png": + icons := fav.GetIcons(req.Host) + raw, err := icons.ProducePng() + if err != nil { + utils.RespondVioletError(rw, http.StatusTeapot, "No PNG icon available") + return + } + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write(raw) + return + case "/favicon.ico": + icons := fav.GetIcons(req.Host) + raw, err := icons.ProduceIco() + if err != nil { + utils.RespondVioletError(rw, http.StatusTeapot, "No ICO icon available") + return + } + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write(raw) + return + } + } + next.ServeHTTP(rw, req) + }) } diff --git a/target/route.go b/target/route.go index 1c18948..e074c23 100644 --- a/target/route.go +++ b/target/route.go @@ -1,12 +1,17 @@ package target import ( + "context" "fmt" "github.com/MrMelon54/violet/proxy" "github.com/MrMelon54/violet/utils" "github.com/rs/cors" + "golang.org/x/net/http/httpguts" + "io" "log" + "net" "net/http" + "net/textproto" "net/url" "path" "strings" @@ -31,23 +36,20 @@ var serveApiCors = cors.New(cors.Options{ // Route is a target used by the router to manage forwarding traffic to an // internal server using the specified configuration. type Route struct { - Pre bool // if the path has had a prefix removed - Host string // target host - Port int // target port - Path string // target path (possibly a prefix or absolute) - Abs bool // if the path is a prefix or absolute - Cors bool // add CORS headers - SecureMode bool // use HTTPS internally - ForwardHost bool // forward host header internally - ForwardAddr bool // forward remote address - IgnoreCert bool // ignore self-cert - Headers http.Header // extra headers - Proxy http.Handler // reverse proxy handler + Pre bool // if the path has had a prefix removed + Host string // target host + Port int // target port + Path string // target path (possibly a prefix or absolute) + Abs bool // if the path is a prefix or absolute + Cors bool // add CORS headers + SecureMode bool // use HTTPS internally + ForwardHost bool // forward host header internally + ForwardAddr bool // forward remote address + IgnoreCert bool // ignore self-cert + Headers http.Header // extra headers + Proxy *proxy.HybridTransport // reverse proxy handler } -// IsIgnoreCert returns true if IgnoreCert is enabled. -func (r Route) IsIgnoreCert() bool { return r.IgnoreCert } - // UpdateHeaders takes an existing set of headers and overwrites them with the // extra headers. func (r Route) UpdateHeaders(header http.Header) { @@ -122,7 +124,7 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) { req2, err := http.NewRequest(req.Method, u.String(), req.Body) if err != nil { log.Printf("[ServeRoute::ServeHTTP()] Error generating new request: %s\n", err) - utils.RespondHttpStatus(rw, http.StatusBadGateway) + utils.RespondVioletError(rw, http.StatusBadGateway, "error generating new request") return } @@ -153,11 +155,162 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) { req2.Header.Add("X-Forwarded-For", req.RemoteAddr) } + // adds extra request metadata + r.internalReverseProxyMeta(rw, req) + // serve request with reverse proxy - r.Proxy.ServeHTTP(rw, proxy.SetReverseProxyHost(req2, r)) + var resp *http.Response + if r.IgnoreCert { + resp, err = r.Proxy.InsecureRoundTrip(req2) + } else { + resp, err = r.Proxy.SecureRoundTrip(req2) + } + + // copy headers and status code + copyHeader(rw.Header(), resp.Header) + rw.WriteHeader(resp.StatusCode) + + // copy body + if resp.Body != nil { + _, err := io.Copy(rw, resp.Body) + if err != nil { + // hijack and close upon error + if h, ok := rw.(http.Hijacker); ok { + hijack, _, err := h.Hijack() + if err != nil { + return + } + _ = hijack.Close() + } + return + } + } +} + +// internalReverseProxyMeta is mainly built from code copied from httputil.ReverseProxy, +// due to the highly custom nature of this reverse proxy software we use a copy +// of the code instead of the full httputil implementation to prevent overhead +// from the more generic implementation +func (r Route) internalReverseProxyMeta(rw http.ResponseWriter, req *http.Request) { + outreq := req.Clone(context.Background()) + if req.ContentLength == 0 { + outreq.Body = nil // Issue 16036: nil Body for http.Transport retries + } + if outreq.Body != nil { + // Reading from the request body after returning from a handler is not + // allowed, and the RoundTrip goroutine that reads the Body can outlive + // this handler. This can lead to a crash if the handler panics (see + // Issue 46866). Although calling Close doesn't guarantee there isn't + // any Read in flight after the handle returns, in practice it's safe to + // read after closing it. + defer outreq.Body.Close() + } + if outreq.Header == nil { + outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate + } + + reqUpType := upgradeType(outreq.Header) + if !asciiIsPrint(reqUpType) { + utils.RespondVioletError(rw, http.StatusBadRequest, fmt.Sprintf("client tried to switch to invalid protocol %q", reqUpType)) + return + } + removeHopByHopHeaders(outreq.Header) + + // Issue 21096: tell backend applications that care about trailer support + // that we support trailers. (We do, but we don't go out of our way to + // advertise that unless the incoming client request thought it was worth + // mentioning.) Note that we look at req.Header, not outreq.Header, since + // the latter has passed through removeHopByHopHeaders. + if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") { + outreq.Header.Set("Te", "trailers") + } + + // After stripping all the hop-by-hop connection headers above, add back any + // necessary for protocol upgrades, such as for websockets. + if reqUpType != "" { + outreq.Header.Set("Connection", "Upgrade") + outreq.Header.Set("Upgrade", reqUpType) + } + + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + prior, ok := outreq.Header["X-Forwarded-For"] + omit := ok && prior == nil // Issue 38079: nil now means don't populate the header + if len(prior) > 0 { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + if !omit { + outreq.Header.Set("X-Forwarded-For", clientIP) + } + } } // String outputs a debug string for the route. func (r Route) String() string { return fmt.Sprintf("%#v", r) } + +// copyHeader copies all headers from src to dst +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +// updateType returns the value of upgrade from http.Header +func upgradeType(h http.Header) string { + if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { + return "" + } + return h.Get("Upgrade") +} + +// IsPrint returns whether s is ASCII and printable according to +// https://tools.ietf.org/html/rfc20#section-4.2. +func asciiIsPrint(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] < ' ' || s[i] > '~' { + return false + } + } + return true +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// As of RFC 7230, hop-by-hop headers are required to appear in the +// Connection header field. These are the headers defined by the +// obsoleted RFC 2616 (section 13.5.1) and are used for backward +// compatibility. +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 + "Transfer-Encoding", + "Upgrade", +} + +// removeHopByHopHeaders removes the hop-by-hop headers defined in hopHeaders +func removeHopByHopHeaders(h http.Header) { + // RFC 7230, section 6.1: Remove headers listed in the "Connection" header. + for _, f := range h["Connection"] { + for _, sf := range strings.Split(f, ",") { + if sf = textproto.TrimString(sf); sf != "" { + h.Del(sf) + } + } + } + // RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers. + // This behavior is superseded by the RFC 7230 Connection header, but + // preserve it for backwards compatibility. + for _, f := range hopHeaders { + h.Del(f) + } +} diff --git a/target/route_test.go b/target/route_test.go index a1598c9..075e73c 100644 --- a/target/route_test.go +++ b/target/route_test.go @@ -1,6 +1,7 @@ package target import ( + "github.com/MrMelon54/violet/proxy" "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" @@ -9,14 +10,17 @@ import ( type proxyTester struct { got bool - rw http.ResponseWriter req *http.Request } -func (p *proxyTester) ServeHTTP(rw http.ResponseWriter, req *http.Request) { +func (p *proxyTester) makeHybridTransport() *proxy.HybridTransport { + return proxy.NewHybridTransportWithCalls(p, p) +} + +func (p *proxyTester) RoundTrip(req *http.Request) (*http.Response, error) { p.got = true - p.rw = rw p.req = req + return &http.Response{StatusCode: http.StatusOK}, nil } func TestRoute_FullHost(t *testing.T) { @@ -38,7 +42,7 @@ func TestRoute_ServeHTTP(t *testing.T) { } for _, i := range a { pt := &proxyTester{} - i.Proxy = pt + i.Proxy = pt.makeHybridTransport() res := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "https://www.example.com/hello/world", nil) i.ServeHTTP(res, req) @@ -62,7 +66,7 @@ func TestRoute_ServeHTTP_Cors(t *testing.T) { res := httptest.NewRecorder() req := httptest.NewRequest(http.MethodOptions, "https://www.example.com/test", nil) req.Header.Set("Origin", "https://test.example.com") - i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt} + i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt.makeHybridTransport()} i.ServeHTTP(res, req) assert.True(t, pt.got) diff --git a/utils/response.go b/utils/response.go index cf59f42..4a54287 100644 --- a/utils/response.go +++ b/utils/response.go @@ -9,3 +9,8 @@ import ( func RespondHttpStatus(rw http.ResponseWriter, status int) { http.Error(rw, fmt.Sprintf("%d %s\n", status, http.StatusText(status)), status) } + +func RespondVioletError(rw http.ResponseWriter, status int, msg string) { + rw.Header().Set("X-Violet-Error", msg) + RespondHttpStatus(rw, status) +}