violet/router/router_test.go
2024-04-20 16:17:32 +01:00

327 lines
8.8 KiB
Go

package router
import (
"fmt"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/1f349/violet/target"
"github.com/mrmelon54/trie"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"net/url"
"path"
"testing"
)
type routeTestBase struct {
path string
dst target.Route
tests map[string]string
}
type redirectTestBase struct {
path string
dst target.Redirect
tests map[string]string
}
type mss map[string]string
var (
routeTests = []routeTestBase{
{"/", target.Route{}, mss{
"/": "/",
"/hello": "",
}},
{"/", target.Route{Dst: "world"}, mss{
"/": "/world",
"/hello": "",
}},
{"/", target.Route{Flags: target.FlagAbs}, mss{
"/": "/",
"/hello": "",
}},
{"/", target.Route{Flags: target.FlagAbs, Dst: "world"}, mss{
"/": "/world",
"/hello": "",
}},
{"/", target.Route{Flags: target.FlagPre}, mss{
"/": "/",
"/hello": "/hello",
}},
{"/", target.Route{Flags: target.FlagPre, Dst: "world"}, mss{
"/": "/world",
"/hello": "/world/hello",
}},
{"/", target.Route{Flags: target.FlagPre | target.FlagAbs}, mss{
"/": "/",
"/hello": "/",
}},
{"/", target.Route{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{
"/": "/world",
"/hello": "/world",
}},
{"/hello", target.Route{}, mss{
"/": "",
"/hello": "/",
"/hello/hi": "",
}},
{"/hello", target.Route{Dst: "world"}, mss{
"/": "",
"/hello": "/world",
"/hello/hi": "",
}},
{"/hello", target.Route{Flags: target.FlagAbs}, mss{
"/": "",
"/hello": "/",
"/hello/hi": "",
}},
{"/hello", target.Route{Flags: target.FlagAbs, Dst: "world"}, mss{
"/": "",
"/hello": "/world",
"/hello/hi": "",
}},
{"/hello", target.Route{Flags: target.FlagPre}, mss{
"/": "",
"/hello": "/",
"/hello/hi": "/hi",
}},
{"/hello", target.Route{Flags: target.FlagPre, Dst: "world"}, mss{
"/": "",
"/hello": "/world",
"/hello/hi": "/world/hi",
}},
{"/hello", target.Route{Flags: target.FlagPre | target.FlagAbs}, mss{
"/": "",
"/hello": "/",
"/hello/hi": "/",
}},
{"/hello", target.Route{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{
"/": "",
"/hello": "/world",
"/hello/hi": "/world",
}},
}
redirectTests = []redirectTestBase{
{"/", target.Redirect{}, mss{
"/": "/",
"/hello": "",
}},
{"/", target.Redirect{Dst: "world"}, mss{
"/": "/world",
"/hello": "",
}},
{"/", target.Redirect{Flags: target.FlagAbs}, mss{
"/": "/",
"/hello": "",
}},
{"/", target.Redirect{Flags: target.FlagAbs, Dst: "world"}, mss{
"/": "/world",
"/hello": "",
}},
{"/", target.Redirect{Flags: target.FlagPre}, mss{
"/": "/",
"/hello": "/hello",
}},
{"/", target.Redirect{Flags: target.FlagPre, Dst: "world"}, mss{
"/": "/world",
"/hello": "/world/hello",
}},
{"/", target.Redirect{Flags: target.FlagPre | target.FlagAbs}, mss{
"/": "/",
"/hello": "/",
}},
{"/", target.Redirect{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{
"/": "/world",
"/hello": "/world",
}},
{"/hello", target.Redirect{}, mss{
"/": "",
"/hello": "/",
"/hello/hi": "",
}},
{"/hello", target.Redirect{Dst: "world"}, mss{
"/": "",
"/hello": "/world",
"/hello/hi": "",
}},
{"/hello", target.Redirect{Flags: target.FlagAbs}, mss{
"/": "",
"/hello": "/",
"/hello/hi": "",
}},
{"/hello", target.Redirect{Flags: target.FlagAbs, Dst: "world"}, mss{
"/": "",
"/hello": "/world",
"/hello/hi": "",
}},
{"/hello", target.Redirect{Flags: target.FlagPre}, mss{
"/": "",
"/hello": "/",
"/hello/hi": "/hi",
}},
{"/hello", target.Redirect{Flags: target.FlagPre, Dst: "world"}, mss{
"/": "",
"/hello": "/world",
"/hello/hi": "/world/hi",
}},
{"/hello", target.Redirect{Flags: target.FlagPre | target.FlagAbs}, mss{
"/": "",
"/hello": "/",
"/hello/hi": "/",
}},
{"/hello", target.Redirect{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{
"/": "",
"/hello": "/world",
"/hello/hi": "/world",
}},
}
)
func TestRouter_AddRoute(t *testing.T) {
transSecure := &fakeTransport{}
transInsecure := &fakeTransport{}
for _, i := range routeTests {
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure, &websocket.Server{}))
dst := i.dst
dst.Dst = path.Join("127.0.0.1:8080", dst.Dst)
dst.Src = path.Join("example.com", i.path)
t.Logf("Running tests for %#v\n", dst)
r.AddRoute(dst)
for k, v := range i.tests {
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: k}
req, _ := http.NewRequest(http.MethodGet, u1.String(), nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if v == "" {
if transSecure.req != nil {
t.Logf("Test URL: %#v\n", req.URL)
t.Log(r.route["example.com"].String())
t.Fatalf("%s => %s\n", k, v)
}
} else {
if transSecure.req == nil {
t.Logf("Test URL: %#v\n", req.URL)
t.Log(r.route["example.com"].String())
t.Fatalf("\nexpected %s => %s\n got %s => %s\n", k, v, k, "")
}
if v != transSecure.req.URL.Path {
t.Logf("Test URL: %#v\n", req.URL)
t.Log(r.route["example.com"].String())
t.Fatalf("\nexpected %s => %s\n got %s => %s\n", k, v, k, transSecure.req.URL.Path)
}
transSecure.req = nil
}
}
}
}
func TestRouter_AddRedirect(t *testing.T) {
for _, i := range redirectTests {
r := New(nil)
dst := i.dst
dst.Dst = path.Join("example.com", dst.Dst)
dst.Code = http.StatusFound
dst.Src = path.Join("www.example.com", i.path)
t.Logf("Running tests for %#v\n", dst)
r.AddRedirect(dst)
for k, v := range i.tests {
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: v}
if v == "" {
u1 = nil
}
u2 := &url.URL{Scheme: "https", Host: "www.example.com", Path: k}
assertHttpRedirect(t, r, http.StatusFound, outputUrl(u1), http.MethodGet, outputUrl(u2))
}
}
}
func assertHttpRedirect(t *testing.T, r *Router, code int, target, method, start string) {
res := httptest.NewRecorder()
req := httptest.NewRequest(method, start, nil)
r.ServeHTTP(res, req)
l := res.Header().Get("Location")
if target == "" {
if code == res.Code || "" != l {
t.Logf("Test URL: %#v\n", req.URL)
t.Log(r.redirect["www.example.com"].String())
t.Fatalf("%s => %s\n", start, target)
}
} else {
if code != res.Code || target != l {
t.Logf("Test URL: %#v\n", req.URL)
t.Log(r.redirect["www.example.com"].String())
t.Fatalf("\nexpected %s => %s\n got %s => %s\n", start, target, start, l)
}
}
}
func outputUrl(u *url.URL) string {
if u == nil {
return ""
}
return u.String()
}
func TestRouter_AddWildcardRoute(t *testing.T) {
transSecure := &fakeTransport{}
transInsecure := &fakeTransport{}
for _, i := range routeTests {
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure, &websocket.Server{}))
dst := i.dst
dst.Dst = path.Join("127.0.0.1:8080", dst.Dst)
dst.Src = path.Join("*.example.com", i.path)
t.Logf("Running tests for %#v\n", dst)
r.AddRoute(dst)
for k, v := range i.tests {
u1 := &url.URL{Scheme: "https", Host: "test.example.com", Path: k}
req, _ := http.NewRequest(http.MethodGet, u1.String(), nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if v == "" {
if transSecure.req != nil {
t.Logf("Test URL: %#v\n", req.URL)
t.Log(r.route["*.example.com"].String())
t.Fatalf("%s => %s\n", k, v)
}
} else {
if transSecure.req == nil {
t.Logf("Test URL: %#v\n", req.URL)
t.Log(r.route["*.example.com"].String())
t.Fatalf("\nexpected %s => %s\n got %s => %s\n", k, v, k, "")
}
if v != transSecure.req.URL.Path {
t.Logf("Test URL: %#v\n", req.URL)
t.Log(r.route["*.example.com"].String())
t.Fatalf("\nexpected %s => %s\n got %s => %s\n", k, v, k, transSecure.req.URL.Path)
}
transSecure.req = nil
}
}
}
}
type fakeRoundTripper struct{}
func (f *fakeRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) {
rec := httptest.NewRecorder()
rec.WriteHeader(http.StatusNotFound)
return rec.Result(), nil
}
func TestGetServeData_Route(t *testing.T) {
hyb := proxy.NewHybridTransportWithCalls(&fakeRoundTripper{}, &fakeRoundTripper{}, nil)
req, err := http.NewRequest(http.MethodGet, "https://example.com/hello/world/this/is/a/test", nil)
assert.NoError(t, err)
h := trie.BuildFromMap(map[string]target.Route{
"/hello/world": {Flags: target.FlagPre, Proxy: hyb},
})
rec := httptest.NewRecorder()
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
fmt.Printf("%#v\n", pairs)
assert.True(t, getServeData(rec, req, h))
}