violet/proxy/websocket/server.go

135 lines
2.8 KiB
Go
Raw Normal View History

2023-08-17 14:38:00 +01:00
package websocket
import (
2024-05-13 19:33:33 +01:00
"github.com/1f349/violet/logger"
2023-08-17 14:38:00 +01:00
"github.com/gorilla/websocket"
"net/http"
2024-08-06 00:00:54 +01:00
"slices"
2023-08-17 14:38:00 +01:00
"sync"
"time"
)
2024-05-13 19:33:33 +01:00
var Logger = logger.Logger.WithPrefix("Violet Websocket")
2023-08-17 14:38:00 +01:00
var upgrader = websocket.Upgrader{
HandshakeTimeout: time.Second * 5,
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// allow requests from any origin
// the internal service can decide what origins to allow
return true
},
}
type Server struct {
connLock *sync.RWMutex
connStop bool
conns map[string]*websocket.Conn
}
func NewServer() *Server {
return &Server{
connLock: new(sync.RWMutex),
conns: make(map[string]*websocket.Conn),
}
}
func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) {
req.URL.Scheme = "ws"
2024-05-13 19:33:33 +01:00
Logger.Info("Upgrading request", "url", req.URL, "origin", req.Header.Get("Origin"))
2023-08-17 14:38:00 +01:00
c, err := upgrader.Upgrade(rw, req, nil)
if err != nil {
return
}
2023-09-10 21:49:57 +01:00
defer c.Close()
2023-08-17 14:38:00 +01:00
s.connLock.Lock()
// no more connections allowed
if s.connStop {
2023-08-17 14:57:41 +01:00
s.connLock.Unlock()
2023-08-17 14:38:00 +01:00
return
}
// save connection for shutdown
s.conns[c.RemoteAddr().String()] = c
2023-08-17 14:57:41 +01:00
s.connLock.Unlock()
2023-08-17 14:38:00 +01:00
2024-05-13 19:33:33 +01:00
Logger.Info("Dialing", "url", req.URL)
2023-08-17 14:38:00 +01:00
// dial for internal connection
2024-08-06 00:00:54 +01:00
ic, _, err := websocket.DefaultDialer.DialContext(req.Context(), req.URL.String(), filterWebsocketHeaders(req.Header))
2023-08-17 14:38:00 +01:00
if err != nil {
2024-05-13 19:33:33 +01:00
Logger.Info("Failed to dial", "url", req.URL, "err", err)
2023-08-17 14:38:00 +01:00
s.Remove(c)
return
}
2023-09-10 21:49:57 +01:00
defer ic.Close()
2023-08-28 23:09:29 +01:00
d1 := make(chan struct{}, 1)
d2 := make(chan struct{}, 1)
2023-08-17 14:38:00 +01:00
// relay messages each way
2023-08-28 23:09:29 +01:00
go s.wsRelay(d1, c, ic)
go s.wsRelay(d2, ic, c)
2023-08-17 14:38:00 +01:00
// wait for done signal and close both connections
2024-05-13 19:33:33 +01:00
Logger.Info("Completed websocket hijacking")
2023-09-10 21:49:57 +01:00
// waiting until d1 or d2 close then automatically defer close both connections
select {
case <-d1:
case <-d2:
}
2023-08-17 14:38:00 +01:00
}
2024-08-06 00:00:54 +01:00
// filterWebsocketHeaders allows specific headers to forward to the underlying websocket connection
func filterWebsocketHeaders(headers http.Header) (out http.Header) {
2024-08-06 00:18:36 +01:00
out = make(http.Header)
2024-08-06 00:00:54 +01:00
for k, v := range headers {
if k == "Origin" {
out[k] = slices.Clone(v)
}
}
return
}
2023-08-17 14:38:00 +01:00
func (s *Server) wsRelay(done chan struct{}, a, b *websocket.Conn) {
defer func() {
2023-08-28 23:09:29 +01:00
close(done)
2023-08-17 14:38:00 +01:00
}()
for {
mt, message, err := a.ReadMessage()
if err != nil {
2024-05-13 19:33:33 +01:00
Logger.Info("Read message", "err", err)
2023-08-17 14:38:00 +01:00
return
}
if b.WriteMessage(mt, message) != nil {
return
}
}
}
func (s *Server) Remove(c *websocket.Conn) {
s.connLock.Lock()
delete(s.conns, c.RemoteAddr().String())
s.connLock.Unlock()
_ = c.Close()
}
func (s *Server) Shutdown() {
s.connLock.Lock()
defer s.connLock.Unlock()
// flag shutdown and close all open connections
s.connStop = true
for _, i := range s.conns {
_ = i.Close()
}
// clear connections, not required but do it anyway
s.conns = make(map[string]*websocket.Conn)
}