Yes rewrote some stuff

This commit is contained in:
Melon 2023-06-03 19:33:06 +01:00
parent 9899d67d50
commit 1f487eb80c
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
29 changed files with 715 additions and 406 deletions

View File

@ -1,12 +0,0 @@
module benchmarks
go 1.20
require (
github.com/MrMelon54/violet v0.0.0-20230419182034-77d570ac1e6d
github.com/gorilla/mux v1.8.0
)
require github.com/MrMelon54/trie v0.0.2 // indirect
replace github.com/MrMelon54/violet => ../

View File

@ -1,8 +0,0 @@
github.com/MrMelon54/trie v0.0.2 h1:ZXWcX5ij62O9K4I/anuHmVg8L3tF0UGdlPceAASwKEY=
github.com/MrMelon54/trie v0.0.2/go.mod h1:sGCGOcqb+DxSxvHgSOpbpkmA7mFZR47YDExy9OCbVZI=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@ -1,42 +0,0 @@
package benchmarks
import (
"github.com/MrMelon54/violet/router"
"github.com/MrMelon54/violet/target"
gorillaRouter "github.com/gorilla/mux"
"net/http"
"net/http/httptest"
"testing"
)
func benchRequest(b *testing.B, router http.Handler, r *http.Request) {
w := httptest.NewRecorder()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
router.ServeHTTP(w, r)
}
if w.Header().Get("Location") != "https://example.com" {
b.Fatal("Location: ", w.Header().Get("Location"), " != https://example.com")
}
}
func BenchmarkVioletRouter(b *testing.B) {
r := router.New(nil)
r.AddRedirect("*.example.com", "", target.Redirect{
Pre: true,
Host: "example.com",
Code: http.StatusPermanentRedirect,
})
benchRequest(b, r, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil))
}
func BenchmarkGorillaMux(b *testing.B) {
r := gorillaRouter.NewRouter()
r.Host("{subdomain}.example.com").Handler(target.Redirect{
Pre: true,
Host: "example.com",
Code: http.StatusPermanentRedirect,
})
benchRequest(b, r, httptest.NewRequest(http.MethodGet, "https://www.example.com/", nil))
}

View File

@ -62,7 +62,7 @@ func main() {
allowedDomains := domains.New(db) // load allowed domains
allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath), *selfSigned) // load certificate manager
reverseProxy := proxy.CreateHybridReverseProxy() // load reverse proxy
reverseProxy := proxy.NewHybridTransport() // load reverse proxy
dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider
dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider
dynamicRouter := router.NewManager(db, reverseProxy) // load dynamic router manager

View File

