From b5ff809345cdab058318622e4d93507716f311f2 Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Thu, 17 Aug 2023 15:23:23 +0100 Subject: [PATCH] Some changes after debugging websockets --- cmd/violet/serve.go | 2 +- proxy/websocket/server.go | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/cmd/violet/serve.go b/cmd/violet/serve.go index 3c18da9..2dd4a18 100644 --- a/cmd/violet/serve.go +++ b/cmd/violet/serve.go @@ -170,7 +170,7 @@ func normalLoad(startUp startUpConfig, wd string) { _ = srvHttp.Close() } if srvHttps != nil { - _ = srvHttps.Shutdown(context.Background()) + _ = srvHttps.Close() } }) } diff --git a/proxy/websocket/server.go b/proxy/websocket/server.go index 9a752aa..50ff72e 100644 --- a/proxy/websocket/server.go +++ b/proxy/websocket/server.go @@ -33,6 +33,7 @@ func NewServer() *Server { } func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) { + req.URL.Scheme = "ws" log.Printf("[Websocket] Upgrading request to '%s' from '%s'\n", req.URL.String(), req.Header.Get("Origin")) c, err := upgrader.Upgrade(rw, req, nil) @@ -55,16 +56,17 @@ func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) { log.Printf("[Websocket] Dialing: '%s'\n", req.URL.String()) // dial for internal connection - ic, _, err := websocket.DefaultDialer.DialContext(req.Context(), req.URL.String(), req.Header) + ic, _, err := websocket.DefaultDialer.DialContext(req.Context(), req.URL.String(), nil) if err != nil { + log.Printf("[Websocket] Failed to dial '%s': %s\n", req.URL.String(), err) s.Remove(c) return } done := make(chan struct{}, 1) // relay messages each way - s.wsRelay(done, c, ic) - s.wsRelay(done, ic, c) + go s.wsRelay(done, c, ic) + go s.wsRelay(done, ic, c) // wait for done signal and close both connections go func() { @@ -72,6 +74,8 @@ func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) { _ = c.Close() _ = ic.Close() }() + + log.Println("[Websocket] Completed websocket hijacking") } func (s *Server) wsRelay(done chan struct{}, a, b *websocket.Conn) { @@ -81,6 +85,7 @@ func (s *Server) wsRelay(done chan struct{}, a, b *websocket.Conn) { for { mt, message, err := a.ReadMessage() if err != nil { + log.Println("Websocket read message error: ", err) return } if b.WriteMessage(mt, message) != nil {