mirror of
https://github.com/1f349/violet.git
synced 2024-11-09 22:22:50 +00:00
134 lines
2.8 KiB
Go
134 lines
2.8 KiB
Go
package websocket
|
|
|
|
import (
|
|
"github.com/1f349/violet/logger"
|
|
"github.com/gorilla/websocket"
|
|
"net/http"
|
|
"slices"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
var Logger = logger.Logger.WithPrefix("Violet Websocket")
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
HandshakeTimeout: time.Second * 5,
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
// allow requests from any origin
|
|
// the internal service can decide what origins to allow
|
|
return true
|
|
},
|
|
}
|
|
|
|
type Server struct {
|
|
connLock *sync.RWMutex
|
|
connStop bool
|
|
conns map[string]*websocket.Conn
|
|
}
|
|
|
|
func NewServer() *Server {
|
|
return &Server{
|
|
connLock: new(sync.RWMutex),
|
|
conns: make(map[string]*websocket.Conn),
|
|
}
|
|
}
|
|
|
|
func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) {
|
|
req.URL.Scheme = "ws"
|
|
Logger.Info("Upgrading request", "url", req.URL, "origin", req.Header.Get("Origin"))
|
|
|
|
c, err := upgrader.Upgrade(rw, req, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
defer c.Close()
|
|
s.connLock.Lock()
|
|
|
|
// no more connections allowed
|
|
if s.connStop {
|
|
s.connLock.Unlock()
|
|
return
|
|
}
|
|
|
|
// save connection for shutdown
|
|
s.conns[c.RemoteAddr().String()] = c
|
|
s.connLock.Unlock()
|
|
|
|
Logger.Info("Dialing", "url", req.URL)
|
|
|
|
// dial for internal connection
|
|
ic, _, err := websocket.DefaultDialer.DialContext(req.Context(), req.URL.String(), filterWebsocketHeaders(req.Header))
|
|
if err != nil {
|
|
Logger.Info("Failed to dial", "url", req.URL, "err", err)
|
|
s.Remove(c)
|
|
return
|
|
}
|
|
defer ic.Close()
|
|
|
|
d1 := make(chan struct{}, 1)
|
|
d2 := make(chan struct{}, 1)
|
|
|
|
// relay messages each way
|
|
go s.wsRelay(d1, c, ic)
|
|
go s.wsRelay(d2, ic, c)
|
|
|
|
// wait for done signal and close both connections
|
|
Logger.Info("Completed websocket hijacking")
|
|
|
|
// waiting until d1 or d2 close then automatically defer close both connections
|
|
select {
|
|
case <-d1:
|
|
case <-d2:
|
|
}
|
|
}
|
|
|
|
// filterWebsocketHeaders allows specific headers to forward to the underlying websocket connection
|
|
func filterWebsocketHeaders(headers http.Header) (out http.Header) {
|
|
for k, v := range headers {
|
|
if k == "Origin" {
|
|
out[k] = slices.Clone(v)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (s *Server) wsRelay(done chan struct{}, a, b *websocket.Conn) {
|
|
defer func() {
|
|
close(done)
|
|
}()
|
|
for {
|
|
mt, message, err := a.ReadMessage()
|
|
if err != nil {
|
|
Logger.Info("Read message", "err", err)
|
|
return
|
|
}
|
|
if b.WriteMessage(mt, message) != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) Remove(c *websocket.Conn) {
|
|
s.connLock.Lock()
|
|
delete(s.conns, c.RemoteAddr().String())
|
|
s.connLock.Unlock()
|
|
_ = c.Close()
|
|
}
|
|
|
|
func (s *Server) Shutdown() {
|
|
s.connLock.Lock()
|
|
defer s.connLock.Unlock()
|
|
|
|
// flag shutdown and close all open connections
|
|
s.connStop = true
|
|
for _, i := range s.conns {
|
|
_ = i.Close()
|
|
}
|
|
|
|
// clear connections, not required but do it anyway
|
|
s.conns = make(map[string]*websocket.Conn)
|
|
}
|