@ -48,12 +48,15 @@ func (d *Domains) IsValid(host string) bool {
defer d.s.RUnlock()
// check root domains `www.example.com`, `example.com`, `com`
// TODO: could be faster using indexes and cropping the string?
n := strings.Split(domain, ".")
for i := 0; i < len(n); i++ {
if _, ok := d.m[strings.Join(n[i:], ".")]; ok {
for len(domain) > 0 {
if _, ok := d.m[domain]; ok {
return true
}
n := strings.IndexByte(domain, '.')
if n == -1 {
break
}
domain = domain[n+1:]
}
return false
}

45
domains/domains_test.go Normal file
View File

@ -0,0 +1,45 @@
package domains
import (
"database/sql"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"testing"
)
func TestDomainsNew(t *testing.T) {
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
assert.NoError(t, err)
domains := New(db)
_, err = db.Exec("insert into domains (domain, active) values (?, ?)", "example.com", 1)
assert.NoError(t, err)
domains.Compile()
if _, ok := domains.m["example.com"]; ok {
assert.True(t, ok)
}
if _, ok := domains.m["www.example.com"]; !ok {
assert.False(t, ok)
}
}
func TestDomains_IsValid(t *testing.T) {
// open sqlite database
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
assert.NoError(t, err)
domains := New(db)
_, err = domains.db.Exec("insert into domains (domain, active) values (?, ?)", "example.com", 1)
assert.NoError(t, err)
domains.s.Lock()
assert.NoError(t, domains.internalCompile(domains.m))
domains.s.Unlock()
assert.True(t, domains.IsValid("example.com"))
assert.True(t, domains.IsValid("www.example.com"))
assert.False(t, domains.IsValid("notexample.com"))
assert.False(t, domains.IsValid("www.notexample.com"))
}

View File

@ -29,6 +29,7 @@ func New(dir fs.FS) *ErrorPages {
generic: func(rw http.ResponseWriter, code int) {
// if status text is empty then the code is unknown
a := http.StatusText(code)
fmt.Printf("%d - %s\n", code, a)
if a != "" {
// output in "xxx Error Text" format
http.Error(rw, fmt.Sprintf("%d %s\n", code, a), code)
@ -64,11 +65,13 @@ func (e *ErrorPages) Compile() {
errorPageMap := make(map[int]func(rw http.ResponseWriter))
// compile map and check errors
if e.dir != nil {
err := e.internalCompile(errorPageMap)
if err != nil {
log.Printf("[Certs] Compile failed: %s\n", err)
return
}
}
// lock while replacing the map
e.s.Lock()

View File

@ -0,0 +1,31 @@
package error_pages
import (
"github.com/stretchr/testify/assert"
"io"
"net/http"
"net/http/httptest"
"testing"
)
func TestErrorPages_ServeError(t *testing.T) {
errorPages := New(nil)
rec := httptest.NewRecorder()
errorPages.ServeError(rec, http.StatusTeapot)
res := rec.Result()
assert.Equal(t, http.StatusTeapot, res.StatusCode)
assert.Equal(t, "418 I'm a teapot", res.Status)
a, err := io.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "418 I'm a teapot\n\n", string(a))
rec = httptest.NewRecorder()
errorPages.ServeError(rec, 469)
res = rec.Result()
assert.Equal(t, 469, res.StatusCode)
assert.Equal(t, "469 ", res.Status)
a, err = io.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "469 Unknown Error Code\n\n", string(a))
}

BIN
favicons/example.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

BIN
favicons/example.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

4
favicons/example.svg Normal file
View File

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg height="100" width="100" viewBox="0 0 100 100" xmlns="http://www.w3.org/2000/svg">
<circle cx="50" cy="50" r="40" stroke="black" stroke-width="3" fill="red"/>
</svg>

After

Width:  |  Height:  |  Size: 214 B

17
favicons/favicon-image.go Normal file
View File

@ -0,0 +1,17 @@
package favicons
// FaviconImage stores the url, hash and raw bytes of an image
type FaviconImage struct {
Url string
Hash string
Raw []byte
}
// CreateFaviconImage outputs a FaviconImage with the specified URL or nil if
// the URL is an empty string.
func CreateFaviconImage(url string) *FaviconImage {
if url == "" {
return nil
}
return &FaviconImage{Url: url}
}

144
favicons/favicon-list.go Normal file
View File

@ -0,0 +1,144 @@
package favicons
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"fmt"
"github.com/mrmelon54/png2ico"
"image/png"
"io"
"net/http"
)
// FaviconList contains the ico, png and svg icons for separate favicons
type FaviconList struct {
Ico *FaviconImage // can be generated from png with wrapper
Png *FaviconImage // can be generated from svg with inkscape
Svg *FaviconImage
}
// ProduceIco outputs the bytes of the ico icon or an error
func (l *FaviconList) ProduceIco() ([]byte, error) {
if l.Ico == nil {
return nil, ErrFaviconNotFound
}
return l.Ico.Raw, nil
}
// ProducePng outputs the bytes of the png icon or an error
func (l *FaviconList) ProducePng() ([]byte, error) {
if l.Png == nil {
return nil, ErrFaviconNotFound
}
return l.Png.Raw, nil
}
// ProduceSvg outputs the bytes of the svg icon or an error
func (l *FaviconList) ProduceSvg() ([]byte, error) {
if l.Svg == nil {
return nil, ErrFaviconNotFound
}
return l.Svg.Raw, nil
}
// PreProcess takes an input of the svg2png conversion function and outputs
// an error if the SVG, PNG or ICO fails to download or generate
func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error {
var err error
// SVG
if l.Svg != nil {
// download SVG
l.Svg.Raw, err = getFaviconViaRequest(l.Svg.Url)
if err != nil {
return fmt.Errorf("[Favicons] Failed to fetch SVG icon: %w", err)
}
l.Svg.Hash = hex.EncodeToString(sha256.New().Sum(l.Svg.Raw))
}
// PNG
if l.Png != nil {
// download PNG
l.Png.Raw, err = getFaviconViaRequest(l.Png.Url)
if err != nil {
return fmt.Errorf("[Favicons] Failed to fetch PNG icon: %w", err)
}
} else if l.Svg != nil {
// generate PNG from SVG
l.Png = &FaviconImage{}
l.Png.Raw, err = convert(l.Svg.Raw)
if err != nil {
return fmt.Errorf("[Favicons] Failed to generate PNG icon: %w", err)
}
}
// ICO
if l.Ico != nil {
// download ICO
l.Ico.Raw, err = getFaviconViaRequest(l.Ico.Url)
if err != nil {
return fmt.Errorf("[Favicons] Failed to fetch ICO icon: %w", err)
}
} else if l.Png != nil {
// generate ICO from PNG
l.Ico = &FaviconImage{}
decode, err := png.Decode(bytes.NewReader(l.Png.Raw))
if err != nil {
return fmt.Errorf("[Favicons] Failed to decode PNG icon: %w", err)
}
b := decode.Bounds()
l.Ico.Raw, err = png2ico.ConvertPngToIco(l.Png.Raw, b.Dx(), b.Dy())
if err != nil {
return fmt.Errorf("[Favicons] Failed to generate ICO icon: %w", err)
}
}
// generate sha256 hashes for svg, png and ico
l.genSha256()
return nil
}
// genSha256 generates sha256 hashes
func (l *FaviconList) genSha256() {
if l.Svg != nil {
l.Svg.Hash = genSha256(l.Svg.Raw)
}
if l.Png != nil {
l.Png.Hash = genSha256(l.Png.Raw)
}
if l.Ico != nil {
l.Ico.Hash = genSha256(l.Ico.Raw)
}
}
// getFaviconViaRequest uses the standard http request library to download
// icons, outputs the raw bytes from the download or an error.
var getFaviconViaRequest = func(url string) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("[Favicons] Failed to send request '%s': %w", url, err)
}
req.Header.Set("X-Violet-Raw-Favicon", "1")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("[Favicons] Failed to do request '%s': %w", url, err)
}
rawBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("[Favicons] Failed to read response '%s': %w", url, err)
}
return rawBody, nil
}
// genSha256 generates a sha256 hash as a hex encoded string
func genSha256(in []byte) string {
// create sha256 generator and write to it
h := sha256.New()
_, err := h.Write(in)
if err != nil {
return ""
}
// encode as hex
return hex.EncodeToString(h.Sum(nil))
}

