mirror of
https://github.com/1f349/violet.git
synced 2024-12-21 23:14:04 +00:00
Improve websocket closing controller
This commit is contained in:
parent
d19050060a
commit
a16617b131
@ -62,15 +62,19 @@ func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) {
|
||||
s.Remove(c)
|
||||
return
|
||||
}
|
||||
done := make(chan struct{}, 1)
|
||||
d1 := make(chan struct{}, 1)
|
||||
d2 := make(chan struct{}, 1)
|
||||
|
||||
// relay messages each way
|
||||
go s.wsRelay(done, c, ic)
|
||||
go s.wsRelay(done, ic, c)
|
||||
go s.wsRelay(d1, c, ic)
|
||||
go s.wsRelay(d2, ic, c)
|
||||
|
||||
// wait for done signal and close both connections
|
||||
go func() {
|
||||
<-done
|
||||
select {
|
||||
case <-d1:
|
||||
case <-d2:
|
||||
}
|
||||
_ = c.Close()
|
||||
_ = ic.Close()
|
||||
}()
|
||||
@ -80,7 +84,7 @@ func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
func (s *Server) wsRelay(done chan struct{}, a, b *websocket.Conn) {
|
||||
defer func() {
|
||||
done <- struct{}{}
|
||||
close(done)
|
||||
}()
|
||||
for {
|
||||
mt, message, err := a.ReadMessage()
|
||||
|
@ -90,29 +90,29 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
func (r *Router) serveRouteHTTP(rw http.ResponseWriter, req *http.Request, host string) bool {
|
||||
h := r.route[host]
|
||||
if h != nil {
|
||||
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
|
||||
for i := len(pairs) - 1; i >= 0; i-- {
|
||||
if pairs[i].Value.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
|
||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
|
||||
pairs[i].Value.ServeHTTP(rw, req)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
return getServeData(rw, req, h)
|
||||
}
|
||||
|
||||
func (r *Router) serveRedirectHTTP(rw http.ResponseWriter, req *http.Request, host string) bool {
|
||||
h := r.redirect[host]
|
||||
if h != nil {
|
||||
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
|
||||
for i := len(pairs) - 1; i >= 0; i-- {
|
||||
if pairs[i].Value.Flags.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
|
||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
|
||||
pairs[i].Value.ServeHTTP(rw, req)
|
||||
return true
|
||||
}
|
||||
return getServeData(rw, req, h)
|
||||
}
|
||||
|
||||
type serveDataInterface interface {
|
||||
HasFlag(flag target.Flags) bool
|
||||
ServeHTTP(rw http.ResponseWriter, req *http.Request)
|
||||
}
|
||||
|
||||
func getServeData[T serveDataInterface](rw http.ResponseWriter, req *http.Request, h *trie.Trie[T]) bool {
|
||||
if h == nil {
|
||||
return false
|
||||
}
|
||||
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
|
||||
for i := len(pairs) - 1; i >= 0; i-- {
|
||||
if pairs[i].Value.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
|
||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
|
||||
pairs[i].Value.ServeHTTP(rw, req)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
@ -1,9 +1,12 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/1f349/violet/proxy"
|
||||
"github.com/1f349/violet/proxy/websocket"
|
||||
"github.com/1f349/violet/target"
|
||||
"github.com/MrMelon54/trie"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@ -300,3 +303,24 @@ func TestRouter_AddWildcardRoute(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type fakeRoundTripper struct{}
|
||||
|
||||
func (f *fakeRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||
rec := httptest.NewRecorder()
|
||||
rec.WriteHeader(http.StatusNotFound)
|
||||
return rec.Result(), nil
|
||||
}
|
||||
|
||||
func TestGetServeData_Route(t *testing.T) {
|
||||
hyb := proxy.NewHybridTransportWithCalls(&fakeRoundTripper{}, &fakeRoundTripper{}, nil)
|
||||
req, err := http.NewRequest(http.MethodGet, "https://example.com/hello/world/this/is/a/test", nil)
|
||||
assert.NoError(t, err)
|
||||
h := trie.BuildFromMap(map[string]target.Route{
|
||||
"/hello/world": {Flags: target.FlagPre, Proxy: hyb},
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
|
||||
fmt.Printf("%#v\n", pairs)
|
||||
assert.True(t, getServeData(rec, req, h))
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ type RedirectWithActive struct {
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
func (r Route) HasFlag(flag Flags) bool {
|
||||
func (r Redirect) HasFlag(flag Flags) bool {
|
||||
return r.Flags&flag != 0
|
||||
}
|
||||
|
||||
|
@ -48,6 +48,10 @@ type RouteWithActive struct {
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
func (r Route) HasFlag(flag Flags) bool {
|
||||
return r.Flags&flag != 0
|
||||
}
|
||||
|
||||
// UpdateHeaders takes an existing set of headers and overwrites them with the
|
||||
// extra headers.
|
||||
func (r Route) UpdateHeaders(header http.Header) {
|
||||
|
Loading…
Reference in New Issue
Block a user