violet/proxy/websocket/server.go
2024-08-06 00:18:36 +01:00

135 lines
2.8 KiB
Go

package websocket
import (
"github.com/1f349/violet/logger"
"github.com/gorilla/websocket"
"net/http"
"slices"
"sync"
"time"
)
var Logger = logger.Logger.WithPrefix("Violet Websocket")
var upgrader = websocket.Upgrader{
HandshakeTimeout: time.Second * 5,
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// allow requests from any origin
// the internal service can decide what origins to allow
return true
},
}
type Server struct {
connLock *sync.RWMutex
connStop bool
conns map[string]*websocket.Conn
}
func NewServer() *Server {
return &Server{
connLock: new(sync.RWMutex),
conns: make(map[string]*websocket.Conn),
}
}
func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) {
req.URL.Scheme = "ws"
Logger.Info("Upgrading request", "url", req.URL, "origin", req.Header.Get("Origin"))
c, err := upgrader.Upgrade(rw, req, nil)
if err != nil {
return
}
defer c.Close()
s.connLock.Lock()
// no more connections allowed
if s.connStop {
s.connLock.Unlock()
return
}
// save connection for shutdown
s.conns[c.RemoteAddr().String()] = c
s.connLock.Unlock()
Logger.Info("Dialing", "url", req.URL)
// dial for internal connection
ic, _, err := websocket.DefaultDialer.DialContext(req.Context(), req.URL.String(), filterWebsocketHeaders(req.Header))
if err != nil {
Logger.Info("Failed to dial", "url", req.URL, "err", err)
s.Remove(c)
return
}
defer ic.Close()
d1 := make(chan struct{}, 1)
d2 := make(chan struct{}, 1)
// relay messages each way
go s.wsRelay(d1, c, ic)
go s.wsRelay(d2, ic, c)
// wait for done signal and close both connections
Logger.Info("Completed websocket hijacking")
// waiting until d1 or d2 close then automatically defer close both connections
select {
case <-d1:
case <-d2:
}
}
// filterWebsocketHeaders allows specific headers to forward to the underlying websocket connection
func filterWebsocketHeaders(headers http.Header) (out http.Header) {
out = make(http.Header)
for k, v := range headers {
if k == "Origin" {
out[k] = slices.Clone(v)
}
}
return
}
func (s *Server) wsRelay(done chan struct{}, a, b *websocket.Conn) {
defer func() {
close(done)
}()
for {
mt, message, err := a.ReadMessage()
if err != nil {
Logger.Info("Read message", "err", err)
return
}
if b.WriteMessage(mt, message) != nil {
return
}
}
}
func (s *Server) Remove(c *websocket.Conn) {
s.connLock.Lock()
delete(s.conns, c.RemoteAddr().String())
s.connLock.Unlock()
_ = c.Close()
}
func (s *Server) Shutdown() {
s.connLock.Lock()
defer s.connLock.Unlock()
// flag shutdown and close all open connections
s.connStop = true
for _, i := range s.conns {
_ = i.Close()
}
// clear connections, not required but do it anyway
s.conns = make(map[string]*websocket.Conn)
}