View File

@ -0,0 +1,33 @@
package favicons
import (
"bytes"
"github.com/stretchr/testify/assert"
"testing"
)
func TestFaviconList_PreProcess(t *testing.T) {
getFaviconViaRequest = func(_ string) ([]byte, error) {
return exampleSvg, nil
}
icons := &FaviconList{Svg: &FaviconImage{Url: "https://example.com/assets/logo.svg"}}
assert.NoError(t, icons.PreProcess(func(in []byte) ([]byte, error) {
return svg2png("inkscape", in)
}))
iconSvg, err := icons.ProduceSvg()
assert.NoError(t, err)
iconPng, err := icons.ProducePng()
assert.NoError(t, err)
iconIco, err := icons.ProduceIco()
assert.NoError(t, err)
assert.Equal(t, "https://example.com/assets/logo.svg", icons.Svg.Url)
assert.Equal(t, "74cdc17d0502a690941799c327d9ca1ed042e76c784def43a42937f2eed270b4", icons.Svg.Hash)
assert.Equal(t, "84841341dafbb1e54c62d160dfc5e48c3f8db4b22265a4dbe2e0318debf9b670", icons.Png.Hash)
assert.Equal(t, "33fc667fdb0e32305f2ee27e7dd7feb781cc776638d0971db7e18cc6335a15c7", icons.Ico.Hash)
assert.Equal(t, 0, bytes.Compare(exampleSvg, iconSvg))
assert.Equal(t, 0, bytes.Compare(examplePng, iconPng))
assert.Equal(t, 0, bytes.Compare(exampleIco, iconIco))
}

View File

