mirror of
https://github.com/1f349/violet.git
synced 2024-11-21 10:51:40 +00:00
Yes rewrote some stuff
This commit is contained in:
parent
9899d67d50
commit
1f487eb80c
@ -1,12 +0,0 @@
|
|||||||
module benchmarks
|
|
||||||
|
|
||||||
go 1.20
|
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/MrMelon54/violet v0.0.0-20230419182034-77d570ac1e6d
|
|
||||||
github.com/gorilla/mux v1.8.0
|
|
||||||
)
|
|
||||||
|
|
||||||
require github.com/MrMelon54/trie v0.0.2 // indirect
|
|
||||||
|
|
||||||
replace github.com/MrMelon54/violet => ../
|
|
@ -1,8 +0,0 @@
|
|||||||
github.com/MrMelon54/trie v0.0.2 h1:ZXWcX5ij62O9K4I/anuHmVg8L3tF0UGdlPceAASwKEY=
|
|
||||||
github.com/MrMelon54/trie v0.0.2/go.mod h1:sGCGOcqb+DxSxvHgSOpbpkmA7mFZR47YDExy9OCbVZI=
|
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
|
||||||
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
|
|
||||||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
|
||||||
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
|
@ -1,42 +0,0 @@
|
|||||||
package benchmarks
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/MrMelon54/violet/router"
|
|
||||||
"github.com/MrMelon54/violet/target"
|
|
||||||
gorillaRouter "github.com/gorilla/mux"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func benchRequest(b *testing.B, router http.Handler, r *http.Request) {
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
b.ReportAllocs()
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
router.ServeHTTP(w, r)
|
|
||||||
}
|
|
||||||
if w.Header().Get("Location") != "https://example.com" {
|
|
||||||
b.Fatal("Location: ", w.Header().Get("Location"), " != https://example.com")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkVioletRouter(b *testing.B) {
|
|
||||||
r := router.New(nil)
|
|
||||||
r.AddRedirect("*.example.com", "", target.Redirect{
|
|
||||||
Pre: true,
|
|
||||||
Host: "example.com",
|
|
||||||
Code: http.StatusPermanentRedirect,
|
|
||||||
})
|
|
||||||
benchRequest(b, r, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil))
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkGorillaMux(b *testing.B) {
|
|
||||||
r := gorillaRouter.NewRouter()
|
|
||||||
r.Host("{subdomain}.example.com").Handler(target.Redirect{
|
|
||||||
Pre: true,
|
|
||||||
Host: "example.com",
|
|
||||||
Code: http.StatusPermanentRedirect,
|
|
||||||
})
|
|
||||||
benchRequest(b, r, httptest.NewRequest(http.MethodGet, "https://www.example.com/", nil))
|
|
||||||
}
|
|
@ -62,7 +62,7 @@ func main() {
|
|||||||
|
|
||||||
allowedDomains := domains.New(db) // load allowed domains
|
allowedDomains := domains.New(db) // load allowed domains
|
||||||
allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath), *selfSigned) // load certificate manager
|
allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath), *selfSigned) // load certificate manager
|
||||||
reverseProxy := proxy.CreateHybridReverseProxy() // load reverse proxy
|
reverseProxy := proxy.NewHybridTransport() // load reverse proxy
|
||||||
dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider
|
dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider
|
||||||
dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider
|
dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider
|
||||||
dynamicRouter := router.NewManager(db, reverseProxy) // load dynamic router manager
|
dynamicRouter := router.NewManager(db, reverseProxy) // load dynamic router manager
|
||||||
|
@ -48,12 +48,15 @@ func (d *Domains) IsValid(host string) bool {
|
|||||||
defer d.s.RUnlock()
|
defer d.s.RUnlock()
|
||||||
|
|
||||||
// check root domains `www.example.com`, `example.com`, `com`
|
// check root domains `www.example.com`, `example.com`, `com`
|
||||||
// TODO: could be faster using indexes and cropping the string?
|
for len(domain) > 0 {
|
||||||
n := strings.Split(domain, ".")
|
if _, ok := d.m[domain]; ok {
|
||||||
for i := 0; i < len(n); i++ {
|
|
||||||
if _, ok := d.m[strings.Join(n[i:], ".")]; ok {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
n := strings.IndexByte(domain, '.')
|
||||||
|
if n == -1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
domain = domain[n+1:]
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
45
domains/domains_test.go
Normal file
45
domains/domains_test.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package domains
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDomainsNew(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
domains := New(db)
|
||||||
|
_, err = db.Exec("insert into domains (domain, active) values (?, ?)", "example.com", 1)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
domains.Compile()
|
||||||
|
|
||||||
|
if _, ok := domains.m["example.com"]; ok {
|
||||||
|
assert.True(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := domains.m["www.example.com"]; !ok {
|
||||||
|
assert.False(t, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDomains_IsValid(t *testing.T) {
|
||||||
|
// open sqlite database
|
||||||
|
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
domains := New(db)
|
||||||
|
_, err = domains.db.Exec("insert into domains (domain, active) values (?, ?)", "example.com", 1)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
domains.s.Lock()
|
||||||
|
assert.NoError(t, domains.internalCompile(domains.m))
|
||||||
|
domains.s.Unlock()
|
||||||
|
|
||||||
|
assert.True(t, domains.IsValid("example.com"))
|
||||||
|
assert.True(t, domains.IsValid("www.example.com"))
|
||||||
|
assert.False(t, domains.IsValid("notexample.com"))
|
||||||
|
assert.False(t, domains.IsValid("www.notexample.com"))
|
||||||
|
}
|
@ -29,6 +29,7 @@ func New(dir fs.FS) *ErrorPages {
|
|||||||
generic: func(rw http.ResponseWriter, code int) {
|
generic: func(rw http.ResponseWriter, code int) {
|
||||||
// if status text is empty then the code is unknown
|
// if status text is empty then the code is unknown
|
||||||
a := http.StatusText(code)
|
a := http.StatusText(code)
|
||||||
|
fmt.Printf("%d - %s\n", code, a)
|
||||||
if a != "" {
|
if a != "" {
|
||||||
// output in "xxx Error Text" format
|
// output in "xxx Error Text" format
|
||||||
http.Error(rw, fmt.Sprintf("%d %s\n", code, a), code)
|
http.Error(rw, fmt.Sprintf("%d %s\n", code, a), code)
|
||||||
@ -64,10 +65,12 @@ func (e *ErrorPages) Compile() {
|
|||||||
errorPageMap := make(map[int]func(rw http.ResponseWriter))
|
errorPageMap := make(map[int]func(rw http.ResponseWriter))
|
||||||
|
|
||||||
// compile map and check errors
|
// compile map and check errors
|
||||||
err := e.internalCompile(errorPageMap)
|
if e.dir != nil {
|
||||||
if err != nil {
|
err := e.internalCompile(errorPageMap)
|
||||||
log.Printf("[Certs] Compile failed: %s\n", err)
|
if err != nil {
|
||||||
return
|
log.Printf("[Certs] Compile failed: %s\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// lock while replacing the map
|
// lock while replacing the map
|
||||||
|
31
error-pages/error-pages_test.go
Normal file
31
error-pages/error-pages_test.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
package error_pages
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestErrorPages_ServeError(t *testing.T) {
|
||||||
|
errorPages := New(nil)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
errorPages.ServeError(rec, http.StatusTeapot)
|
||||||
|
res := rec.Result()
|
||||||
|
assert.Equal(t, http.StatusTeapot, res.StatusCode)
|
||||||
|
assert.Equal(t, "418 I'm a teapot", res.Status)
|
||||||
|
a, err := io.ReadAll(res.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "418 I'm a teapot\n\n", string(a))
|
||||||
|
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
errorPages.ServeError(rec, 469)
|
||||||
|
res = rec.Result()
|
||||||
|
assert.Equal(t, 469, res.StatusCode)
|
||||||
|
assert.Equal(t, "469 ", res.Status)
|
||||||
|
a, err = io.ReadAll(res.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "469 Unknown Error Code\n\n", string(a))
|
||||||
|
}
|
BIN
favicons/example.ico
Normal file
BIN
favicons/example.ico
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.2 KiB |
BIN
favicons/example.png
Normal file
BIN
favicons/example.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.2 KiB |
4
favicons/example.svg
Normal file
4
favicons/example.svg
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg height="100" width="100" viewBox="0 0 100 100" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<circle cx="50" cy="50" r="40" stroke="black" stroke-width="3" fill="red"/>
|
||||||
|
</svg>
|
After Width: | Height: | Size: 214 B |
17
favicons/favicon-image.go
Normal file
17
favicons/favicon-image.go
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
package favicons
|
||||||
|
|
||||||
|
// FaviconImage stores the url, hash and raw bytes of an image
|
||||||
|
type FaviconImage struct {
|
||||||
|
Url string
|
||||||
|
Hash string
|
||||||
|
Raw []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateFaviconImage outputs a FaviconImage with the specified URL or nil if
|
||||||
|
// the URL is an empty string.
|
||||||
|
func CreateFaviconImage(url string) *FaviconImage {
|
||||||
|
if url == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &FaviconImage{Url: url}
|
||||||
|
}
|
144
favicons/favicon-list.go
Normal file
144
favicons/favicon-list.go
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
package favicons
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"github.com/mrmelon54/png2ico"
|
||||||
|
"image/png"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FaviconList contains the ico, png and svg icons for separate favicons
|
||||||
|
type FaviconList struct {
|
||||||
|
Ico *FaviconImage // can be generated from png with wrapper
|
||||||
|
Png *FaviconImage // can be generated from svg with inkscape
|
||||||
|
Svg *FaviconImage
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProduceIco outputs the bytes of the ico icon or an error
|
||||||
|
func (l *FaviconList) ProduceIco() ([]byte, error) {
|
||||||
|
if l.Ico == nil {
|
||||||
|
return nil, ErrFaviconNotFound
|
||||||
|
}
|
||||||
|
return l.Ico.Raw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProducePng outputs the bytes of the png icon or an error
|
||||||
|
func (l *FaviconList) ProducePng() ([]byte, error) {
|
||||||
|
if l.Png == nil {
|
||||||
|
return nil, ErrFaviconNotFound
|
||||||
|
}
|
||||||
|
return l.Png.Raw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProduceSvg outputs the bytes of the svg icon or an error
|
||||||
|
func (l *FaviconList) ProduceSvg() ([]byte, error) {
|
||||||
|
if l.Svg == nil {
|
||||||
|
return nil, ErrFaviconNotFound
|
||||||
|
}
|
||||||
|
return l.Svg.Raw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreProcess takes an input of the svg2png conversion function and outputs
|
||||||
|
// an error if the SVG, PNG or ICO fails to download or generate
|
||||||
|
func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// SVG
|
||||||
|
if l.Svg != nil {
|
||||||
|
// download SVG
|
||||||
|
l.Svg.Raw, err = getFaviconViaRequest(l.Svg.Url)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("[Favicons] Failed to fetch SVG icon: %w", err)
|
||||||
|
}
|
||||||
|
l.Svg.Hash = hex.EncodeToString(sha256.New().Sum(l.Svg.Raw))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PNG
|
||||||
|
if l.Png != nil {
|
||||||
|
// download PNG
|
||||||
|
l.Png.Raw, err = getFaviconViaRequest(l.Png.Url)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("[Favicons] Failed to fetch PNG icon: %w", err)
|
||||||
|
}
|
||||||
|
} else if l.Svg != nil {
|
||||||
|
// generate PNG from SVG
|
||||||
|
l.Png = &FaviconImage{}
|
||||||
|
l.Png.Raw, err = convert(l.Svg.Raw)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("[Favicons] Failed to generate PNG icon: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ICO
|
||||||
|
if l.Ico != nil {
|
||||||
|
// download ICO
|
||||||
|
l.Ico.Raw, err = getFaviconViaRequest(l.Ico.Url)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("[Favicons] Failed to fetch ICO icon: %w", err)
|
||||||
|
}
|
||||||
|
} else if l.Png != nil {
|
||||||
|
// generate ICO from PNG
|
||||||
|
l.Ico = &FaviconImage{}
|
||||||
|
decode, err := png.Decode(bytes.NewReader(l.Png.Raw))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("[Favicons] Failed to decode PNG icon: %w", err)
|
||||||
|
}
|
||||||
|
b := decode.Bounds()
|
||||||
|
l.Ico.Raw, err = png2ico.ConvertPngToIco(l.Png.Raw, b.Dx(), b.Dy())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("[Favicons] Failed to generate ICO icon: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate sha256 hashes for svg, png and ico
|
||||||
|
l.genSha256()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// genSha256 generates sha256 hashes
|
||||||
|
func (l *FaviconList) genSha256() {
|
||||||
|
if l.Svg != nil {
|
||||||
|
l.Svg.Hash = genSha256(l.Svg.Raw)
|
||||||
|
}
|
||||||
|
if l.Png != nil {
|
||||||
|
l.Png.Hash = genSha256(l.Png.Raw)
|
||||||
|
}
|
||||||
|
if l.Ico != nil {
|
||||||
|
l.Ico.Hash = genSha256(l.Ico.Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFaviconViaRequest uses the standard http request library to download
|
||||||
|
// icons, outputs the raw bytes from the download or an error.
|
||||||
|
var getFaviconViaRequest = func(url string) ([]byte, error) {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("[Favicons] Failed to send request '%s': %w", url, err)
|
||||||
|
}
|
||||||
|
req.Header.Set("X-Violet-Raw-Favicon", "1")
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("[Favicons] Failed to do request '%s': %w", url, err)
|
||||||
|
}
|
||||||
|
rawBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("[Favicons] Failed to read response '%s': %w", url, err)
|
||||||
|
}
|
||||||
|
return rawBody, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// genSha256 generates a sha256 hash as a hex encoded string
|
||||||
|
func genSha256(in []byte) string {
|
||||||
|
// create sha256 generator and write to it
|
||||||
|
h := sha256.New()
|
||||||
|
_, err := h.Write(in)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// encode as hex
|
||||||
|
return hex.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
33
favicons/favicon-list_test.go
Normal file
33
favicons/favicon-list_test.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
package favicons
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFaviconList_PreProcess(t *testing.T) {
|
||||||
|
getFaviconViaRequest = func(_ string) ([]byte, error) {
|
||||||
|
return exampleSvg, nil
|
||||||
|
}
|
||||||
|
icons := &FaviconList{Svg: &FaviconImage{Url: "https://example.com/assets/logo.svg"}}
|
||||||
|
assert.NoError(t, icons.PreProcess(func(in []byte) ([]byte, error) {
|
||||||
|
return svg2png("inkscape", in)
|
||||||
|
}))
|
||||||
|
iconSvg, err := icons.ProduceSvg()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
iconPng, err := icons.ProducePng()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
iconIco, err := icons.ProduceIco()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "https://example.com/assets/logo.svg", icons.Svg.Url)
|
||||||
|
|
||||||
|
assert.Equal(t, "74cdc17d0502a690941799c327d9ca1ed042e76c784def43a42937f2eed270b4", icons.Svg.Hash)
|
||||||
|
assert.Equal(t, "84841341dafbb1e54c62d160dfc5e48c3f8db4b22265a4dbe2e0318debf9b670", icons.Png.Hash)
|
||||||
|
assert.Equal(t, "33fc667fdb0e32305f2ee27e7dd7feb781cc776638d0971db7e18cc6335a15c7", icons.Ico.Hash)
|
||||||
|
|
||||||
|
assert.Equal(t, 0, bytes.Compare(exampleSvg, iconSvg))
|
||||||
|
assert.Equal(t, 0, bytes.Compare(examplePng, iconPng))
|
||||||
|
assert.Equal(t, 0, bytes.Compare(exampleIco, iconIco))
|
||||||
|
}
|
@ -1,18 +1,11 @@
|
|||||||
package favicons
|
package favicons
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"crypto/sha256"
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/hex"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/mrmelon54/png2ico"
|
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
"image/png"
|
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -121,150 +114,3 @@ func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error {
|
|||||||
func (f *Favicons) convertSvgToPng(in []byte) ([]byte, error) {
|
func (f *Favicons) convertSvgToPng(in []byte) ([]byte, error) {
|
||||||
return svg2png(f.cmd, in)
|
return svg2png(f.cmd, in)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FaviconList contains the ico, png and svg icons for separate favicons
|
|
||||||
type FaviconList struct {
|
|
||||||
Ico *FaviconImage // can be generated from png with wrapper
|
|
||||||
Png *FaviconImage // can be generated from svg with inkscape
|
|
||||||
Svg *FaviconImage
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProduceIco outputs the bytes of the ico icon or an error
|
|
||||||
func (l *FaviconList) ProduceIco() ([]byte, error) {
|
|
||||||
if l.Ico == nil {
|
|
||||||
return nil, ErrFaviconNotFound
|
|
||||||
}
|
|
||||||
return l.Ico.Raw, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProducePng outputs the bytes of the png icon or an error
|
|
||||||
func (l *FaviconList) ProducePng() ([]byte, error) {
|
|
||||||
if l.Png == nil {
|
|
||||||
return nil, ErrFaviconNotFound
|
|
||||||
}
|
|
||||||
return l.Png.Raw, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProduceSvg outputs the bytes of the svg icon or an error
|
|
||||||
func (l *FaviconList) ProduceSvg() ([]byte, error) {
|
|
||||||
if l.Svg == nil {
|
|
||||||
return nil, ErrFaviconNotFound
|
|
||||||
}
|
|
||||||
return l.Svg.Raw, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PreProcess takes an input of the svg2png conversion function and outputs
|
|
||||||
// an error if the SVG, PNG or ICO fails to download or generate
|
|
||||||
func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// SVG
|
|
||||||
if l.Svg != nil {
|
|
||||||
// download SVG
|
|
||||||
l.Svg.Raw, err = getFaviconViaRequest(l.Svg.Url)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("[Favicons] Failed to fetch SVG icon: %w", err)
|
|
||||||
}
|
|
||||||
l.Svg.Hash = hex.EncodeToString(sha256.New().Sum(l.Svg.Raw))
|
|
||||||
}
|
|
||||||
|
|
||||||
// PNG
|
|
||||||
if l.Png != nil {
|
|
||||||
// download PNG
|
|
||||||
l.Png.Raw, err = getFaviconViaRequest(l.Png.Url)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("[Favicons] Failed to fetch PNG icon: %w", err)
|
|
||||||
}
|
|
||||||
} else if l.Svg != nil {
|
|
||||||
// generate PNG from SVG
|
|
||||||
l.Png = &FaviconImage{}
|
|
||||||
l.Png.Raw, err = convert(l.Svg.Raw)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("[Favicons] Failed to generate PNG icon: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ICO
|
|
||||||
if l.Ico != nil {
|
|
||||||
// download ICO
|
|
||||||
l.Ico.Raw, err = getFaviconViaRequest(l.Ico.Url)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("[Favicons] Failed to fetch ICO icon: %w", err)
|
|
||||||
}
|
|
||||||
} else if l.Png != nil {
|
|
||||||
// generate ICO from PNG
|
|
||||||
l.Ico = &FaviconImage{}
|
|
||||||
decode, err := png.Decode(bytes.NewReader(l.Png.Raw))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("[Favicons] Failed to decode PNG icon: %w", err)
|
|
||||||
}
|
|
||||||
b := decode.Bounds()
|
|
||||||
l.Ico.Raw, err = png2ico.ConvertPngToIco(l.Png.Raw, b.Dx(), b.Dy())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("[Favicons] Failed to generate ICO icon: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// generate sha256 hashes for svg, png and ico
|
|
||||||
l.genSha256()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// genSha256 generates sha256 hashes
|
|
||||||
func (l *FaviconList) genSha256() {
|
|
||||||
if l.Svg != nil {
|
|
||||||
l.Svg.Hash = genSha256(l.Svg.Raw)
|
|
||||||
}
|
|
||||||
if l.Png != nil {
|
|
||||||
l.Png.Hash = genSha256(l.Png.Raw)
|
|
||||||
}
|
|
||||||
if l.Ico != nil {
|
|
||||||
l.Ico.Hash = genSha256(l.Ico.Raw)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getFaviconViaRequest uses the standard http request library to download
|
|
||||||
// icons, outputs the raw bytes from the download or an error.
|
|
||||||
func getFaviconViaRequest(url string) ([]byte, error) {
|
|
||||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("[Favicons] Failed to send request '%s': %w", url, err)
|
|
||||||
}
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("[Favicons] Failed to do request '%s': %w", url, err)
|
|
||||||
}
|
|
||||||
rawBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("[Favicons] Failed to read response '%s': %w", url, err)
|
|
||||||
}
|
|
||||||
return rawBody, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// genSha256 generates a sha256 hash as a hex encoded string
|
|
||||||
func genSha256(in []byte) string {
|
|
||||||
// create sha256 generator and write to it
|
|
||||||
h := sha256.New()
|
|
||||||
_, err := h.Write(in)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
// encode as hex
|
|
||||||
return hex.EncodeToString(h.Sum(nil))
|
|
||||||
}
|
|
||||||
|
|
||||||
// FaviconImage stores the url, hash and raw bytes of an image
|
|
||||||
type FaviconImage struct {
|
|
||||||
Url string
|
|
||||||
Hash string
|
|
||||||
Raw []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateFaviconImage outputs a FaviconImage with the specified URL or nil if
|
|
||||||
// the URL is an empty string.
|
|
||||||
func CreateFaviconImage(url string) *FaviconImage {
|
|
||||||
if url == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &FaviconImage{Url: url}
|
|
||||||
}
|
|
||||||
|
51
favicons/favicons_test.go
Normal file
51
favicons/favicons_test.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
package favicons
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"database/sql"
|
||||||
|
_ "embed"
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
//go:embed example.svg
|
||||||
|
exampleSvg []byte
|
||||||
|
//go:embed example.png
|
||||||
|
examplePng []byte
|
||||||
|
//go:embed example.ico
|
||||||
|
exampleIco []byte
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFaviconsNew(t *testing.T) {
|
||||||
|
getFaviconViaRequest = func(_ string) ([]byte, error) { return exampleSvg, nil }
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
favicons := New(db, "inkscape")
|
||||||
|
_, err = db.Exec("insert into favicons (host, svg, png, ico) values (?, ?, ?, ?)", "example.com", "https://example.com/assets/logo.svg", "", "")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
favicons.cLock.Lock()
|
||||||
|
assert.NoError(t, favicons.internalCompile(favicons.faviconMap))
|
||||||
|
favicons.cLock.Unlock()
|
||||||
|
|
||||||
|
icons := favicons.GetIcons("example.com")
|
||||||
|
iconSvg, err := icons.ProduceSvg()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
iconPng, err := icons.ProducePng()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
iconIco, err := icons.ProduceIco()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "https://example.com/assets/logo.svg", icons.Svg.Url)
|
||||||
|
|
||||||
|
assert.Equal(t, "74cdc17d0502a690941799c327d9ca1ed042e76c784def43a42937f2eed270b4", icons.Svg.Hash)
|
||||||
|
assert.Equal(t, "84841341dafbb1e54c62d160dfc5e48c3f8db4b22265a4dbe2e0318debf9b670", icons.Png.Hash)
|
||||||
|
assert.Equal(t, "33fc667fdb0e32305f2ee27e7dd7feb781cc776638d0971db7e18cc6335a15c7", icons.Ico.Hash)
|
||||||
|
|
||||||
|
assert.Equal(t, 0, bytes.Compare(exampleSvg, iconSvg))
|
||||||
|
assert.Equal(t, 0, bytes.Compare(examplePng, iconPng))
|
||||||
|
assert.Equal(t, 0, bytes.Compare(exampleIco, iconIco))
|
||||||
|
}
|
2
go.mod
2
go.mod
@ -6,7 +6,6 @@ require (
|
|||||||
code.mrmelon54.com/melon/certgen v0.0.0-20220830133534-0fb4cb7e67d1
|
code.mrmelon54.com/melon/certgen v0.0.0-20220830133534-0fb4cb7e67d1
|
||||||
code.mrmelon54.com/melon/summer-utils v0.0.3
|
code.mrmelon54.com/melon/summer-utils v0.0.3
|
||||||
github.com/MrMelon54/trie v0.0.2
|
github.com/MrMelon54/trie v0.0.2
|
||||||
github.com/gorilla/mux v1.8.0
|
|
||||||
github.com/julienschmidt/httprouter v1.3.0
|
github.com/julienschmidt/httprouter v1.3.0
|
||||||
github.com/mattn/go-sqlite3 v1.14.16
|
github.com/mattn/go-sqlite3 v1.14.16
|
||||||
github.com/mrmelon54/mjwt v0.0.1
|
github.com/mrmelon54/mjwt v0.0.1
|
||||||
@ -24,6 +23,7 @@ require (
|
|||||||
github.com/kr/pretty v0.1.0 // indirect
|
github.com/kr/pretty v0.1.0 // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
golang.org/x/text v0.9.0 // indirect
|
||||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
4
go.sum
4
go.sum
@ -9,8 +9,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
|||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
|
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||||
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
|
|
||||||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
|
||||||
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
|
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
|
||||||
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
||||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||||
@ -43,6 +41,8 @@ golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
|
|||||||
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
||||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck=
|
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck=
|
||||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
||||||
|
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
73
proxy/hybrid-transport.go
Normal file
73
proxy/hybrid-transport.go
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HybridTransport struct {
|
||||||
|
baseDialer *net.Dialer
|
||||||
|
normalTransport http.RoundTripper
|
||||||
|
insecureTransport http.RoundTripper
|
||||||
|
socksSync *sync.RWMutex
|
||||||
|
socksTransport map[string]http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHybridTransport creates a new hybrid transport
|
||||||
|
func NewHybridTransport() *HybridTransport {
|
||||||
|
return NewHybridTransportWithCalls(nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHybridTransportWithCalls creates new hybrid transport with custom normal
|
||||||
|
// and insecure http.RoundTripper functions.
|
||||||
|
//
|
||||||
|
// NewHybridTransportWithCalls(nil, nil) is equivalent to NewHybridTransport()
|
||||||
|
func NewHybridTransportWithCalls(normal, insecure http.RoundTripper) *HybridTransport {
|
||||||
|
h := &HybridTransport{
|
||||||
|
baseDialer: &net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
},
|
||||||
|
normalTransport: normal,
|
||||||
|
insecureTransport: insecure,
|
||||||
|
}
|
||||||
|
if h.normalTransport == nil {
|
||||||
|
h.normalTransport = &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: h.baseDialer.DialContext,
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConns: 15,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
IdleConnTimeout: 30 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
ResponseHeaderTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.insecureTransport == nil {
|
||||||
|
h.insecureTransport = &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: h.baseDialer.DialContext,
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConns: 15,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
IdleConnTimeout: 30 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
ResponseHeaderTimeout: 10 * time.Second,
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// SecureRoundTrip calls the secure transport
|
||||||
|
func (h *HybridTransport) SecureRoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return h.normalTransport.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsecureRoundTrip calls the insecure transport
|
||||||
|
func (h *HybridTransport) InsecureRoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return h.insecureTransport.RoundTrip(req)
|
||||||
|
}
|
16
proxy/hybrid-transport_test.go
Normal file
16
proxy/hybrid-transport_test.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewHybridTransport(t *testing.T) {
|
||||||
|
h := NewHybridTransport()
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
trip, err := h.SecureRoundTrip(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, trip.StatusCode)
|
||||||
|
}
|
@ -1,145 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httputil"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type reverseProxyHostKey int
|
|
||||||
|
|
||||||
type ReverseProxyContext interface {
|
|
||||||
IsIgnoreCert() bool
|
|
||||||
UpdateHeaders(http.Header)
|
|
||||||
}
|
|
||||||
|
|
||||||
func SetReverseProxyHost(req *http.Request, hf ReverseProxyContext) *http.Request {
|
|
||||||
ctx := req.Context()
|
|
||||||
ctx2 := context.WithValue(ctx, reverseProxyHostKey(0), hf)
|
|
||||||
return req.WithContext(ctx2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateHybridReverseProxy() *httputil.ReverseProxy {
|
|
||||||
return &httputil.ReverseProxy{
|
|
||||||
Director: func(req *http.Request) {},
|
|
||||||
Transport: NewHybridTransport(),
|
|
||||||
ModifyResponse: func(rw *http.Response) error { return nil },
|
|
||||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
|
||||||
log.Printf("[ReverseProxy] Request: %#v\n -- Error: %s\n", req, err)
|
|
||||||
rw.WriteHeader(http.StatusBadGateway)
|
|
||||||
_, _ = rw.Write([]byte("502 Bad gateway\n"))
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type HybridTransport struct {
|
|
||||||
baseDialer *net.Dialer
|
|
||||||
normalTransport http.RoundTripper
|
|
||||||
insecureTransport http.RoundTripper
|
|
||||||
socksSync *sync.RWMutex
|
|
||||||
socksTransport map[string]http.RoundTripper
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHybridTransport() *HybridTransport {
|
|
||||||
h := &HybridTransport{
|
|
||||||
baseDialer: &net.Dialer{
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
},
|
|
||||||
socksSync: &sync.RWMutex{},
|
|
||||||
socksTransport: make(map[string]http.RoundTripper),
|
|
||||||
}
|
|
||||||
h.normalTransport = &http.Transport{
|
|
||||||
Proxy: http.ProxyFromEnvironment,
|
|
||||||
DialContext: h.baseDialer.DialContext,
|
|
||||||
ForceAttemptHTTP2: true,
|
|
||||||
MaxIdleConns: 15,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
|
||||||
IdleConnTimeout: 30 * time.Second,
|
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
|
||||||
ResponseHeaderTimeout: 10 * time.Second,
|
|
||||||
}
|
|
||||||
h.insecureTransport = &http.Transport{
|
|
||||||
Proxy: http.ProxyFromEnvironment,
|
|
||||||
DialContext: h.baseDialer.DialContext,
|
|
||||||
ForceAttemptHTTP2: true,
|
|
||||||
MaxIdleConns: 15,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
|
||||||
IdleConnTimeout: 30 * time.Second,
|
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
|
||||||
ResponseHeaderTimeout: 10 * time.Second,
|
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
||||||
}
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HybridTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
newHost := req.Context().Value(reverseProxyHostKey(0))
|
|
||||||
hf, ok := newHost.(ReverseProxyContext)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("failed to detect reverse proxy configuration")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do a round trip using existing transports
|
|
||||||
var trip *http.Response
|
|
||||||
var err error
|
|
||||||
if hf.IsIgnoreCert() {
|
|
||||||
trip, err = h.insecureTransport.RoundTrip(req)
|
|
||||||
} else {
|
|
||||||
trip, err = h.normalTransport.RoundTrip(req)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Override headers
|
|
||||||
hf.UpdateHeaders(trip.Header)
|
|
||||||
return trip, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HybridTransport) getSocksProxy(addr string, insecure bool) (http.RoundTripper, error) {
|
|
||||||
if insecure {
|
|
||||||
addr = "%i-" + addr
|
|
||||||
}
|
|
||||||
h.socksSync.RLock()
|
|
||||||
s, ok := h.socksTransport[addr]
|
|
||||||
h.socksSync.RUnlock()
|
|
||||||
if ok {
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
dialer, err := proxy.SOCKS5("tcp", addr, nil, proxy.Direct)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("cannot connect to the proxy: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if f, ok := dialer.(proxy.ContextDialer); ok {
|
|
||||||
t := &http.Transport{
|
|
||||||
Proxy: http.ProxyFromEnvironment,
|
|
||||||
DialContext: f.DialContext,
|
|
||||||
ForceAttemptHTTP2: true,
|
|
||||||
MaxIdleConns: 15,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
|
||||||
IdleConnTimeout: 30 * time.Second,
|
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
|
||||||
ResponseHeaderTimeout: 10 * time.Second,
|
|
||||||
DisableKeepAlives: true,
|
|
||||||
}
|
|
||||||
if insecure {
|
|
||||||
t.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
|
||||||
}
|
|
||||||
h.socksSync.Lock()
|
|
||||||
h.socksTransport[addr] = t
|
|
||||||
h.socksSync.Unlock()
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
return nil, errors.New("cannot create socks5 dialer")
|
|
||||||
}
|
|
48
proxy/reverse-proxy_test.go
Normal file
48
proxy/reverse-proxy_test.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/http/httputil"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type customTransport struct{}
|
||||||
|
|
||||||
|
func (c *customTransport) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||||
|
res := httptest.NewRecorder()
|
||||||
|
res.WriteHeader(http.StatusOK)
|
||||||
|
res.Write([]byte{0x54, 0x54})
|
||||||
|
return res.Result(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetupReverseProxy() *httputil.ReverseProxy {
|
||||||
|
return &httputil.ReverseProxy{
|
||||||
|
Director: func(req *http.Request) {},
|
||||||
|
Transport: &customTransport{},
|
||||||
|
ModifyResponse: func(rw *http.Response) error { return nil },
|
||||||
|
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||||
|
log.Printf("[ReverseProxy] Request: %#v\n -- Error: %s\n", req, err)
|
||||||
|
rw.WriteHeader(http.StatusBadGateway)
|
||||||
|
_, _ = rw.Write([]byte("502 Bad gateway\n"))
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkHttpUtilReverseProxy(b *testing.B) {
|
||||||
|
rev := SetupReverseProxy()
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
rev.ServeHTTP(rec, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCustomTransport(b *testing.B) {
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||||
|
t := &customTransport{}
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = t.RoundTrip(req)
|
||||||
|
}
|
||||||
|
}
|
@ -4,11 +4,11 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/MrMelon54/violet/proxy"
|
||||||
"github.com/MrMelon54/violet/target"
|
"github.com/MrMelon54/violet/target"
|
||||||
"github.com/MrMelon54/violet/utils"
|
"github.com/MrMelon54/violet/utils"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
@ -19,7 +19,7 @@ type Manager struct {
|
|||||||
db *sql.DB
|
db *sql.DB
|
||||||
s *sync.RWMutex
|
s *sync.RWMutex
|
||||||
r *Router
|
r *Router
|
||||||
p *httputil.ReverseProxy
|
p *proxy.HybridTransport
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -35,7 +35,7 @@ var (
|
|||||||
|
|
||||||
// NewManager create a new manager, initialises the routes and redirects tables
|
// NewManager create a new manager, initialises the routes and redirects tables
|
||||||
// in the database and runs a first time compile.
|
// in the database and runs a first time compile.
|
||||||
func NewManager(db *sql.DB, proxy *httputil.ReverseProxy) *Manager {
|
func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager {
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
db: db,
|
db: db,
|
||||||
s: &sync.RWMutex{},
|
s: &sync.RWMutex{},
|
||||||
|
@ -3,10 +3,10 @@ package router
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/MrMelon54/trie"
|
"github.com/MrMelon54/trie"
|
||||||
|
"github.com/MrMelon54/violet/proxy"
|
||||||
"github.com/MrMelon54/violet/target"
|
"github.com/MrMelon54/violet/target"
|
||||||
"github.com/MrMelon54/violet/utils"
|
"github.com/MrMelon54/violet/utils"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -14,10 +14,10 @@ type Router struct {
|
|||||||
route map[string]*trie.Trie[target.Route]
|
route map[string]*trie.Trie[target.Route]
|
||||||
redirect map[string]*trie.Trie[target.Redirect]
|
redirect map[string]*trie.Trie[target.Redirect]
|
||||||
notFound http.Handler
|
notFound http.Handler
|
||||||
proxy *httputil.ReverseProxy
|
proxy *proxy.HybridTransport
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(proxy *httputil.ReverseProxy) *Router {
|
func New(proxy *proxy.HybridTransport) *Router {
|
||||||
return &Router{
|
return &Router{
|
||||||
route: make(map[string]*trie.Trie[target.Route]),
|
route: make(map[string]*trie.Trie[target.Route]),
|
||||||
redirect: make(map[string]*trie.Trie[target.Redirect]),
|
redirect: make(map[string]*trie.Trie[target.Redirect]),
|
||||||
|
@ -14,7 +14,7 @@ import (
|
|||||||
// endpoints for the reverse proxy.
|
// endpoints for the reverse proxy.
|
||||||
//
|
//
|
||||||
// `/.well-known/acme-challenge/{token}` is used for outputting answers for
|
// `/.well-known/acme-challenge/{token}` is used for outputting answers for
|
||||||
// acme challenges, this is used for Lets Encrypt HTTP verification.
|
// acme challenges, this is used for Let's Encrypt HTTP verification.
|
||||||
func NewHttpServer(conf *Conf) *http.Server {
|
func NewHttpServer(conf *Conf) *http.Server {
|
||||||
r := httprouter.New()
|
r := httprouter.New()
|
||||||
var secureExtend string
|
var secureExtend string
|
||||||
|
@ -3,8 +3,8 @@ package servers
|
|||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/MrMelon54/violet/favicons"
|
||||||
"github.com/MrMelon54/violet/utils"
|
"github.com/MrMelon54/violet/utils"
|
||||||
"github.com/gorilla/mux"
|
|
||||||
"github.com/sethvargo/go-limiter/httplimit"
|
"github.com/sethvargo/go-limiter/httplimit"
|
||||||
"github.com/sethvargo/go-limiter/memorystore"
|
"github.com/sethvargo/go-limiter/memorystore"
|
||||||
"log"
|
"log"
|
||||||
@ -18,7 +18,7 @@ import (
|
|||||||
func NewHttpsServer(conf *Conf) *http.Server {
|
func NewHttpsServer(conf *Conf) *http.Server {
|
||||||
s := &http.Server{
|
s := &http.Server{
|
||||||
Addr: conf.HttpsListen,
|
Addr: conf.HttpsListen,
|
||||||
Handler: setupRateLimiter(300).Middleware(conf.Router),
|
Handler: setupRateLimiter(300, setupFaviconMiddleware(conf.Favicons, conf.Router)),
|
||||||
DisableGeneralOptionsHandler: false,
|
DisableGeneralOptionsHandler: false,
|
||||||
TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
// error out on invalid domains
|
// error out on invalid domains
|
||||||
@ -51,7 +51,7 @@ func NewHttpsServer(conf *Conf) *http.Server {
|
|||||||
|
|
||||||
// setupRateLimiter is an internal function to create a middleware to manage
|
// setupRateLimiter is an internal function to create a middleware to manage
|
||||||
// rate limits.
|
// rate limits.
|
||||||
func setupRateLimiter(rateLimit uint64) mux.MiddlewareFunc {
|
func setupRateLimiter(rateLimit uint64, next http.Handler) http.Handler {
|
||||||
// create memory store
|
// create memory store
|
||||||
store, err := memorystore.New(&memorystore.Config{
|
store, err := memorystore.New(&memorystore.Config{
|
||||||
Tokens: rateLimit,
|
Tokens: rateLimit,
|
||||||
@ -66,5 +66,45 @@ func setupRateLimiter(rateLimit uint64) mux.MiddlewareFunc {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalln(err)
|
log.Fatalln(err)
|
||||||
}
|
}
|
||||||
return middleware.Handle
|
return middleware.Handle(next)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupFaviconMiddleware(fav *favicons.Favicons, next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.Header.Get("X-Violet-Raw-Favicon") != "1" {
|
||||||
|
switch req.URL.Path {
|
||||||
|
case "/favicon.svg":
|
||||||
|
icons := fav.GetIcons(req.Host)
|
||||||
|
raw, err := icons.ProduceSvg()
|
||||||
|
if err != nil {
|
||||||
|
utils.RespondVioletError(rw, http.StatusTeapot, "No SVG icon available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = rw.Write(raw)
|
||||||
|
return
|
||||||
|
case "/favicon.png":
|
||||||
|
icons := fav.GetIcons(req.Host)
|
||||||
|
raw, err := icons.ProducePng()
|
||||||
|
if err != nil {
|
||||||
|
utils.RespondVioletError(rw, http.StatusTeapot, "No PNG icon available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = rw.Write(raw)
|
||||||
|
return
|
||||||
|
case "/favicon.ico":
|
||||||
|
icons := fav.GetIcons(req.Host)
|
||||||
|
raw, err := icons.ProduceIco()
|
||||||
|
if err != nil {
|
||||||
|
utils.RespondVioletError(rw, http.StatusTeapot, "No ICO icon available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = rw.Write(raw)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
next.ServeHTTP(rw, req)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
187
target/route.go
187
target/route.go
@ -1,12 +1,17 @@
|
|||||||
package target
|
package target
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/MrMelon54/violet/proxy"
|
"github.com/MrMelon54/violet/proxy"
|
||||||
"github.com/MrMelon54/violet/utils"
|
"github.com/MrMelon54/violet/utils"
|
||||||
"github.com/rs/cors"
|
"github.com/rs/cors"
|
||||||
|
"golang.org/x/net/http/httpguts"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/textproto"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
@ -31,23 +36,20 @@ var serveApiCors = cors.New(cors.Options{
|
|||||||
// Route is a target used by the router to manage forwarding traffic to an
|
// Route is a target used by the router to manage forwarding traffic to an
|
||||||
// internal server using the specified configuration.
|
// internal server using the specified configuration.
|
||||||
type Route struct {
|
type Route struct {
|
||||||
Pre bool // if the path has had a prefix removed
|
Pre bool // if the path has had a prefix removed
|
||||||
Host string // target host
|
Host string // target host
|
||||||
Port int // target port
|
Port int // target port
|
||||||
Path string // target path (possibly a prefix or absolute)
|
Path string // target path (possibly a prefix or absolute)
|
||||||
Abs bool // if the path is a prefix or absolute
|
Abs bool // if the path is a prefix or absolute
|
||||||
Cors bool // add CORS headers
|
Cors bool // add CORS headers
|
||||||
SecureMode bool // use HTTPS internally
|
SecureMode bool // use HTTPS internally
|
||||||
ForwardHost bool // forward host header internally
|
ForwardHost bool // forward host header internally
|
||||||
ForwardAddr bool // forward remote address
|
ForwardAddr bool // forward remote address
|
||||||
IgnoreCert bool // ignore self-cert
|
IgnoreCert bool // ignore self-cert
|
||||||
Headers http.Header // extra headers
|
Headers http.Header // extra headers
|
||||||
Proxy http.Handler // reverse proxy handler
|
Proxy *proxy.HybridTransport // reverse proxy handler
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsIgnoreCert returns true if IgnoreCert is enabled.
|
|
||||||
func (r Route) IsIgnoreCert() bool { return r.IgnoreCert }
|
|
||||||
|
|
||||||
// UpdateHeaders takes an existing set of headers and overwrites them with the
|
// UpdateHeaders takes an existing set of headers and overwrites them with the
|
||||||
// extra headers.
|
// extra headers.
|
||||||
func (r Route) UpdateHeaders(header http.Header) {
|
func (r Route) UpdateHeaders(header http.Header) {
|
||||||
@ -122,7 +124,7 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
req2, err := http.NewRequest(req.Method, u.String(), req.Body)
|
req2, err := http.NewRequest(req.Method, u.String(), req.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[ServeRoute::ServeHTTP()] Error generating new request: %s\n", err)
|
log.Printf("[ServeRoute::ServeHTTP()] Error generating new request: %s\n", err)
|
||||||
utils.RespondHttpStatus(rw, http.StatusBadGateway)
|
utils.RespondVioletError(rw, http.StatusBadGateway, "error generating new request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -153,11 +155,162 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
req2.Header.Add("X-Forwarded-For", req.RemoteAddr)
|
req2.Header.Add("X-Forwarded-For", req.RemoteAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// adds extra request metadata
|
||||||
|
r.internalReverseProxyMeta(rw, req)
|
||||||
|
|
||||||
// serve request with reverse proxy
|
// serve request with reverse proxy
|
||||||
r.Proxy.ServeHTTP(rw, proxy.SetReverseProxyHost(req2, r))
|
var resp *http.Response
|
||||||
|
if r.IgnoreCert {
|
||||||
|
resp, err = r.Proxy.InsecureRoundTrip(req2)
|
||||||
|
} else {
|
||||||
|
resp, err = r.Proxy.SecureRoundTrip(req2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy headers and status code
|
||||||
|
copyHeader(rw.Header(), resp.Header)
|
||||||
|
rw.WriteHeader(resp.StatusCode)
|
||||||
|
|
||||||
|
// copy body
|
||||||
|
if resp.Body != nil {
|
||||||
|
_, err := io.Copy(rw, resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
// hijack and close upon error
|
||||||
|
if h, ok := rw.(http.Hijacker); ok {
|
||||||
|
hijack, _, err := h.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = hijack.Close()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// internalReverseProxyMeta is mainly built from code copied from httputil.ReverseProxy,
|
||||||
|
// due to the highly custom nature of this reverse proxy software we use a copy
|
||||||
|
// of the code instead of the full httputil implementation to prevent overhead
|
||||||
|
// from the more generic implementation
|
||||||
|
func (r Route) internalReverseProxyMeta(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
outreq := req.Clone(context.Background())
|
||||||
|
if req.ContentLength == 0 {
|
||||||
|
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
|
||||||
|
}
|
||||||
|
if outreq.Body != nil {
|
||||||
|
// Reading from the request body after returning from a handler is not
|
||||||
|
// allowed, and the RoundTrip goroutine that reads the Body can outlive
|
||||||
|
// this handler. This can lead to a crash if the handler panics (see
|
||||||
|
// Issue 46866). Although calling Close doesn't guarantee there isn't
|
||||||
|
// any Read in flight after the handle returns, in practice it's safe to
|
||||||
|
// read after closing it.
|
||||||
|
defer outreq.Body.Close()
|
||||||
|
}
|
||||||
|
if outreq.Header == nil {
|
||||||
|
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
|
||||||
|
}
|
||||||
|
|
||||||
|
reqUpType := upgradeType(outreq.Header)
|
||||||
|
if !asciiIsPrint(reqUpType) {
|
||||||
|
utils.RespondVioletError(rw, http.StatusBadRequest, fmt.Sprintf("client tried to switch to invalid protocol %q", reqUpType))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
removeHopByHopHeaders(outreq.Header)
|
||||||
|
|
||||||
|
// Issue 21096: tell backend applications that care about trailer support
|
||||||
|
// that we support trailers. (We do, but we don't go out of our way to
|
||||||
|
// advertise that unless the incoming client request thought it was worth
|
||||||
|
// mentioning.) Note that we look at req.Header, not outreq.Header, since
|
||||||
|
// the latter has passed through removeHopByHopHeaders.
|
||||||
|
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
|
||||||
|
outreq.Header.Set("Te", "trailers")
|
||||||
|
}
|
||||||
|
|
||||||
|
// After stripping all the hop-by-hop connection headers above, add back any
|
||||||
|
// necessary for protocol upgrades, such as for websockets.
|
||||||
|
if reqUpType != "" {
|
||||||
|
outreq.Header.Set("Connection", "Upgrade")
|
||||||
|
outreq.Header.Set("Upgrade", reqUpType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||||
|
// If we aren't the first proxy retain prior
|
||||||
|
// X-Forwarded-For information as a comma+space
|
||||||
|
// separated list and fold multiple headers into one.
|
||||||
|
prior, ok := outreq.Header["X-Forwarded-For"]
|
||||||
|
omit := ok && prior == nil // Issue 38079: nil now means don't populate the header
|
||||||
|
if len(prior) > 0 {
|
||||||
|
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||||
|
}
|
||||||
|
if !omit {
|
||||||
|
outreq.Header.Set("X-Forwarded-For", clientIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// String outputs a debug string for the route.
|
// String outputs a debug string for the route.
|
||||||
func (r Route) String() string {
|
func (r Route) String() string {
|
||||||
return fmt.Sprintf("%#v", r)
|
return fmt.Sprintf("%#v", r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// copyHeader copies all headers from src to dst
|
||||||
|
func copyHeader(dst, src http.Header) {
|
||||||
|
for k, vv := range src {
|
||||||
|
for _, v := range vv {
|
||||||
|
dst.Add(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateType returns the value of upgrade from http.Header
|
||||||
|
func upgradeType(h http.Header) string {
|
||||||
|
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return h.Get("Upgrade")
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPrint returns whether s is ASCII and printable according to
|
||||||
|
// https://tools.ietf.org/html/rfc20#section-4.2.
|
||||||
|
func asciiIsPrint(s string) bool {
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
if s[i] < ' ' || s[i] > '~' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||||
|
// As of RFC 7230, hop-by-hop headers are required to appear in the
|
||||||
|
// Connection header field. These are the headers defined by the
|
||||||
|
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
|
||||||
|
// compatibility.
|
||||||
|
var hopHeaders = []string{
|
||||||
|
"Connection",
|
||||||
|
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||||
|
"Keep-Alive",
|
||||||
|
"Proxy-Authenticate",
|
||||||
|
"Proxy-Authorization",
|
||||||
|
"Te", // canonicalized version of "TE"
|
||||||
|
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
|
||||||
|
"Transfer-Encoding",
|
||||||
|
"Upgrade",
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeHopByHopHeaders removes the hop-by-hop headers defined in hopHeaders
|
||||||
|
func removeHopByHopHeaders(h http.Header) {
|
||||||
|
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
|
||||||
|
for _, f := range h["Connection"] {
|
||||||
|
for _, sf := range strings.Split(f, ",") {
|
||||||
|
if sf = textproto.TrimString(sf); sf != "" {
|
||||||
|
h.Del(sf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
|
||||||
|
// This behavior is superseded by the RFC 7230 Connection header, but
|
||||||
|
// preserve it for backwards compatibility.
|
||||||
|
for _, f := range hopHeaders {
|
||||||
|
h.Del(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package target
|
package target
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/MrMelon54/violet/proxy"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@ -9,14 +10,17 @@ import (
|
|||||||
|
|
||||||
type proxyTester struct {
|
type proxyTester struct {
|
||||||
got bool
|
got bool
|
||||||
rw http.ResponseWriter
|
|
||||||
req *http.Request
|
req *http.Request
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *proxyTester) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (p *proxyTester) makeHybridTransport() *proxy.HybridTransport {
|
||||||
|
return proxy.NewHybridTransportWithCalls(p, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *proxyTester) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
p.got = true
|
p.got = true
|
||||||
p.rw = rw
|
|
||||||
p.req = req
|
p.req = req
|
||||||
|
return &http.Response{StatusCode: http.StatusOK}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoute_FullHost(t *testing.T) {
|
func TestRoute_FullHost(t *testing.T) {
|
||||||
@ -38,7 +42,7 @@ func TestRoute_ServeHTTP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, i := range a {
|
for _, i := range a {
|
||||||
pt := &proxyTester{}
|
pt := &proxyTester{}
|
||||||
i.Proxy = pt
|
i.Proxy = pt.makeHybridTransport()
|
||||||
res := httptest.NewRecorder()
|
res := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(http.MethodGet, "https://www.example.com/hello/world", nil)
|
req := httptest.NewRequest(http.MethodGet, "https://www.example.com/hello/world", nil)
|
||||||
i.ServeHTTP(res, req)
|
i.ServeHTTP(res, req)
|
||||||
@ -62,7 +66,7 @@ func TestRoute_ServeHTTP_Cors(t *testing.T) {
|
|||||||
res := httptest.NewRecorder()
|
res := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(http.MethodOptions, "https://www.example.com/test", nil)
|
req := httptest.NewRequest(http.MethodOptions, "https://www.example.com/test", nil)
|
||||||
req.Header.Set("Origin", "https://test.example.com")
|
req.Header.Set("Origin", "https://test.example.com")
|
||||||
i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt}
|
i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt.makeHybridTransport()}
|
||||||
i.ServeHTTP(res, req)
|
i.ServeHTTP(res, req)
|
||||||
|
|
||||||
assert.True(t, pt.got)
|
assert.True(t, pt.got)
|
||||||
|
@ -9,3 +9,8 @@ import (
|
|||||||
func RespondHttpStatus(rw http.ResponseWriter, status int) {
|
func RespondHttpStatus(rw http.ResponseWriter, status int) {
|
||||||
http.Error(rw, fmt.Sprintf("%d %s\n", status, http.StatusText(status)), status)
|
http.Error(rw, fmt.Sprintf("%d %s\n", status, http.StatusText(status)), status)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RespondVioletError(rw http.ResponseWriter, status int, msg string) {
|
||||||
|
rw.Header().Set("X-Violet-Error", msg)
|
||||||
|
RespondHttpStatus(rw, status)
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user