diff --git a/proxy/websocket/server.go b/proxy/websocket/server.go index 50ff72e..4754940 100644 --- a/proxy/websocket/server.go +++ b/proxy/websocket/server.go @@ -62,15 +62,19 @@ func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) { s.Remove(c) return } - done := make(chan struct{}, 1) + d1 := make(chan struct{}, 1) + d2 := make(chan struct{}, 1) // relay messages each way - go s.wsRelay(done, c, ic) - go s.wsRelay(done, ic, c) + go s.wsRelay(d1, c, ic) + go s.wsRelay(d2, ic, c) // wait for done signal and close both connections go func() { - <-done + select { + case <-d1: + case <-d2: + } _ = c.Close() _ = ic.Close() }() @@ -80,7 +84,7 @@ func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) { func (s *Server) wsRelay(done chan struct{}, a, b *websocket.Conn) { defer func() { - done <- struct{}{} + close(done) }() for { mt, message, err := a.ReadMessage() diff --git a/router/router.go b/router/router.go index 091e968..16964ea 100644 --- a/router/router.go +++ b/router/router.go @@ -90,29 +90,29 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (r *Router) serveRouteHTTP(rw http.ResponseWriter, req *http.Request, host string) bool { h := r.route[host] - if h != nil { - pairs := h.GetAllKeyValues([]byte(req.URL.Path)) - for i := len(pairs) - 1; i >= 0; i-- { - if pairs[i].Value.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path { - req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key) - pairs[i].Value.ServeHTTP(rw, req) - return true - } - } - } - return false + return getServeData(rw, req, h) } func (r *Router) serveRedirectHTTP(rw http.ResponseWriter, req *http.Request, host string) bool { h := r.redirect[host] - if h != nil { - pairs := h.GetAllKeyValues([]byte(req.URL.Path)) - for i := len(pairs) - 1; i >= 0; i-- { - if pairs[i].Value.Flags.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path { - req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key) - pairs[i].Value.ServeHTTP(rw, req) - return true - } + return getServeData(rw, req, h) +} + +type serveDataInterface interface { + HasFlag(flag target.Flags) bool + ServeHTTP(rw http.ResponseWriter, req *http.Request) +} + +func getServeData[T serveDataInterface](rw http.ResponseWriter, req *http.Request, h *trie.Trie[T]) bool { + if h == nil { + return false + } + pairs := h.GetAllKeyValues([]byte(req.URL.Path)) + for i := len(pairs) - 1; i >= 0; i-- { + if pairs[i].Value.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path { + req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key) + pairs[i].Value.ServeHTTP(rw, req) + return true } } return false diff --git a/router/router_test.go b/router/router_test.go index 25aa670..be6e218 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -1,9 +1,12 @@ package router import ( + "fmt" "github.com/1f349/violet/proxy" "github.com/1f349/violet/proxy/websocket" "github.com/1f349/violet/target" + "github.com/MrMelon54/trie" + "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" "net/url" @@ -300,3 +303,24 @@ func TestRouter_AddWildcardRoute(t *testing.T) { } } } + +type fakeRoundTripper struct{} + +func (f *fakeRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) { + rec := httptest.NewRecorder() + rec.WriteHeader(http.StatusNotFound) + return rec.Result(), nil +} + +func TestGetServeData_Route(t *testing.T) { + hyb := proxy.NewHybridTransportWithCalls(&fakeRoundTripper{}, &fakeRoundTripper{}, nil) + req, err := http.NewRequest(http.MethodGet, "https://example.com/hello/world/this/is/a/test", nil) + assert.NoError(t, err) + h := trie.BuildFromMap(map[string]target.Route{ + "/hello/world": {Flags: target.FlagPre, Proxy: hyb}, + }) + rec := httptest.NewRecorder() + pairs := h.GetAllKeyValues([]byte(req.URL.Path)) + fmt.Printf("%#v\n", pairs) + assert.True(t, getServeData(rec, req, h)) +} diff --git a/target/redirect.go b/target/redirect.go index 4e6766f..8af1e83 100644 --- a/target/redirect.go +++ b/target/redirect.go @@ -23,7 +23,7 @@ type RedirectWithActive struct { Active bool `json:"active"` } -func (r Route) HasFlag(flag Flags) bool { +func (r Redirect) HasFlag(flag Flags) bool { return r.Flags&flag != 0 } diff --git a/target/route.go b/target/route.go index 506583b..c887cf3 100644 --- a/target/route.go +++ b/target/route.go @@ -48,6 +48,10 @@ type RouteWithActive struct { Active bool `json:"active"` } +func (r Route) HasFlag(flag Flags) bool { + return r.Flags&flag != 0 +} + // UpdateHeaders takes an existing set of headers and overwrites them with the // extra headers. func (r Route) UpdateHeaders(header http.Header) {