@ -1,18 +1,11 @@
package favicons
import (
"bytes"
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"github.com/mrmelon54/png2ico"
"golang.org/x/sync/errgroup"
"image/png"
"io"
"log"
"net/http"
"sync"
)
@ -121,150 +114,3 @@ func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error {
func (f *Favicons) convertSvgToPng(in []byte) ([]byte, error) {
return svg2png(f.cmd, in)
}
// FaviconList contains the ico, png and svg icons for separate favicons
type FaviconList struct {
Ico *FaviconImage // can be generated from png with wrapper
Png *FaviconImage // can be generated from svg with inkscape
Svg *FaviconImage
}
// ProduceIco outputs the bytes of the ico icon or an error
func (l *FaviconList) ProduceIco() ([]byte, error) {
if l.Ico == nil {
return nil, ErrFaviconNotFound
}
return l.Ico.Raw, nil
}
// ProducePng outputs the bytes of the png icon or an error
func (l *FaviconList) ProducePng() ([]byte, error) {
if l.Png == nil {
return nil, ErrFaviconNotFound
}
return l.Png.Raw, nil
}
// ProduceSvg outputs the bytes of the svg icon or an error
func (l *FaviconList) ProduceSvg() ([]byte, error) {
if l.Svg == nil {
return nil, ErrFaviconNotFound
}
return l.Svg.Raw, nil
}
// PreProcess takes an input of the svg2png conversion function and outputs
// an error if the SVG, PNG or ICO fails to download or generate
func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error {
var err error
// SVG
if l.Svg != nil {
// download SVG
l.Svg.Raw, err = getFaviconViaRequest(l.Svg.Url)
if err != nil {
return fmt.Errorf("[Favicons] Failed to fetch SVG icon: %w", err)
}
l.Svg.Hash = hex.EncodeToString(sha256.New().Sum(l.Svg.Raw))
}
// PNG
if l.Png != nil {
// download PNG
l.Png.Raw, err = getFaviconViaRequest(l.Png.Url)
if err != nil {
return fmt.Errorf("[Favicons] Failed to fetch PNG icon: %w", err)
}
} else if l.Svg != nil {
// generate PNG from SVG
l.Png = &FaviconImage{}
l.Png.Raw, err = convert(l.Svg.Raw)
if err != nil {
return fmt.Errorf("[Favicons] Failed to generate PNG icon: %w", err)
}
}
// ICO
if l.Ico != nil {
// download ICO
l.Ico.Raw, err = getFaviconViaRequest(l.Ico.Url)
if err != nil {
return fmt.Errorf("[Favicons] Failed to fetch ICO icon: %w", err)
}
} else if l.Png != nil {
// generate ICO from PNG
l.Ico = &FaviconImage{}
decode, err := png.Decode(bytes.NewReader(l.Png.Raw))
if err != nil {
return fmt.Errorf("[Favicons] Failed to decode PNG icon: %w", err)
}
b := decode.Bounds()
l.Ico.Raw, err = png2ico.ConvertPngToIco(l.Png.Raw, b.Dx(), b.Dy())
if err != nil {
return fmt.Errorf("[Favicons] Failed to generate ICO icon: %w", err)
}
}
// generate sha256 hashes for svg, png and ico
l.genSha256()
return nil
}
// genSha256 generates sha256 hashes
func (l *FaviconList) genSha256() {
if l.Svg != nil {
l.Svg.Hash = genSha256(l.Svg.Raw)
}
if l.Png != nil {
l.Png.Hash = genSha256(l.Png.Raw)
}
if l.Ico != nil {
l.Ico.Hash = genSha256(l.Ico.Raw)
}
}
// getFaviconViaRequest uses the standard http request library to download
// icons, outputs the raw bytes from the download or an error.
func getFaviconViaRequest(url string) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("[Favicons] Failed to send request '%s': %w", url, err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("[Favicons] Failed to do request '%s': %w", url, err)
}
rawBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("[Favicons] Failed to read response '%s': %w", url, err)
}
return rawBody, nil
}
// genSha256 generates a sha256 hash as a hex encoded string
func genSha256(in []byte) string {
// create sha256 generator and write to it
h := sha256.New()
_, err := h.Write(in)
if err != nil {
return ""
}
// encode as hex
return hex.EncodeToString(h.Sum(nil))
}
// FaviconImage stores the url, hash and raw bytes of an image
type FaviconImage struct {
Url string
Hash string
Raw []byte
}
// CreateFaviconImage outputs a FaviconImage with the specified URL or nil if
// the URL is an empty string.
func CreateFaviconImage(url string) *FaviconImage {
if url == "" {
return nil
}
return &FaviconImage{Url: url}
}

51
favicons/favicons_test.go Normal file
View File

@ -0,0 +1,51 @@
package favicons
import (
"bytes"
"database/sql"
_ "embed"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"testing"
)
var (
//go:embed example.svg
exampleSvg []byte
//go:embed example.png
examplePng []byte
//go:embed example.ico
exampleIco []byte
)
func TestFaviconsNew(t *testing.T) {
getFaviconViaRequest = func(_ string) ([]byte, error) { return exampleSvg, nil }
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
assert.NoError(t, err)
favicons := New(db, "inkscape")
_, err = db.Exec("insert into favicons (host, svg, png, ico) values (?, ?, ?, ?)", "example.com", "https://example.com/assets/logo.svg", "", "")
assert.NoError(t, err)
favicons.cLock.Lock()
assert.NoError(t, favicons.internalCompile(favicons.faviconMap))
favicons.cLock.Unlock()
icons := favicons.GetIcons("example.com")
iconSvg, err := icons.ProduceSvg()
assert.NoError(t, err)
iconPng, err := icons.ProducePng()
assert.NoError(t, err)
iconIco, err := icons.ProduceIco()
assert.NoError(t, err)
assert.Equal(t, "https://example.com/assets/logo.svg", icons.Svg.Url)
assert.Equal(t, "74cdc17d0502a690941799c327d9ca1ed042e76c784def43a42937f2eed270b4", icons.Svg.Hash)
assert.Equal(t, "84841341dafbb1e54c62d160dfc5e48c3f8db4b22265a4dbe2e0318debf9b670", icons.Png.Hash)
assert.Equal(t, "33fc667fdb0e32305f2ee27e7dd7feb781cc776638d0971db7e18cc6335a15c7", icons.Ico.Hash)
assert.Equal(t, 0, bytes.Compare(exampleSvg, iconSvg))
assert.Equal(t, 0, bytes.Compare(examplePng, iconPng))
assert.Equal(t, 0, bytes.Compare(exampleIco, iconIco))
}

