mirror of
https://github.com/1f349/violet.git
synced 2024-11-09 22:22:50 +00:00
Improve websocket closing controller
This commit is contained in:
parent
d19050060a
commit
a16617b131
@ -62,15 +62,19 @@ func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) {
|
|||||||
s.Remove(c)
|
s.Remove(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
done := make(chan struct{}, 1)
|
d1 := make(chan struct{}, 1)
|
||||||
|
d2 := make(chan struct{}, 1)
|
||||||
|
|
||||||
// relay messages each way
|
// relay messages each way
|
||||||
go s.wsRelay(done, c, ic)
|
go s.wsRelay(d1, c, ic)
|
||||||
go s.wsRelay(done, ic, c)
|
go s.wsRelay(d2, ic, c)
|
||||||
|
|
||||||
// wait for done signal and close both connections
|
// wait for done signal and close both connections
|
||||||
go func() {
|
go func() {
|
||||||
<-done
|
select {
|
||||||
|
case <-d1:
|
||||||
|
case <-d2:
|
||||||
|
}
|
||||||
_ = c.Close()
|
_ = c.Close()
|
||||||
_ = ic.Close()
|
_ = ic.Close()
|
||||||
}()
|
}()
|
||||||
@ -80,7 +84,7 @@ func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) {
|
|||||||
|
|
||||||
func (s *Server) wsRelay(done chan struct{}, a, b *websocket.Conn) {
|
func (s *Server) wsRelay(done chan struct{}, a, b *websocket.Conn) {
|
||||||
defer func() {
|
defer func() {
|
||||||
done <- struct{}{}
|
close(done)
|
||||||
}()
|
}()
|
||||||
for {
|
for {
|
||||||
mt, message, err := a.ReadMessage()
|
mt, message, err := a.ReadMessage()
|
||||||
|
@ -90,29 +90,29 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
|
|
||||||
func (r *Router) serveRouteHTTP(rw http.ResponseWriter, req *http.Request, host string) bool {
|
func (r *Router) serveRouteHTTP(rw http.ResponseWriter, req *http.Request, host string) bool {
|
||||||
h := r.route[host]
|
h := r.route[host]
|
||||||
if h != nil {
|
return getServeData(rw, req, h)
|
||||||
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
|
|
||||||
for i := len(pairs) - 1; i >= 0; i-- {
|
|
||||||
if pairs[i].Value.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
|
|
||||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
|
|
||||||
pairs[i].Value.ServeHTTP(rw, req)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) serveRedirectHTTP(rw http.ResponseWriter, req *http.Request, host string) bool {
|
func (r *Router) serveRedirectHTTP(rw http.ResponseWriter, req *http.Request, host string) bool {
|
||||||
h := r.redirect[host]
|
h := r.redirect[host]
|
||||||
if h != nil {
|
return getServeData(rw, req, h)
|
||||||
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
|
}
|
||||||
for i := len(pairs) - 1; i >= 0; i-- {
|
|
||||||
if pairs[i].Value.Flags.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
|
type serveDataInterface interface {
|
||||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
|
HasFlag(flag target.Flags) bool
|
||||||
pairs[i].Value.ServeHTTP(rw, req)
|
ServeHTTP(rw http.ResponseWriter, req *http.Request)
|
||||||
return true
|
}
|
||||||
}
|
|
||||||
|
func getServeData[T serveDataInterface](rw http.ResponseWriter, req *http.Request, h *trie.Trie[T]) bool {
|
||||||
|
if h == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
|
||||||
|
for i := len(pairs) - 1; i >= 0; i-- {
|
||||||
|
if pairs[i].Value.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
|
||||||
|
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
|
||||||
|
pairs[i].Value.ServeHTTP(rw, req)
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
package router
|
package router
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"github.com/1f349/violet/proxy"
|
"github.com/1f349/violet/proxy"
|
||||||
"github.com/1f349/violet/proxy/websocket"
|
"github.com/1f349/violet/proxy/websocket"
|
||||||
"github.com/1f349/violet/target"
|
"github.com/1f349/violet/target"
|
||||||
|
"github.com/MrMelon54/trie"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -300,3 +303,24 @@ func TestRouter_AddWildcardRoute(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type fakeRoundTripper struct{}
|
||||||
|
|
||||||
|
func (f *fakeRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
rec.WriteHeader(http.StatusNotFound)
|
||||||
|
return rec.Result(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetServeData_Route(t *testing.T) {
|
||||||
|
hyb := proxy.NewHybridTransportWithCalls(&fakeRoundTripper{}, &fakeRoundTripper{}, nil)
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://example.com/hello/world/this/is/a/test", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
h := trie.BuildFromMap(map[string]target.Route{
|
||||||
|
"/hello/world": {Flags: target.FlagPre, Proxy: hyb},
|
||||||
|
})
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
|
||||||
|
fmt.Printf("%#v\n", pairs)
|
||||||
|
assert.True(t, getServeData(rec, req, h))
|
||||||
|
}
|
||||||
|
@ -23,7 +23,7 @@ type RedirectWithActive struct {
|
|||||||
Active bool `json:"active"`
|
Active bool `json:"active"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r Route) HasFlag(flag Flags) bool {
|
func (r Redirect) HasFlag(flag Flags) bool {
|
||||||
return r.Flags&flag != 0
|
return r.Flags&flag != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,6 +48,10 @@ type RouteWithActive struct {
|
|||||||
Active bool `json:"active"`
|
Active bool `json:"active"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r Route) HasFlag(flag Flags) bool {
|
||||||
|
return r.Flags&flag != 0
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateHeaders takes an existing set of headers and overwrites them with the
|
// UpdateHeaders takes an existing set of headers and overwrites them with the
|
||||||
// extra headers.
|
// extra headers.
|
||||||
func (r Route) UpdateHeaders(header http.Header) {
|
func (r Route) UpdateHeaders(header http.Header) {
|
||||||
|
Loading…
Reference in New Issue
Block a user