Improve websocket closing controller

This commit is contained in:
Melon 2023-08-28 23:09:29 +01:00
parent d19050060a
commit a16617b131
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
5 changed files with 57 additions and 25 deletions

View File

@ -62,15 +62,19 @@ func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) {
s.Remove(c)
return
}
done := make(chan struct{}, 1)
d1 := make(chan struct{}, 1)
d2 := make(chan struct{}, 1)
// relay messages each way
go s.wsRelay(done, c, ic)
go s.wsRelay(done, ic, c)
go s.wsRelay(d1, c, ic)
go s.wsRelay(d2, ic, c)
// wait for done signal and close both connections
go func() {
<-done
select {
case <-d1:
case <-d2:
}
_ = c.Close()
_ = ic.Close()
}()
@ -80,7 +84,7 @@ func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) {
func (s *Server) wsRelay(done chan struct{}, a, b *websocket.Conn) {
defer func() {
done <- struct{}{}
close(done)
}()
for {
mt, message, err := a.ReadMessage()

View File

@ -90,29 +90,29 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
func (r *Router) serveRouteHTTP(rw http.ResponseWriter, req *http.Request, host string) bool {
h := r.route[host]
if h != nil {
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
for i := len(pairs) - 1; i >= 0; i-- {
if pairs[i].Value.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
pairs[i].Value.ServeHTTP(rw, req)
return true
}
}
}
return false
return getServeData(rw, req, h)
}
func (r *Router) serveRedirectHTTP(rw http.ResponseWriter, req *http.Request, host string) bool {
h := r.redirect[host]
if h != nil {
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
for i := len(pairs) - 1; i >= 0; i-- {
if pairs[i].Value.Flags.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
pairs[i].Value.ServeHTTP(rw, req)
return true
}
return getServeData(rw, req, h)
}
type serveDataInterface interface {
HasFlag(flag target.Flags) bool
ServeHTTP(rw http.ResponseWriter, req *http.Request)
}
func getServeData[T serveDataInterface](rw http.ResponseWriter, req *http.Request, h *trie.Trie[T]) bool {
if h == nil {
return false
}
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
for i := len(pairs) - 1; i >= 0; i-- {
if pairs[i].Value.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
pairs[i].Value.ServeHTTP(rw, req)
return true
}
}
return false

View File

@ -1,9 +1,12 @@
package router
import (
"fmt"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/1f349/violet/target"
"github.com/MrMelon54/trie"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"net/url"
@ -300,3 +303,24 @@ func TestRouter_AddWildcardRoute(t *testing.T) {
}
}
}
type fakeRoundTripper struct{}
func (f *fakeRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) {
rec := httptest.NewRecorder()
rec.WriteHeader(http.StatusNotFound)
return rec.Result(), nil
}
func TestGetServeData_Route(t *testing.T) {
hyb := proxy.NewHybridTransportWithCalls(&fakeRoundTripper{}, &fakeRoundTripper{}, nil)
req, err := http.NewRequest(http.MethodGet, "https://example.com/hello/world/this/is/a/test", nil)
assert.NoError(t, err)
h := trie.BuildFromMap(map[string]target.Route{
"/hello/world": {Flags: target.FlagPre, Proxy: hyb},
})
rec := httptest.NewRecorder()
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
fmt.Printf("%#v\n", pairs)
assert.True(t, getServeData(rec, req, h))
}

View File

@ -23,7 +23,7 @@ type RedirectWithActive struct {
Active bool `json:"active"`
}
func (r Route) HasFlag(flag Flags) bool {
func (r Redirect) HasFlag(flag Flags) bool {
return r.Flags&flag != 0
}

View File

@ -48,6 +48,10 @@ type RouteWithActive struct {
Active bool `json:"active"`
}
func (r Route) HasFlag(flag Flags) bool {
return r.Flags&flag != 0
}
// UpdateHeaders takes an existing set of headers and overwrites them with the
// extra headers.
func (r Route) UpdateHeaders(header http.Header) {