2
go.mod
View File

@ -6,7 +6,6 @@ require (
code.mrmelon54.com/melon/certgen v0.0.0-20220830133534-0fb4cb7e67d1
code.mrmelon54.com/melon/summer-utils v0.0.3
github.com/MrMelon54/trie v0.0.2
github.com/gorilla/mux v1.8.0
github.com/julienschmidt/httprouter v1.3.0
github.com/mattn/go-sqlite3 v1.14.16
github.com/mrmelon54/mjwt v0.0.1
@ -24,6 +23,7 @@ require (
github.com/kr/pretty v0.1.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/text v0.9.0 // indirect
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

4
go.sum
View File

@ -9,8 +9,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
@ -43,6 +41,8 @@ golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

73
proxy/hybrid-transport.go Normal file
View File

@ -0,0 +1,73 @@
package proxy
import (
"crypto/tls"
"net"
"net/http"
"sync"
"time"
)
type HybridTransport struct {
baseDialer *net.Dialer
normalTransport http.RoundTripper
insecureTransport http.RoundTripper
socksSync *sync.RWMutex
socksTransport map[string]http.RoundTripper
}
// NewHybridTransport creates a new hybrid transport
func NewHybridTransport() *HybridTransport {
return NewHybridTransportWithCalls(nil, nil)
}
// NewHybridTransportWithCalls creates new hybrid transport with custom normal
// and insecure http.RoundTripper functions.
//
// NewHybridTransportWithCalls(nil, nil) is equivalent to NewHybridTransport()
func NewHybridTransportWithCalls(normal, insecure http.RoundTripper) *HybridTransport {
h := &HybridTransport{
baseDialer: &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
},
normalTransport: normal,
insecureTransport: insecure,
}
if h.normalTransport == nil {
h.normalTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: h.baseDialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 15,
TLSHandshakeTimeout: 10 * time.Second,
IdleConnTimeout: 30 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
}
}
if h.insecureTransport == nil {
h.insecureTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: h.baseDialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 15,
TLSHandshakeTimeout: 10 * time.Second,
IdleConnTimeout: 30 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
}
return h
}
// SecureRoundTrip calls the secure transport
func (h *HybridTransport) SecureRoundTrip(req *http.Request) (*http.Response, error) {
return h.normalTransport.RoundTrip(req)
}
// InsecureRoundTrip calls the insecure transport
func (h *HybridTransport) InsecureRoundTrip(req *http.Request) (*http.Response, error) {
return h.insecureTransport.RoundTrip(req)
}

View File

@ -0,0 +1,16 @@
package proxy
import (
"github.com/stretchr/testify/assert"
"net/http"
"testing"
)
func TestNewHybridTransport(t *testing.T) {
h := NewHybridTransport()
req, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
assert.NoError(t, err)
trip, err := h.SecureRoundTrip(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, trip.StatusCode)
}

View File

@ -1,145 +0,0 @@
package proxy
import (
"context"
"crypto/tls"
"errors"
"fmt"
"golang.org/x/net/proxy"
"log"
"net"
"net/http"
"net/http/httputil"
"sync"
"time"
)
type reverseProxyHostKey int
type ReverseProxyContext interface {
IsIgnoreCert() bool
UpdateHeaders(http.Header)
}
func SetReverseProxyHost(req *http.Request, hf ReverseProxyContext) *http.Request {
ctx := req.Context()
ctx2 := context.WithValue(ctx, reverseProxyHostKey(0), hf)
return req.WithContext(ctx2)
}
func CreateHybridReverseProxy() *httputil.ReverseProxy {
return &httputil.ReverseProxy{
Director: func(req *http.Request) {},
Transport: NewHybridTransport(),
ModifyResponse: func(rw *http.Response) error { return nil },
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
log.Printf("[ReverseProxy] Request: %#v\n -- Error: %s\n", req, err)
rw.WriteHeader(http.StatusBadGateway)
_, _ = rw.Write([]byte("502 Bad gateway\n"))
},
}
}
type HybridTransport struct {
baseDialer *net.Dialer
normalTransport http.RoundTripper
insecureTransport http.RoundTripper
socksSync *sync.RWMutex
socksTransport map[string]http.RoundTripper
}
func NewHybridTransport() *HybridTransport {
h := &HybridTransport{
baseDialer: &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
},
socksSync: &sync.RWMutex{},
socksTransport: make(map[string]http.RoundTripper),
}
h.normalTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: h.baseDialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 15,
TLSHandshakeTimeout: 10 * time.Second,
IdleConnTimeout: 30 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
}
h.insecureTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: h.baseDialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 15,
TLSHandshakeTimeout: 10 * time.Second,
IdleConnTimeout: 30 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
return h
}
func (h *HybridTransport) RoundTrip(req *http.Request) (*http.Response, error) {
newHost := req.Context().Value(reverseProxyHostKey(0))
hf, ok := newHost.(ReverseProxyContext)
if !ok {
return nil, errors.New("failed to detect reverse proxy configuration")
}
// Do a round trip using existing transports
var trip *http.Response
var err error
if hf.IsIgnoreCert() {
trip, err = h.insecureTransport.RoundTrip(req)
} else {
trip, err = h.normalTransport.RoundTrip(req)
}
if err != nil {
return nil, err
}
// Override headers
hf.UpdateHeaders(trip.Header)
return trip, nil
}
func (h *HybridTransport) getSocksProxy(addr string, insecure bool) (http.RoundTripper, error) {
if insecure {
addr = "%i-" + addr
}
h.socksSync.RLock()
s, ok := h.socksTransport[addr]
h.socksSync.RUnlock()
if ok {
return s, nil
}
dialer, err := proxy.SOCKS5("tcp", addr, nil, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("cannot connect to the proxy: %s", err)
}
if f, ok := dialer.(proxy.ContextDialer); ok {
t := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: f.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 15,
TLSHandshakeTimeout: 10 * time.Second,
IdleConnTimeout: 30 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
DisableKeepAlives: true,
}
if insecure {
t.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
h.socksSync.Lock()
h.socksTransport[addr] = t
h.socksSync.Unlock()
return t, nil
}
return nil, errors.New("cannot create socks5 dialer")
}

