Add filterWebsocketHeaders function

This commit is contained in:
Melon 2024-08-06 00:00:54 +01:00
parent 1f4f4414d5
commit 8aa82303ce
Signed by: melon
GPG Key ID: 6C9D970C50D26A25

View File

@ -4,6 +4,7 @@ import (
"github.com/1f349/violet/logger" "github.com/1f349/violet/logger"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"net/http" "net/http"
"slices"
"sync" "sync"
"time" "time"
) )
@ -59,7 +60,7 @@ func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) {
Logger.Info("Dialing", "url", req.URL) Logger.Info("Dialing", "url", req.URL)
// dial for internal connection // dial for internal connection
ic, _, err := websocket.DefaultDialer.DialContext(req.Context(), req.URL.String(), nil) ic, _, err := websocket.DefaultDialer.DialContext(req.Context(), req.URL.String(), filterWebsocketHeaders(req.Header))
if err != nil { if err != nil {
Logger.Info("Failed to dial", "url", req.URL, "err", err) Logger.Info("Failed to dial", "url", req.URL, "err", err)
s.Remove(c) s.Remove(c)
@ -84,6 +85,16 @@ func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) {
} }
} }
// filterWebsocketHeaders allows specific headers to forward to the underlying websocket connection
func filterWebsocketHeaders(headers http.Header) (out http.Header) {
for k, v := range headers {
if k == "Origin" {
out[k] = slices.Clone(v)
}
}
return
}
func (s *Server) wsRelay(done chan struct{}, a, b *websocket.Conn) { func (s *Server) wsRelay(done chan struct{}, a, b *websocket.Conn) {
defer func() { defer func() {
close(done) close(done)