mirror of
https://github.com/1f349/violet.git
synced 2024-11-08 10:06:53 +00:00
Yes rewrote some stuff
This commit is contained in:
parent
9899d67d50
commit
1f487eb80c
@ -1,12 +0,0 @@
|
||||
module benchmarks
|
||||
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/MrMelon54/violet v0.0.0-20230419182034-77d570ac1e6d
|
||||
github.com/gorilla/mux v1.8.0
|
||||
)
|
||||
|
||||
require github.com/MrMelon54/trie v0.0.2 // indirect
|
||||
|
||||
replace github.com/MrMelon54/violet => ../
|
@ -1,8 +0,0 @@
|
||||
github.com/MrMelon54/trie v0.0.2 h1:ZXWcX5ij62O9K4I/anuHmVg8L3tF0UGdlPceAASwKEY=
|
||||
github.com/MrMelon54/trie v0.0.2/go.mod h1:sGCGOcqb+DxSxvHgSOpbpkmA7mFZR47YDExy9OCbVZI=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
|
||||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
@ -1,42 +0,0 @@
|
||||
package benchmarks
|
||||
|
||||
import (
|
||||
"github.com/MrMelon54/violet/router"
|
||||
"github.com/MrMelon54/violet/target"
|
||||
gorillaRouter "github.com/gorilla/mux"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func benchRequest(b *testing.B, router http.Handler, r *http.Request) {
|
||||
w := httptest.NewRecorder()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
router.ServeHTTP(w, r)
|
||||
}
|
||||
if w.Header().Get("Location") != "https://example.com" {
|
||||
b.Fatal("Location: ", w.Header().Get("Location"), " != https://example.com")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkVioletRouter(b *testing.B) {
|
||||
r := router.New(nil)
|
||||
r.AddRedirect("*.example.com", "", target.Redirect{
|
||||
Pre: true,
|
||||
Host: "example.com",
|
||||
Code: http.StatusPermanentRedirect,
|
||||
})
|
||||
benchRequest(b, r, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil))
|
||||
}
|
||||
|
||||
func BenchmarkGorillaMux(b *testing.B) {
|
||||
r := gorillaRouter.NewRouter()
|
||||
r.Host("{subdomain}.example.com").Handler(target.Redirect{
|
||||
Pre: true,
|
||||
Host: "example.com",
|
||||
Code: http.StatusPermanentRedirect,
|
||||
})
|
||||
benchRequest(b, r, httptest.NewRequest(http.MethodGet, "https://www.example.com/", nil))
|
||||
}
|
@ -62,7 +62,7 @@ func main() {
|
||||
|
||||
allowedDomains := domains.New(db) // load allowed domains
|
||||
allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath), *selfSigned) // load certificate manager
|
||||
reverseProxy := proxy.CreateHybridReverseProxy() // load reverse proxy
|
||||
reverseProxy := proxy.NewHybridTransport() // load reverse proxy
|
||||
dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider
|
||||
dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider
|
||||
dynamicRouter := router.NewManager(db, reverseProxy) // load dynamic router manager
|
||||
|
@ -48,12 +48,15 @@ func (d *Domains) IsValid(host string) bool {
|
||||
defer d.s.RUnlock()
|
||||
|
||||
// check root domains `www.example.com`, `example.com`, `com`
|
||||
// TODO: could be faster using indexes and cropping the string?
|
||||
n := strings.Split(domain, ".")
|
||||
for i := 0; i < len(n); i++ {
|
||||
if _, ok := d.m[strings.Join(n[i:], ".")]; ok {
|
||||
for len(domain) > 0 {
|
||||
if _, ok := d.m[domain]; ok {
|
||||
return true
|
||||
}
|
||||
n := strings.IndexByte(domain, '.')
|
||||
if n == -1 {
|
||||
break
|
||||
}
|
||||
domain = domain[n+1:]
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
45
domains/domains_test.go
Normal file
45
domains/domains_test.go
Normal file
@ -0,0 +1,45 @@
|
||||
package domains
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDomainsNew(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
|
||||
assert.NoError(t, err)
|
||||
|
||||
domains := New(db)
|
||||
_, err = db.Exec("insert into domains (domain, active) values (?, ?)", "example.com", 1)
|
||||
assert.NoError(t, err)
|
||||
domains.Compile()
|
||||
|
||||
if _, ok := domains.m["example.com"]; ok {
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
if _, ok := domains.m["www.example.com"]; !ok {
|
||||
assert.False(t, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomains_IsValid(t *testing.T) {
|
||||
// open sqlite database
|
||||
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
|
||||
assert.NoError(t, err)
|
||||
|
||||
domains := New(db)
|
||||
_, err = domains.db.Exec("insert into domains (domain, active) values (?, ?)", "example.com", 1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
domains.s.Lock()
|
||||
assert.NoError(t, domains.internalCompile(domains.m))
|
||||
domains.s.Unlock()
|
||||
|
||||
assert.True(t, domains.IsValid("example.com"))
|
||||
assert.True(t, domains.IsValid("www.example.com"))
|
||||
assert.False(t, domains.IsValid("notexample.com"))
|
||||
assert.False(t, domains.IsValid("www.notexample.com"))
|
||||
}
|
@ -29,6 +29,7 @@ func New(dir fs.FS) *ErrorPages {
|
||||
generic: func(rw http.ResponseWriter, code int) {
|
||||
// if status text is empty then the code is unknown
|
||||
a := http.StatusText(code)
|
||||
fmt.Printf("%d - %s\n", code, a)
|
||||
if a != "" {
|
||||
// output in "xxx Error Text" format
|
||||
http.Error(rw, fmt.Sprintf("%d %s\n", code, a), code)
|
||||
@ -64,11 +65,13 @@ func (e *ErrorPages) Compile() {
|
||||
errorPageMap := make(map[int]func(rw http.ResponseWriter))
|
||||
|
||||
// compile map and check errors
|
||||
if e.dir != nil {
|
||||
err := e.internalCompile(errorPageMap)
|
||||
if err != nil {
|
||||
log.Printf("[Certs] Compile failed: %s\n", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// lock while replacing the map
|
||||
e.s.Lock()
|
||||
|
31
error-pages/error-pages_test.go
Normal file
31
error-pages/error-pages_test.go
Normal file
@ -0,0 +1,31 @@
|
||||
package error_pages
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrorPages_ServeError(t *testing.T) {
|
||||
errorPages := New(nil)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
errorPages.ServeError(rec, http.StatusTeapot)
|
||||
res := rec.Result()
|
||||
assert.Equal(t, http.StatusTeapot, res.StatusCode)
|
||||
assert.Equal(t, "418 I'm a teapot", res.Status)
|
||||
a, err := io.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "418 I'm a teapot\n\n", string(a))
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
errorPages.ServeError(rec, 469)
|
||||
res = rec.Result()
|
||||
assert.Equal(t, 469, res.StatusCode)
|
||||
assert.Equal(t, "469 ", res.Status)
|
||||
a, err = io.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "469 Unknown Error Code\n\n", string(a))
|
||||
}
|
BIN
favicons/example.ico
Normal file
BIN
favicons/example.ico
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.2 KiB |
BIN
favicons/example.png
Normal file
BIN
favicons/example.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.2 KiB |
4
favicons/example.svg
Normal file
4
favicons/example.svg
Normal file
@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg height="100" width="100" viewBox="0 0 100 100" xmlns="http://www.w3.org/2000/svg">
|
||||
<circle cx="50" cy="50" r="40" stroke="black" stroke-width="3" fill="red"/>
|
||||
</svg>
|
After Width: | Height: | Size: 214 B |
17
favicons/favicon-image.go
Normal file
17
favicons/favicon-image.go
Normal file
@ -0,0 +1,17 @@
|
||||
package favicons
|
||||
|
||||
// FaviconImage stores the url, hash and raw bytes of an image
|
||||
type FaviconImage struct {
|
||||
Url string
|
||||
Hash string
|
||||
Raw []byte
|
||||
}
|
||||
|
||||
// CreateFaviconImage outputs a FaviconImage with the specified URL or nil if
|
||||
// the URL is an empty string.
|
||||
func CreateFaviconImage(url string) *FaviconImage {
|
||||
if url == "" {
|
||||
return nil
|
||||
}
|
||||
return &FaviconImage{Url: url}
|
||||
}
|
144
favicons/favicon-list.go
Normal file
144
favicons/favicon-list.go
Normal file
@ -0,0 +1,144 @@
|
||||
package favicons
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"github.com/mrmelon54/png2ico"
|
||||
"image/png"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// FaviconList contains the ico, png and svg icons for separate favicons
|
||||
type FaviconList struct {
|
||||
Ico *FaviconImage // can be generated from png with wrapper
|
||||
Png *FaviconImage // can be generated from svg with inkscape
|
||||
Svg *FaviconImage
|
||||
}
|
||||
|
||||
// ProduceIco outputs the bytes of the ico icon or an error
|
||||
func (l *FaviconList) ProduceIco() ([]byte, error) {
|
||||
if l.Ico == nil {
|
||||
return nil, ErrFaviconNotFound
|
||||
}
|
||||
return l.Ico.Raw, nil
|
||||
}
|
||||
|
||||
// ProducePng outputs the bytes of the png icon or an error
|
||||
func (l *FaviconList) ProducePng() ([]byte, error) {
|
||||
if l.Png == nil {
|
||||
return nil, ErrFaviconNotFound
|
||||
}
|
||||
return l.Png.Raw, nil
|
||||
}
|
||||
|
||||
// ProduceSvg outputs the bytes of the svg icon or an error
|
||||
func (l *FaviconList) ProduceSvg() ([]byte, error) {
|
||||
if l.Svg == nil {
|
||||
return nil, ErrFaviconNotFound
|
||||
}
|
||||
return l.Svg.Raw, nil
|
||||
}
|
||||
|
||||
// PreProcess takes an input of the svg2png conversion function and outputs
|
||||
// an error if the SVG, PNG or ICO fails to download or generate
|
||||
func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error {
|
||||
var err error
|
||||
|
||||
// SVG
|
||||
if l.Svg != nil {
|
||||
// download SVG
|
||||
l.Svg.Raw, err = getFaviconViaRequest(l.Svg.Url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("[Favicons] Failed to fetch SVG icon: %w", err)
|
||||
}
|
||||
l.Svg.Hash = hex.EncodeToString(sha256.New().Sum(l.Svg.Raw))
|
||||
}
|
||||
|
||||
// PNG
|
||||
if l.Png != nil {
|
||||
// download PNG
|
||||
l.Png.Raw, err = getFaviconViaRequest(l.Png.Url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("[Favicons] Failed to fetch PNG icon: %w", err)
|
||||
}
|
||||
} else if l.Svg != nil {
|
||||
// generate PNG from SVG
|
||||
l.Png = &FaviconImage{}
|
||||
l.Png.Raw, err = convert(l.Svg.Raw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("[Favicons] Failed to generate PNG icon: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ICO
|
||||
if l.Ico != nil {
|
||||
// download ICO
|
||||
l.Ico.Raw, err = getFaviconViaRequest(l.Ico.Url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("[Favicons] Failed to fetch ICO icon: %w", err)
|
||||
}
|
||||
} else if l.Png != nil {
|
||||
// generate ICO from PNG
|
||||
l.Ico = &FaviconImage{}
|
||||
decode, err := png.Decode(bytes.NewReader(l.Png.Raw))
|
||||
if err != nil {
|
||||
return fmt.Errorf("[Favicons] Failed to decode PNG icon: %w", err)
|
||||
}
|
||||
b := decode.Bounds()
|
||||
l.Ico.Raw, err = png2ico.ConvertPngToIco(l.Png.Raw, b.Dx(), b.Dy())
|
||||
if err != nil {
|
||||
return fmt.Errorf("[Favicons] Failed to generate ICO icon: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// generate sha256 hashes for svg, png and ico
|
||||
l.genSha256()
|
||||
return nil
|
||||
}
|
||||
|
||||
// genSha256 generates sha256 hashes
|
||||
func (l *FaviconList) genSha256() {
|
||||
if l.Svg != nil {
|
||||
l.Svg.Hash = genSha256(l.Svg.Raw)
|
||||
}
|
||||
if l.Png != nil {
|
||||
l.Png.Hash = genSha256(l.Png.Raw)
|
||||
}
|
||||
if l.Ico != nil {
|
||||
l.Ico.Hash = genSha256(l.Ico.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
// getFaviconViaRequest uses the standard http request library to download
|
||||
// icons, outputs the raw bytes from the download or an error.
|
||||
var getFaviconViaRequest = func(url string) ([]byte, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Favicons] Failed to send request '%s': %w", url, err)
|
||||
}
|
||||
req.Header.Set("X-Violet-Raw-Favicon", "1")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Favicons] Failed to do request '%s': %w", url, err)
|
||||
}
|
||||
rawBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Favicons] Failed to read response '%s': %w", url, err)
|
||||
}
|
||||
return rawBody, nil
|
||||
}
|
||||
|
||||
// genSha256 generates a sha256 hash as a hex encoded string
|
||||
func genSha256(in []byte) string {
|
||||
// create sha256 generator and write to it
|
||||
h := sha256.New()
|
||||
_, err := h.Write(in)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
// encode as hex
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
33
favicons/favicon-list_test.go
Normal file
33
favicons/favicon-list_test.go
Normal file
@ -0,0 +1,33 @@
|
||||
package favicons
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFaviconList_PreProcess(t *testing.T) {
|
||||
getFaviconViaRequest = func(_ string) ([]byte, error) {
|
||||
return exampleSvg, nil
|
||||
}
|
||||
icons := &FaviconList{Svg: &FaviconImage{Url: "https://example.com/assets/logo.svg"}}
|
||||
assert.NoError(t, icons.PreProcess(func(in []byte) ([]byte, error) {
|
||||
return svg2png("inkscape", in)
|
||||
}))
|
||||
iconSvg, err := icons.ProduceSvg()
|
||||
assert.NoError(t, err)
|
||||
iconPng, err := icons.ProducePng()
|
||||
assert.NoError(t, err)
|
||||
iconIco, err := icons.ProduceIco()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "https://example.com/assets/logo.svg", icons.Svg.Url)
|
||||
|
||||
assert.Equal(t, "74cdc17d0502a690941799c327d9ca1ed042e76c784def43a42937f2eed270b4", icons.Svg.Hash)
|
||||
assert.Equal(t, "84841341dafbb1e54c62d160dfc5e48c3f8db4b22265a4dbe2e0318debf9b670", icons.Png.Hash)
|
||||
assert.Equal(t, "33fc667fdb0e32305f2ee27e7dd7feb781cc776638d0971db7e18cc6335a15c7", icons.Ico.Hash)
|
||||
|
||||
assert.Equal(t, 0, bytes.Compare(exampleSvg, iconSvg))
|
||||
assert.Equal(t, 0, bytes.Compare(examplePng, iconPng))
|
||||
assert.Equal(t, 0, bytes.Compare(exampleIco, iconIco))
|
||||
}
|
@ -1,18 +1,11 @@
|
||||
package favicons
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/mrmelon54/png2ico"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"image/png"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@ -121,150 +114,3 @@ func (f *Favicons) internalCompile(faviconMap map[string]*FaviconList) error {
|
||||
func (f *Favicons) convertSvgToPng(in []byte) ([]byte, error) {
|
||||
return svg2png(f.cmd, in)
|
||||
}
|
||||
|
||||
// FaviconList contains the ico, png and svg icons for separate favicons
|
||||
type FaviconList struct {
|
||||
Ico *FaviconImage // can be generated from png with wrapper
|
||||
Png *FaviconImage // can be generated from svg with inkscape
|
||||
Svg *FaviconImage
|
||||
}
|
||||
|
||||
// ProduceIco outputs the bytes of the ico icon or an error
|
||||
func (l *FaviconList) ProduceIco() ([]byte, error) {
|
||||
if l.Ico == nil {
|
||||
return nil, ErrFaviconNotFound
|
||||
}
|
||||
return l.Ico.Raw, nil
|
||||
}
|
||||
|
||||
// ProducePng outputs the bytes of the png icon or an error
|
||||
func (l *FaviconList) ProducePng() ([]byte, error) {
|
||||
if l.Png == nil {
|
||||
return nil, ErrFaviconNotFound
|
||||
}
|
||||
return l.Png.Raw, nil
|
||||
}
|
||||
|
||||
// ProduceSvg outputs the bytes of the svg icon or an error
|
||||
func (l *FaviconList) ProduceSvg() ([]byte, error) {
|
||||
if l.Svg == nil {
|
||||
return nil, ErrFaviconNotFound
|
||||
}
|
||||
return l.Svg.Raw, nil
|
||||
}
|
||||
|
||||
// PreProcess takes an input of the svg2png conversion function and outputs
|
||||
// an error if the SVG, PNG or ICO fails to download or generate
|
||||
func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error {
|
||||
var err error
|
||||
|
||||
// SVG
|
||||
if l.Svg != nil {
|
||||
// download SVG
|
||||
l.Svg.Raw, err = getFaviconViaRequest(l.Svg.Url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("[Favicons] Failed to fetch SVG icon: %w", err)
|
||||
}
|
||||
l.Svg.Hash = hex.EncodeToString(sha256.New().Sum(l.Svg.Raw))
|
||||
}
|
||||
|
||||
// PNG
|
||||
if l.Png != nil {
|
||||
// download PNG
|
||||
l.Png.Raw, err = getFaviconViaRequest(l.Png.Url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("[Favicons] Failed to fetch PNG icon: %w", err)
|
||||
}
|
||||
} else if l.Svg != nil {
|
||||
// generate PNG from SVG
|
||||
l.Png = &FaviconImage{}
|
||||
l.Png.Raw, err = convert(l.Svg.Raw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("[Favicons] Failed to generate PNG icon: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ICO
|
||||
if l.Ico != nil {
|
||||
// download ICO
|
||||
l.Ico.Raw, err = getFaviconViaRequest(l.Ico.Url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("[Favicons] Failed to fetch ICO icon: %w", err)
|
||||
}
|
||||
} else if l.Png != nil {
|
||||
// generate ICO from PNG
|
||||
l.Ico = &FaviconImage{}
|
||||
decode, err := png.Decode(bytes.NewReader(l.Png.Raw))
|
||||
if err != nil {
|
||||
return fmt.Errorf("[Favicons] Failed to decode PNG icon: %w", err)
|
||||
}
|
||||
b := decode.Bounds()
|
||||
l.Ico.Raw, err = png2ico.ConvertPngToIco(l.Png.Raw, b.Dx(), b.Dy())
|
||||
if err != nil {
|
||||
return fmt.Errorf("[Favicons] Failed to generate ICO icon: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// generate sha256 hashes for svg, png and ico
|
||||
l.genSha256()
|
||||
return nil
|
||||
}
|
||||
|
||||
// genSha256 generates sha256 hashes
|
||||
func (l *FaviconList) genSha256() {
|
||||
if l.Svg != nil {
|
||||
l.Svg.Hash = genSha256(l.Svg.Raw)
|
||||
}
|
||||
if l.Png != nil {
|
||||
l.Png.Hash = genSha256(l.Png.Raw)
|
||||
}
|
||||
if l.Ico != nil {
|
||||
l.Ico.Hash = genSha256(l.Ico.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
// getFaviconViaRequest uses the standard http request library to download
|
||||
// icons, outputs the raw bytes from the download or an error.
|
||||
func getFaviconViaRequest(url string) ([]byte, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Favicons] Failed to send request '%s': %w", url, err)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Favicons] Failed to do request '%s': %w", url, err)
|
||||
}
|
||||
rawBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[Favicons] Failed to read response '%s': %w", url, err)
|
||||
}
|
||||
return rawBody, nil
|
||||
}
|
||||
|
||||
// genSha256 generates a sha256 hash as a hex encoded string
|
||||
func genSha256(in []byte) string {
|
||||
// create sha256 generator and write to it
|
||||
h := sha256.New()
|
||||
_, err := h.Write(in)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
// encode as hex
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// FaviconImage stores the url, hash and raw bytes of an image
|
||||
type FaviconImage struct {
|
||||
Url string
|
||||
Hash string
|
||||
Raw []byte
|
||||
}
|
||||
|
||||
// CreateFaviconImage outputs a FaviconImage with the specified URL or nil if
|
||||
// the URL is an empty string.
|
||||
func CreateFaviconImage(url string) *FaviconImage {
|
||||
if url == "" {
|
||||
return nil
|
||||
}
|
||||
return &FaviconImage{Url: url}
|
||||
}
|
||||
|
51
favicons/favicons_test.go
Normal file
51
favicons/favicons_test.go
Normal file
@ -0,0 +1,51 @@
|
||||
package favicons
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed example.svg
|
||||
exampleSvg []byte
|
||||
//go:embed example.png
|
||||
examplePng []byte
|
||||
//go:embed example.ico
|
||||
exampleIco []byte
|
||||
)
|
||||
|
||||
func TestFaviconsNew(t *testing.T) {
|
||||
getFaviconViaRequest = func(_ string) ([]byte, error) { return exampleSvg, nil }
|
||||
|
||||
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
|
||||
assert.NoError(t, err)
|
||||
|
||||
favicons := New(db, "inkscape")
|
||||
_, err = db.Exec("insert into favicons (host, svg, png, ico) values (?, ?, ?, ?)", "example.com", "https://example.com/assets/logo.svg", "", "")
|
||||
assert.NoError(t, err)
|
||||
favicons.cLock.Lock()
|
||||
assert.NoError(t, favicons.internalCompile(favicons.faviconMap))
|
||||
favicons.cLock.Unlock()
|
||||
|
||||
icons := favicons.GetIcons("example.com")
|
||||
iconSvg, err := icons.ProduceSvg()
|
||||
assert.NoError(t, err)
|
||||
iconPng, err := icons.ProducePng()
|
||||
assert.NoError(t, err)
|
||||
iconIco, err := icons.ProduceIco()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "https://example.com/assets/logo.svg", icons.Svg.Url)
|
||||
|
||||
assert.Equal(t, "74cdc17d0502a690941799c327d9ca1ed042e76c784def43a42937f2eed270b4", icons.Svg.Hash)
|
||||
assert.Equal(t, "84841341dafbb1e54c62d160dfc5e48c3f8db4b22265a4dbe2e0318debf9b670", icons.Png.Hash)
|
||||
assert.Equal(t, "33fc667fdb0e32305f2ee27e7dd7feb781cc776638d0971db7e18cc6335a15c7", icons.Ico.Hash)
|
||||
|
||||
assert.Equal(t, 0, bytes.Compare(exampleSvg, iconSvg))
|
||||
assert.Equal(t, 0, bytes.Compare(examplePng, iconPng))
|
||||
assert.Equal(t, 0, bytes.Compare(exampleIco, iconIco))
|
||||
}
|
2
go.mod
2
go.mod
@ -6,7 +6,6 @@ require (
|
||||
code.mrmelon54.com/melon/certgen v0.0.0-20220830133534-0fb4cb7e67d1
|
||||
code.mrmelon54.com/melon/summer-utils v0.0.3
|
||||
github.com/MrMelon54/trie v0.0.2
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/julienschmidt/httprouter v1.3.0
|
||||
github.com/mattn/go-sqlite3 v1.14.16
|
||||
github.com/mrmelon54/mjwt v0.0.1
|
||||
@ -24,6 +23,7 @@ require (
|
||||
github.com/kr/pretty v0.1.0 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
golang.org/x/text v0.9.0 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
4
go.sum
4
go.sum
@ -9,8 +9,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
|
||||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
||||
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
|
||||
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
@ -43,6 +41,8 @@ golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
|
||||
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
73
proxy/hybrid-transport.go
Normal file
73
proxy/hybrid-transport.go
Normal file
@ -0,0 +1,73 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type HybridTransport struct {
|
||||
baseDialer *net.Dialer
|
||||
normalTransport http.RoundTripper
|
||||
insecureTransport http.RoundTripper
|
||||
socksSync *sync.RWMutex
|
||||
socksTransport map[string]http.RoundTripper
|
||||
}
|
||||
|
||||
// NewHybridTransport creates a new hybrid transport
|
||||
func NewHybridTransport() *HybridTransport {
|
||||
return NewHybridTransportWithCalls(nil, nil)
|
||||
}
|
||||
|
||||
// NewHybridTransportWithCalls creates new hybrid transport with custom normal
|
||||
// and insecure http.RoundTripper functions.
|
||||
//
|
||||
// NewHybridTransportWithCalls(nil, nil) is equivalent to NewHybridTransport()
|
||||
func NewHybridTransportWithCalls(normal, insecure http.RoundTripper) *HybridTransport {
|
||||
h := &HybridTransport{
|
||||
baseDialer: &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
},
|
||||
normalTransport: normal,
|
||||
insecureTransport: insecure,
|
||||
}
|
||||
if h.normalTransport == nil {
|
||||
h.normalTransport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: h.baseDialer.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 15,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
if h.insecureTransport == nil {
|
||||
h.insecureTransport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: h.baseDialer.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 15,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// SecureRoundTrip calls the secure transport
|
||||
func (h *HybridTransport) SecureRoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return h.normalTransport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// InsecureRoundTrip calls the insecure transport
|
||||
func (h *HybridTransport) InsecureRoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return h.insecureTransport.RoundTrip(req)
|
||||
}
|
16
proxy/hybrid-transport_test.go
Normal file
16
proxy/hybrid-transport_test.go
Normal file
@ -0,0 +1,16 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewHybridTransport(t *testing.T) {
|
||||
h := NewHybridTransport()
|
||||
req, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
trip, err := h.SecureRoundTrip(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, trip.StatusCode)
|
||||
}
|
@ -1,145 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/net/proxy"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type reverseProxyHostKey int
|
||||
|
||||
type ReverseProxyContext interface {
|
||||
IsIgnoreCert() bool
|
||||
UpdateHeaders(http.Header)
|
||||
}
|
||||
|
||||
func SetReverseProxyHost(req *http.Request, hf ReverseProxyContext) *http.Request {
|
||||
ctx := req.Context()
|
||||
ctx2 := context.WithValue(ctx, reverseProxyHostKey(0), hf)
|
||||
return req.WithContext(ctx2)
|
||||
}
|
||||
|
||||
func CreateHybridReverseProxy() *httputil.ReverseProxy {
|
||||
return &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {},
|
||||
Transport: NewHybridTransport(),
|
||||
ModifyResponse: func(rw *http.Response) error { return nil },
|
||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
log.Printf("[ReverseProxy] Request: %#v\n -- Error: %s\n", req, err)
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
_, _ = rw.Write([]byte("502 Bad gateway\n"))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type HybridTransport struct {
|
||||
baseDialer *net.Dialer
|
||||
normalTransport http.RoundTripper
|
||||
insecureTransport http.RoundTripper
|
||||
socksSync *sync.RWMutex
|
||||
socksTransport map[string]http.RoundTripper
|
||||
}
|
||||
|
||||
func NewHybridTransport() *HybridTransport {
|
||||
h := &HybridTransport{
|
||||
baseDialer: &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
},
|
||||
socksSync: &sync.RWMutex{},
|
||||
socksTransport: make(map[string]http.RoundTripper),
|
||||
}
|
||||
h.normalTransport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: h.baseDialer.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 15,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
h.insecureTransport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: h.baseDialer.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 15,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *HybridTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newHost := req.Context().Value(reverseProxyHostKey(0))
|
||||
hf, ok := newHost.(ReverseProxyContext)
|
||||
if !ok {
|
||||
return nil, errors.New("failed to detect reverse proxy configuration")
|
||||
}
|
||||
|
||||
// Do a round trip using existing transports
|
||||
var trip *http.Response
|
||||
var err error
|
||||
if hf.IsIgnoreCert() {
|
||||
trip, err = h.insecureTransport.RoundTrip(req)
|
||||
} else {
|
||||
trip, err = h.normalTransport.RoundTrip(req)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Override headers
|
||||
hf.UpdateHeaders(trip.Header)
|
||||
return trip, nil
|
||||
}
|
||||
|
||||
func (h *HybridTransport) getSocksProxy(addr string, insecure bool) (http.RoundTripper, error) {
|
||||
if insecure {
|
||||
addr = "%i-" + addr
|
||||
}
|
||||
h.socksSync.RLock()
|
||||
s, ok := h.socksTransport[addr]
|
||||
h.socksSync.RUnlock()
|
||||
if ok {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
dialer, err := proxy.SOCKS5("tcp", addr, nil, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot connect to the proxy: %s", err)
|
||||
}
|
||||
|
||||
if f, ok := dialer.(proxy.ContextDialer); ok {
|
||||
t := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: f.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 15,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
DisableKeepAlives: true,
|
||||
}
|
||||
if insecure {
|
||||
t.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
}
|
||||
h.socksSync.Lock()
|
||||
h.socksTransport[addr] = t
|
||||
h.socksSync.Unlock()
|
||||
return t, nil
|
||||
}
|
||||
return nil, errors.New("cannot create socks5 dialer")
|
||||
}
|
48
proxy/reverse-proxy_test.go
Normal file
48
proxy/reverse-proxy_test.go
Normal file
@ -0,0 +1,48 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type customTransport struct{}
|
||||
|
||||
func (c *customTransport) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||
res := httptest.NewRecorder()
|
||||
res.WriteHeader(http.StatusOK)
|
||||
res.Write([]byte{0x54, 0x54})
|
||||
return res.Result(), nil
|
||||
}
|
||||
|
||||
func SetupReverseProxy() *httputil.ReverseProxy {
|
||||
return &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {},
|
||||
Transport: &customTransport{},
|
||||
ModifyResponse: func(rw *http.Response) error { return nil },
|
||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
log.Printf("[ReverseProxy] Request: %#v\n -- Error: %s\n", req, err)
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
_, _ = rw.Write([]byte("502 Bad gateway\n"))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHttpUtilReverseProxy(b *testing.B) {
|
||||
rev := SetupReverseProxy()
|
||||
req, _ := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||
for i := 0; i < b.N; i++ {
|
||||
rec := httptest.NewRecorder()
|
||||
rev.ServeHTTP(rec, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCustomTransport(b *testing.B) {
|
||||
req, _ := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||
t := &customTransport{}
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = t.RoundTrip(req)
|
||||
}
|
||||
}
|
@ -4,11 +4,11 @@ import (
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"github.com/MrMelon54/violet/proxy"
|
||||
"github.com/MrMelon54/violet/target"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
@ -19,7 +19,7 @@ type Manager struct {
|
||||
db *sql.DB
|
||||
s *sync.RWMutex
|
||||
r *Router
|
||||
p *httputil.ReverseProxy
|
||||
p *proxy.HybridTransport
|
||||
}
|
||||
|
||||
var (
|
||||
@ -35,7 +35,7 @@ var (
|
||||
|
||||
// NewManager create a new manager, initialises the routes and redirects tables
|
||||
// in the database and runs a first time compile.
|
||||
func NewManager(db *sql.DB, proxy *httputil.ReverseProxy) *Manager {
|
||||
func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager {
|
||||
m := &Manager{
|
||||
db: db,
|
||||
s: &sync.RWMutex{},
|
||||
|
@ -3,10 +3,10 @@ package router
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/MrMelon54/trie"
|
||||
"github.com/MrMelon54/violet/proxy"
|
||||
"github.com/MrMelon54/violet/target"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@ -14,10 +14,10 @@ type Router struct {
|
||||
route map[string]*trie.Trie[target.Route]
|
||||
redirect map[string]*trie.Trie[target.Redirect]
|
||||
notFound http.Handler
|
||||
proxy *httputil.ReverseProxy
|
||||
proxy *proxy.HybridTransport
|
||||
}
|
||||
|
||||
func New(proxy *httputil.ReverseProxy) *Router {
|
||||
func New(proxy *proxy.HybridTransport) *Router {
|
||||
return &Router{
|
||||
route: make(map[string]*trie.Trie[target.Route]),
|
||||
redirect: make(map[string]*trie.Trie[target.Redirect]),
|
||||
|
@ -14,7 +14,7 @@ import (
|
||||
// endpoints for the reverse proxy.
|
||||
//
|
||||
// `/.well-known/acme-challenge/{token}` is used for outputting answers for
|
||||
// acme challenges, this is used for Lets Encrypt HTTP verification.
|
||||
// acme challenges, this is used for Let's Encrypt HTTP verification.
|
||||
func NewHttpServer(conf *Conf) *http.Server {
|
||||
r := httprouter.New()
|
||||
var secureExtend string
|
||||
|
@ -3,8 +3,8 @@ package servers
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/MrMelon54/violet/favicons"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/sethvargo/go-limiter/httplimit"
|
||||
"github.com/sethvargo/go-limiter/memorystore"
|
||||
"log"
|
||||
@ -18,7 +18,7 @@ import (
|
||||
func NewHttpsServer(conf *Conf) *http.Server {
|
||||
s := &http.Server{
|
||||
Addr: conf.HttpsListen,
|
||||
Handler: setupRateLimiter(300).Middleware(conf.Router),
|
||||
Handler: setupRateLimiter(300, setupFaviconMiddleware(conf.Favicons, conf.Router)),
|
||||
DisableGeneralOptionsHandler: false,
|
||||
TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
// error out on invalid domains
|
||||
@ -51,7 +51,7 @@ func NewHttpsServer(conf *Conf) *http.Server {
|
||||
|
||||
// setupRateLimiter is an internal function to create a middleware to manage
|
||||
// rate limits.
|
||||
func setupRateLimiter(rateLimit uint64) mux.MiddlewareFunc {
|
||||
func setupRateLimiter(rateLimit uint64, next http.Handler) http.Handler {
|
||||
// create memory store
|
||||
store, err := memorystore.New(&memorystore.Config{
|
||||
Tokens: rateLimit,
|
||||
@ -66,5 +66,45 @@ func setupRateLimiter(rateLimit uint64) mux.MiddlewareFunc {
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
return middleware.Handle
|
||||
return middleware.Handle(next)
|
||||
}
|
||||
|
||||
func setupFaviconMiddleware(fav *favicons.Favicons, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Header.Get("X-Violet-Raw-Favicon") != "1" {
|
||||
switch req.URL.Path {
|
||||
case "/favicon.svg":
|
||||
icons := fav.GetIcons(req.Host)
|
||||
raw, err := icons.ProduceSvg()
|
||||
if err != nil {
|
||||
utils.RespondVioletError(rw, http.StatusTeapot, "No SVG icon available")
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = rw.Write(raw)
|
||||
return
|
||||
case "/favicon.png":
|
||||
icons := fav.GetIcons(req.Host)
|
||||
raw, err := icons.ProducePng()
|
||||
if err != nil {
|
||||
utils.RespondVioletError(rw, http.StatusTeapot, "No PNG icon available")
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = rw.Write(raw)
|
||||
return
|
||||
case "/favicon.ico":
|
||||
icons := fav.GetIcons(req.Host)
|
||||
raw, err := icons.ProduceIco()
|
||||
if err != nil {
|
||||
utils.RespondVioletError(rw, http.StatusTeapot, "No ICO icon available")
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = rw.Write(raw)
|
||||
return
|
||||
}
|
||||
}
|
||||
next.ServeHTTP(rw, req)
|
||||
})
|
||||
}
|
||||
|
165
target/route.go
165
target/route.go
@ -1,12 +1,17 @@
|
||||
package target
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/MrMelon54/violet/proxy"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/rs/cors"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
@ -42,12 +47,9 @@ type Route struct {
|
||||
ForwardAddr bool // forward remote address
|
||||
IgnoreCert bool // ignore self-cert
|
||||
Headers http.Header // extra headers
|
||||
Proxy http.Handler // reverse proxy handler
|
||||
Proxy *proxy.HybridTransport // reverse proxy handler
|
||||
}
|
||||
|
||||
// IsIgnoreCert returns true if IgnoreCert is enabled.
|
||||
func (r Route) IsIgnoreCert() bool { return r.IgnoreCert }
|
||||
|
||||
// UpdateHeaders takes an existing set of headers and overwrites them with the
|
||||
// extra headers.
|
||||
func (r Route) UpdateHeaders(header http.Header) {
|
||||
@ -122,7 +124,7 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
req2, err := http.NewRequest(req.Method, u.String(), req.Body)
|
||||
if err != nil {
|
||||
log.Printf("[ServeRoute::ServeHTTP()] Error generating new request: %s\n", err)
|
||||
utils.RespondHttpStatus(rw, http.StatusBadGateway)
|
||||
utils.RespondVioletError(rw, http.StatusBadGateway, "error generating new request")
|
||||
return
|
||||
}
|
||||
|
||||
@ -153,11 +155,162 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
req2.Header.Add("X-Forwarded-For", req.RemoteAddr)
|
||||
}
|
||||
|
||||
// adds extra request metadata
|
||||
r.internalReverseProxyMeta(rw, req)
|
||||
|
||||
// serve request with reverse proxy
|
||||
r.Proxy.ServeHTTP(rw, proxy.SetReverseProxyHost(req2, r))
|
||||
var resp *http.Response
|
||||
if r.IgnoreCert {
|
||||
resp, err = r.Proxy.InsecureRoundTrip(req2)
|
||||
} else {
|
||||
resp, err = r.Proxy.SecureRoundTrip(req2)
|
||||
}
|
||||
|
||||
// copy headers and status code
|
||||
copyHeader(rw.Header(), resp.Header)
|
||||
rw.WriteHeader(resp.StatusCode)
|
||||
|
||||
// copy body
|
||||
if resp.Body != nil {
|
||||
_, err := io.Copy(rw, resp.Body)
|
||||
if err != nil {
|
||||
// hijack and close upon error
|
||||
if h, ok := rw.(http.Hijacker); ok {
|
||||
hijack, _, err := h.Hijack()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = hijack.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// internalReverseProxyMeta is mainly built from code copied from httputil.ReverseProxy,
|
||||
// due to the highly custom nature of this reverse proxy software we use a copy
|
||||
// of the code instead of the full httputil implementation to prevent overhead
|
||||
// from the more generic implementation
|
||||
func (r Route) internalReverseProxyMeta(rw http.ResponseWriter, req *http.Request) {
|
||||
outreq := req.Clone(context.Background())
|
||||
if req.ContentLength == 0 {
|
||||
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
|
||||
}
|
||||
if outreq.Body != nil {
|
||||
// Reading from the request body after returning from a handler is not
|
||||
// allowed, and the RoundTrip goroutine that reads the Body can outlive
|
||||
// this handler. This can lead to a crash if the handler panics (see
|
||||
// Issue 46866). Although calling Close doesn't guarantee there isn't
|
||||
// any Read in flight after the handle returns, in practice it's safe to
|
||||
// read after closing it.
|
||||
defer outreq.Body.Close()
|
||||
}
|
||||
if outreq.Header == nil {
|
||||
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
|
||||
}
|
||||
|
||||
reqUpType := upgradeType(outreq.Header)
|
||||
if !asciiIsPrint(reqUpType) {
|
||||
utils.RespondVioletError(rw, http.StatusBadRequest, fmt.Sprintf("client tried to switch to invalid protocol %q", reqUpType))
|
||||
return
|
||||
}
|
||||
removeHopByHopHeaders(outreq.Header)
|
||||
|
||||
// Issue 21096: tell backend applications that care about trailer support
|
||||
// that we support trailers. (We do, but we don't go out of our way to
|
||||
// advertise that unless the incoming client request thought it was worth
|
||||
// mentioning.) Note that we look at req.Header, not outreq.Header, since
|
||||
// the latter has passed through removeHopByHopHeaders.
|
||||
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
|
||||
outreq.Header.Set("Te", "trailers")
|
||||
}
|
||||
|
||||
// After stripping all the hop-by-hop connection headers above, add back any
|
||||
// necessary for protocol upgrades, such as for websockets.
|
||||
if reqUpType != "" {
|
||||
outreq.Header.Set("Connection", "Upgrade")
|
||||
outreq.Header.Set("Upgrade", reqUpType)
|
||||
}
|
||||
|
||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||
// If we aren't the first proxy retain prior
|
||||
// X-Forwarded-For information as a comma+space
|
||||
// separated list and fold multiple headers into one.
|
||||
prior, ok := outreq.Header["X-Forwarded-For"]
|
||||
omit := ok && prior == nil // Issue 38079: nil now means don't populate the header
|
||||
if len(prior) > 0 {
|
||||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
}
|
||||
if !omit {
|
||||
outreq.Header.Set("X-Forwarded-For", clientIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// String outputs a debug string for the route.
|
||||
func (r Route) String() string {
|
||||
return fmt.Sprintf("%#v", r)
|
||||
}
|
||||
|
||||
// copyHeader copies all headers from src to dst
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateType returns the value of upgrade from http.Header
|
||||
func upgradeType(h http.Header) string {
|
||||
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
|
||||
return ""
|
||||
}
|
||||
return h.Get("Upgrade")
|
||||
}
|
||||
|
||||
// IsPrint returns whether s is ASCII and printable according to
|
||||
// https://tools.ietf.org/html/rfc20#section-4.2.
|
||||
func asciiIsPrint(s string) bool {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] < ' ' || s[i] > '~' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
// As of RFC 7230, hop-by-hop headers are required to appear in the
|
||||
// Connection header field. These are the headers defined by the
|
||||
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
|
||||
// compatibility.
|
||||
var hopHeaders = []string{
|
||||
"Connection",
|
||||
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
|
||||
// removeHopByHopHeaders removes the hop-by-hop headers defined in hopHeaders
|
||||
func removeHopByHopHeaders(h http.Header) {
|
||||
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
|
||||
for _, f := range h["Connection"] {
|
||||
for _, sf := range strings.Split(f, ",") {
|
||||
if sf = textproto.TrimString(sf); sf != "" {
|
||||
h.Del(sf)
|
||||
}
|
||||
}
|
||||
}
|
||||
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
|
||||
// This behavior is superseded by the RFC 7230 Connection header, but
|
||||
// preserve it for backwards compatibility.
|
||||
for _, f := range hopHeaders {
|
||||
h.Del(f)
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package target
|
||||
|
||||
import (
|
||||
"github.com/MrMelon54/violet/proxy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -9,14 +10,17 @@ import (
|
||||
|
||||
type proxyTester struct {
|
||||
got bool
|
||||
rw http.ResponseWriter
|
||||
req *http.Request
|
||||
}
|
||||
|
||||
func (p *proxyTester) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
func (p *proxyTester) makeHybridTransport() *proxy.HybridTransport {
|
||||
return proxy.NewHybridTransportWithCalls(p, p)
|
||||
}
|
||||
|
||||
func (p *proxyTester) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
p.got = true
|
||||
p.rw = rw
|
||||
p.req = req
|
||||
return &http.Response{StatusCode: http.StatusOK}, nil
|
||||
}
|
||||
|
||||
func TestRoute_FullHost(t *testing.T) {
|
||||
@ -38,7 +42,7 @@ func TestRoute_ServeHTTP(t *testing.T) {
|
||||
}
|
||||
for _, i := range a {
|
||||
pt := &proxyTester{}
|
||||
i.Proxy = pt
|
||||
i.Proxy = pt.makeHybridTransport()
|
||||
res := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "https://www.example.com/hello/world", nil)
|
||||
i.ServeHTTP(res, req)
|
||||
@ -62,7 +66,7 @@ func TestRoute_ServeHTTP_Cors(t *testing.T) {
|
||||
res := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodOptions, "https://www.example.com/test", nil)
|
||||
req.Header.Set("Origin", "https://test.example.com")
|
||||
i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt}
|
||||
i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt.makeHybridTransport()}
|
||||
i.ServeHTTP(res, req)
|
||||
|
||||
assert.True(t, pt.got)
|
||||
|
@ -9,3 +9,8 @@ import (
|
||||
func RespondHttpStatus(rw http.ResponseWriter, status int) {
|
||||
http.Error(rw, fmt.Sprintf("%d %s\n", status, http.StatusText(status)), status)
|
||||
}
|
||||
|
||||
func RespondVioletError(rw http.ResponseWriter, status int, msg string) {
|
||||
rw.Header().Set("X-Violet-Error", msg)
|
||||
RespondHttpStatus(rw, status)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user