View File

@ -0,0 +1,48 @@
package proxy
import (
"log"
"net/http"
"net/http/httptest"
"net/http/httputil"
"testing"
)
type customTransport struct{}
func (c *customTransport) RoundTrip(_ *http.Request) (*http.Response, error) {
res := httptest.NewRecorder()
res.WriteHeader(http.StatusOK)
res.Write([]byte{0x54, 0x54})
return res.Result(), nil
}
func SetupReverseProxy() *httputil.ReverseProxy {
return &httputil.ReverseProxy{
Director: func(req *http.Request) {},
Transport: &customTransport{},
ModifyResponse: func(rw *http.Response) error { return nil },
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
log.Printf("[ReverseProxy] Request: %#v\n -- Error: %s\n", req, err)
rw.WriteHeader(http.StatusBadGateway)
_, _ = rw.Write([]byte("502 Bad gateway\n"))
},
}
}
func BenchmarkHttpUtilReverseProxy(b *testing.B) {
rev := SetupReverseProxy()
req, _ := http.NewRequest(http.MethodGet, "https://example.com", nil)
for i := 0; i < b.N; i++ {
rec := httptest.NewRecorder()
rev.ServeHTTP(rec, req)
}
}
func BenchmarkCustomTransport(b *testing.B) {
req, _ := http.NewRequest(http.MethodGet, "https://example.com", nil)
t := &customTransport{}
for i := 0; i < b.N; i++ {
_, _ = t.RoundTrip(req)
}
}

View File

@ -4,11 +4,11 @@ import (
"database/sql"
_ "embed"
"fmt"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/target"
"github.com/MrMelon54/violet/utils"
"log"
"net/http"
"net/http/httputil"
"strings"
"sync"
)
@ -19,7 +19,7 @@ type Manager struct {
db *sql.DB
s *sync.RWMutex
r *Router
p *httputil.ReverseProxy
p *proxy.HybridTransport
}
var (
@ -35,7 +35,7 @@ var (
// NewManager create a new manager, initialises the routes and redirects tables
// in the database and runs a first time compile.
func NewManager(db *sql.DB, proxy *httputil.ReverseProxy) *Manager {
func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager {
m := &Manager{
db: db,
s: &sync.RWMutex{},

View File

@ -3,10 +3,10 @@ package router
import (
"fmt"
"github.com/MrMelon54/trie"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/target"
"github.com/MrMelon54/violet/utils"
"net/http"
"net/http/httputil"
"strings"
)
@ -14,10 +14,10 @@ type Router struct {
route map[string]*trie.Trie[target.Route]
redirect map[string]*trie.Trie[target.Redirect]
notFound http.Handler
proxy *httputil.ReverseProxy
proxy *proxy.HybridTransport
}
func New(proxy *httputil.ReverseProxy) *Router {
func New(proxy *proxy.HybridTransport) *Router {
return &Router{
route: make(map[string]*trie.Trie[target.Route]),
redirect: make(map[string]*trie.Trie[target.Redirect]),

View File

@ -14,7 +14,7 @@ import (
// endpoints for the reverse proxy.
//
// `/.well-known/acme-challenge/{token}` is used for outputting answers for
// acme challenges, this is used for Lets Encrypt HTTP verification.
// acme challenges, this is used for Let's Encrypt HTTP verification.
func NewHttpServer(conf *Conf) *http.Server {
r := httprouter.New()
var secureExtend string

View File

@ -3,8 +3,8 @@ package servers
import (
"crypto/tls"
"fmt"
"github.com/MrMelon54/violet/favicons"
"github.com/MrMelon54/violet/utils"
"github.com/gorilla/mux"
"github.com/sethvargo/go-limiter/httplimit"
"github.com/sethvargo/go-limiter/memorystore"
"log"
@ -18,7 +18,7 @@ import (
func NewHttpsServer(conf *Conf) *http.Server {
s := &http.Server{
Addr: conf.HttpsListen,
Handler: setupRateLimiter(300).Middleware(conf.Router),
Handler: setupRateLimiter(300, setupFaviconMiddleware(conf.Favicons, conf.Router)),
DisableGeneralOptionsHandler: false,
TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
// error out on invalid domains
@ -51,7 +51,7 @@ func NewHttpsServer(conf *Conf) *http.Server {
// setupRateLimiter is an internal function to create a middleware to manage
// rate limits.
func setupRateLimiter(rateLimit uint64) mux.MiddlewareFunc {
func setupRateLimiter(rateLimit uint64, next http.Handler) http.Handler {
// create memory store
store, err := memorystore.New(&memorystore.Config{
Tokens: rateLimit,
@ -66,5 +66,45 @@ func setupRateLimiter(rateLimit uint64) mux.MiddlewareFunc {
if err != nil {
log.Fatalln(err)
}
return middleware.Handle
return middleware.Handle(next)
}
func setupFaviconMiddleware(fav *favicons.Favicons, next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Header.Get("X-Violet-Raw-Favicon") != "1" {
switch req.URL.Path {
case "/favicon.svg":
icons := fav.GetIcons(req.Host)
raw, err := icons.ProduceSvg()
if err != nil {
utils.RespondVioletError(rw, http.StatusTeapot, "No SVG icon available")
return
}
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write(raw)
return
case "/favicon.png":
icons := fav.GetIcons(req.Host)
raw, err := icons.ProducePng()
if err != nil {
utils.RespondVioletError(rw, http.StatusTeapot, "No PNG icon available")
return
}
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write(raw)
return
case "/favicon.ico":
icons := fav.GetIcons(req.Host)
raw, err := icons.ProduceIco()
if err != nil {
utils.RespondVioletError(rw, http.StatusTeapot, "No ICO icon available")
return
}
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write(raw)
return
}
}
next.ServeHTTP(rw, req)
})
}

View File

@ -1,12 +1,17 @@
package target
import (
"context"
"fmt"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/utils"
"github.com/rs/cors"
"golang.org/x/net/http/httpguts"
"io"
"log"
"net"
"net/http"
"net/textproto"
"net/url"
"path"
"strings"
@ -42,12 +47,9 @@ type Route struct {
ForwardAddr bool // forward remote address
IgnoreCert bool // ignore self-cert
Headers http.Header // extra headers
Proxy http.Handler // reverse proxy handler
Proxy *proxy.HybridTransport // reverse proxy handler
}
// IsIgnoreCert returns true if IgnoreCert is enabled.
func (r Route) IsIgnoreCert() bool { return r.IgnoreCert }
// UpdateHeaders takes an existing set of headers and overwrites them with the
// extra headers.
func (r Route) UpdateHeaders(header http.Header) {
@ -122,7 +124,7 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
req2, err := http.NewRequest(req.Method, u.String(), req.Body)
if err != nil {
log.Printf("[ServeRoute::ServeHTTP()] Error generating new request: %s\n", err)
utils.RespondHttpStatus(rw, http.StatusBadGateway)
utils.RespondVioletError(rw, http.StatusBadGateway, "error generating new request")
return
}
@ -153,11 +155,162 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
req2.Header.Add("X-Forwarded-For", req.RemoteAddr)
}
// adds extra request metadata
r.internalReverseProxyMeta(rw, req)
// serve request with reverse proxy
r.Proxy.ServeHTTP(rw, proxy.SetReverseProxyHost(req2, r))
var resp *http.Response
if r.IgnoreCert {
resp, err = r.Proxy.InsecureRoundTrip(req2)
} else {
resp, err = r.Proxy.SecureRoundTrip(req2)
}
// copy headers and status code
copyHeader(rw.Header(), resp.Header)
rw.WriteHeader(resp.StatusCode)
// copy body
if resp.Body != nil {
_, err := io.Copy(rw, resp.Body)
if err != nil {
// hijack and close upon error
if h, ok := rw.(http.Hijacker); ok {
hijack, _, err := h.Hijack()
if err != nil {
return
}
_ = hijack.Close()
}
return
}
}
}
// internalReverseProxyMeta is mainly built from code copied from httputil.ReverseProxy,
// due to the highly custom nature of this reverse proxy software we use a copy
// of the code instead of the full httputil implementation to prevent overhead
// from the more generic implementation
func (r Route) internalReverseProxyMeta(rw http.ResponseWriter, req *http.Request) {
outreq := req.Clone(context.Background())
if req.ContentLength == 0 {
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
}
if outreq.Body != nil {
// Reading from the request body after returning from a handler is not
// allowed, and the RoundTrip goroutine that reads the Body can outlive
// this handler. This can lead to a crash if the handler panics (see
// Issue 46866). Although calling Close doesn't guarantee there isn't
// any Read in flight after the handle returns, in practice it's safe to
// read after closing it.
defer outreq.Body.Close()
}
if outreq.Header == nil {
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
}
reqUpType := upgradeType(outreq.Header)
if !asciiIsPrint(reqUpType) {
utils.RespondVioletError(rw, http.StatusBadRequest, fmt.Sprintf("client tried to switch to invalid protocol %q", reqUpType))
return
}
removeHopByHopHeaders(outreq.Header)
// Issue 21096: tell backend applications that care about trailer support
// that we support trailers. (We do, but we don't go out of our way to
// advertise that unless the incoming client request thought it was worth
// mentioning.) Note that we look at req.Header, not outreq.Header, since
// the latter has passed through removeHopByHopHeaders.
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
outreq.Header.Set("Te", "trailers")
}
// After stripping all the hop-by-hop connection headers above, add back any
// necessary for protocol upgrades, such as for websockets.
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
}
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
prior, ok := outreq.Header["X-Forwarded-For"]
omit := ok && prior == nil // Issue 38079: nil now means don't populate the header
if len(prior) > 0 {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
if !omit {
outreq.Header.Set("X-Forwarded-For", clientIP)
}
}
}
// String outputs a debug string for the route.
func (r Route) String() string {
return fmt.Sprintf("%#v", r)
}
// copyHeader copies all headers from src to dst
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
// updateType returns the value of upgrade from http.Header
func upgradeType(h http.Header) string {
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
return ""
}
return h.Get("Upgrade")
}
// IsPrint returns whether s is ASCII and printable according to
// https://tools.ietf.org/html/rfc20#section-4.2.
func asciiIsPrint(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] < ' ' || s[i] > '~' {
return false
}
}
return true
}
// Hop-by-hop headers. These are removed when sent to the backend.
// As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
// compatibility.
var hopHeaders = []string{
"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
"Upgrade",
}
// removeHopByHopHeaders removes the hop-by-hop headers defined in hopHeaders
func removeHopByHopHeaders(h http.Header) {
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
for _, f := range h["Connection"] {
for _, sf := range strings.Split(f, ",") {
if sf = textproto.TrimString(sf); sf != "" {
h.Del(sf)
}
}
}
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
// This behavior is superseded by the RFC 7230 Connection header, but
// preserve it for backwards compatibility.
for _, f := range hopHeaders {
h.Del(f)
}
}

View File

@ -1,6 +1,7 @@
package target
import (
"github.com/MrMelon54/violet/proxy"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
@ -9,14 +10,17 @@ import (
type proxyTester struct {
got bool
rw http.ResponseWriter
req *http.Request
}
func (p *proxyTester) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
func (p *proxyTester) makeHybridTransport() *proxy.HybridTransport {
return proxy.NewHybridTransportWithCalls(p, p)
}
func (p *proxyTester) RoundTrip(req *http.Request) (*http.Response, error) {
p.got = true
p.rw = rw
p.req = req
return &http.Response{StatusCode: http.StatusOK}, nil
}
func TestRoute_FullHost(t *testing.T) {
@ -38,7 +42,7 @@ func TestRoute_ServeHTTP(t *testing.T) {
}
for _, i := range a {
pt := &proxyTester{}
i.Proxy = pt
i.Proxy = pt.makeHybridTransport()
res := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "https://www.example.com/hello/world", nil)
i.ServeHTTP(res, req)
@ -62,7 +66,7 @@ func TestRoute_ServeHTTP_Cors(t *testing.T) {
res := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "https://www.example.com/test", nil)
req.Header.Set("Origin", "https://test.example.com")
i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt}
i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt.makeHybridTransport()}
i.ServeHTTP(res, req)
assert.True(t, pt.got)

View File

@ -9,3 +9,8 @@ import (
func RespondHttpStatus(rw http.ResponseWriter, status int) {
http.Error(rw, fmt.Sprintf("%d %s\n", status, http.StatusText(status)), status)
}
func RespondVioletError(rw http.ResponseWriter, status int, msg string) {
rw.Header().Set("X-Violet-Error", msg)
RespondHttpStatus(rw, status)
}