diff --git a/cmd/violet/serve.go b/cmd/violet/serve.go index d0c08e9..8dcc122 100644 --- a/cmd/violet/serve.go +++ b/cmd/violet/serve.go @@ -10,6 +10,7 @@ import ( errorPages "github.com/1f349/violet/error-pages" "github.com/1f349/violet/favicons" "github.com/1f349/violet/proxy" + "github.com/1f349/violet/proxy/websocket" "github.com/1f349/violet/router" "github.com/1f349/violet/servers" "github.com/1f349/violet/servers/api" @@ -109,10 +110,11 @@ func normalLoad(startUp startUpConfig, wd string) { certDir := os.DirFS(filepath.Join(wd, "certs")) keyDir := os.DirFS(filepath.Join(wd, "keys")) + ws := websocket.NewServer() allowedDomains := domains.New(db) // load allowed domains acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store allowedCerts := certs.New(certDir, keyDir, startUp.SelfSigned) // load certificate manager - hybridTransport := proxy.NewHybridTransport() // load reverse proxy + hybridTransport := proxy.NewHybridTransport(ws) // load reverse proxy dynamicFavicons := favicons.New(db, startUp.InkscapeCmd) // load dynamic favicon provider dynamicErrorPages := errorPages.New(errorPageDir) // load dynamic error page provider dynamicRouter := router.NewManager(db, hybridTransport) // load dynamic router manager @@ -167,5 +169,6 @@ func normalLoad(startUp startUpConfig, wd string) { if srvHttps != nil { srvHttps.Close() } + ws.Shutdown() }) } diff --git a/cmd/violet/setup.go b/cmd/violet/setup.go index 617e02e..02a8a41 100644 --- a/cmd/violet/setup.go +++ b/cmd/violet/setup.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/1f349/violet/domains" "github.com/1f349/violet/proxy" + "github.com/1f349/violet/proxy/websocket" "github.com/1f349/violet/router" "github.com/1f349/violet/target" "github.com/AlecAivazis/survey/v2" @@ -180,7 +181,7 @@ func (s *setupCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) // add with the route manager, no need to compile as this will run when opened // with the serve subcommand - routeManager := router.NewManager(db, proxy.NewHybridTransportWithCalls(&nilTransport{}, &nilTransport{})) + routeManager := router.NewManager(db, proxy.NewHybridTransportWithCalls(&nilTransport{}, &nilTransport{}, &websocket.Server{})) err = routeManager.InsertRoute(target.Route{ Src: path.Join(apiUrl.Host, apiUrl.Path), Dst: answers.ApiListen, diff --git a/go.mod b/go.mod index 4a8f578..a8f47d7 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/MrMelon54/rescheduler v0.0.1 github.com/MrMelon54/trie v0.0.2 github.com/google/subcommands v1.2.0 + github.com/gorilla/websocket v1.5.0 github.com/julienschmidt/httprouter v1.3.0 github.com/mattn/go-sqlite3 v1.14.16 github.com/rs/cors v1.9.0 diff --git a/go.sum b/go.sum index 2c98fe4..9a0c320 100644 --- a/go.sum +++ b/go.sum @@ -23,6 +23,8 @@ github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOW github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= diff --git a/proxy/hybrid-transport.go b/proxy/hybrid-transport.go index 22235e0..f9c3871 100644 --- a/proxy/hybrid-transport.go +++ b/proxy/hybrid-transport.go @@ -2,6 +2,7 @@ package proxy import ( "crypto/tls" + "github.com/1f349/violet/proxy/websocket" "net" "net/http" "sync" @@ -14,18 +15,19 @@ type HybridTransport struct { insecureTransport http.RoundTripper socksSync *sync.RWMutex socksTransport map[string]http.RoundTripper + ws *websocket.Server } // NewHybridTransport creates a new hybrid transport -func NewHybridTransport() *HybridTransport { - return NewHybridTransportWithCalls(nil, nil) +func NewHybridTransport(ws *websocket.Server) *HybridTransport { + return NewHybridTransportWithCalls(nil, nil, ws) } // NewHybridTransportWithCalls creates new hybrid transport with custom normal // and insecure http.RoundTripper functions. // // NewHybridTransportWithCalls(nil, nil) is equivalent to NewHybridTransport() -func NewHybridTransportWithCalls(normal, insecure http.RoundTripper) *HybridTransport { +func NewHybridTransportWithCalls(normal, insecure http.RoundTripper, ws *websocket.Server) *HybridTransport { h := &HybridTransport{ baseDialer: &net.Dialer{ Timeout: 30 * time.Second, @@ -33,6 +35,7 @@ func NewHybridTransportWithCalls(normal, insecure http.RoundTripper) *HybridTran }, normalTransport: normal, insecureTransport: insecure, + ws: ws, } if h.normalTransport == nil { h.normalTransport = &http.Transport{ @@ -71,3 +74,8 @@ func (h *HybridTransport) SecureRoundTrip(req *http.Request) (*http.Response, er func (h *HybridTransport) InsecureRoundTrip(req *http.Request) (*http.Response, error) { return h.insecureTransport.RoundTrip(req) } + +// ConnectWebsocket calls the websocket upgrader and thus hijacks the connection +func (h *HybridTransport) ConnectWebsocket(rw http.ResponseWriter, req *http.Request) { + h.ws.Upgrade(rw, req) +} diff --git a/proxy/hybrid-transport_test.go b/proxy/hybrid-transport_test.go index 8c3b34f..ecb78ab 100644 --- a/proxy/hybrid-transport_test.go +++ b/proxy/hybrid-transport_test.go @@ -7,7 +7,7 @@ import ( ) func TestNewHybridTransport(t *testing.T) { - h := NewHybridTransport() + h := NewHybridTransport(nil) req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) assert.NoError(t, err) trip, err := h.SecureRoundTrip(req) diff --git a/proxy/websocket/server.go b/proxy/websocket/server.go new file mode 100644 index 0000000..dc80d97 --- /dev/null +++ b/proxy/websocket/server.go @@ -0,0 +1,108 @@ +package websocket + +import ( + "github.com/gorilla/websocket" + "log" + "net/http" + "sync" + "time" +) + +var upgrader = websocket.Upgrader{ + HandshakeTimeout: time.Second * 5, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + // allow requests from any origin + // the internal service can decide what origins to allow + return true + }, +} + +type Server struct { + connLock *sync.RWMutex + connStop bool + conns map[string]*websocket.Conn +} + +func NewServer() *Server { + return &Server{ + connLock: new(sync.RWMutex), + conns: make(map[string]*websocket.Conn), + } +} + +func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) { + c, err := upgrader.Upgrade(rw, req, nil) + if err != nil { + return + } + s.connLock.Lock() + defer s.connLock.Unlock() + + // no more connections allowed + if s.connStop { + _ = c.Close() + return + } + + // save connection for shutdown + s.conns[c.RemoteAddr().String()] = c + + log.Printf("[Websocket] Dialing: '%s'\n", req.URL.String()) + + // dial for internal connection + ic, _, err := websocket.DefaultDialer.DialContext(req.Context(), req.URL.String(), req.Header) + if err != nil { + s.Remove(c) + return + } + done := make(chan struct{}, 1) + + // relay messages each way + s.wsRelay(done, c, ic) + s.wsRelay(done, ic, c) + + // wait for done signal and close both connections + go func() { + <-done + _ = c.Close() + _ = ic.Close() + }() +} + +func (s *Server) wsRelay(done chan struct{}, a, b *websocket.Conn) { + defer func() { + done <- struct{}{} + }() + for { + mt, message, err := a.ReadMessage() + if err != nil { + return + } + if b.WriteMessage(mt, message) != nil { + return + } + } +} + +func (s *Server) Remove(c *websocket.Conn) { + s.connLock.Lock() + delete(s.conns, c.RemoteAddr().String()) + s.connLock.Unlock() + _ = c.Close() +} + +func (s *Server) Shutdown() { + s.connLock.Lock() + defer s.connLock.Unlock() + + // flag shutdown and close all open connections + s.connStop = true + for _, i := range s.conns { + _ = i.Close() + } + + // clear connections, not required but do it anyway + s.conns = make(map[string]*websocket.Conn) +} diff --git a/router/manager_test.go b/router/manager_test.go index 940ba5c..882a34c 100644 --- a/router/manager_test.go +++ b/router/manager_test.go @@ -3,6 +3,7 @@ package router import ( "database/sql" "github.com/1f349/violet/proxy" + "github.com/1f349/violet/proxy/websocket" "github.com/1f349/violet/target" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" @@ -25,7 +26,7 @@ func TestNewManager(t *testing.T) { assert.NoError(t, err) ft := &fakeTransport{} - ht := proxy.NewHybridTransportWithCalls(ft, ft) + ht := proxy.NewHybridTransportWithCalls(ft, ft, &websocket.Server{}) m := NewManager(db, ht) assert.NoError(t, m.internalCompile(m.r)) diff --git a/router/router_test.go b/router/router_test.go index 4fc0738..25aa670 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -2,6 +2,7 @@ package router import ( "github.com/1f349/violet/proxy" + "github.com/1f349/violet/proxy/websocket" "github.com/1f349/violet/target" "net/http" "net/http/httptest" @@ -180,7 +181,7 @@ func TestRouter_AddRoute(t *testing.T) { transInsecure := &fakeTransport{} for _, i := range routeTests { - r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure)) + r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure, &websocket.Server{})) dst := i.dst dst.Dst = path.Join("127.0.0.1:8080", dst.Dst) dst.Src = path.Join("example.com", i.path) @@ -266,7 +267,7 @@ func TestRouter_AddWildcardRoute(t *testing.T) { transInsecure := &fakeTransport{} for _, i := range routeTests { - r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure)) + r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure, &websocket.Server{})) dst := i.dst dst.Dst = path.Join("127.0.0.1:8080", dst.Dst) dst.Src = path.Join("*.example.com", i.path) diff --git a/servers/https_test.go b/servers/https_test.go index b0ce4a6..4412c49 100644 --- a/servers/https_test.go +++ b/servers/https_test.go @@ -4,6 +4,7 @@ import ( "database/sql" "github.com/1f349/violet/certs" "github.com/1f349/violet/proxy" + "github.com/1f349/violet/proxy/websocket" "github.com/1f349/violet/router" "github.com/1f349/violet/servers/conf" "github.com/1f349/violet/utils/fake" @@ -33,7 +34,7 @@ func TestNewHttpsServer_RateLimit(t *testing.T) { Domains: &fake.Domains{}, Certs: certs.New(nil, nil, true), Signer: fake.SnakeOilProv, - Router: router.NewManager(db, proxy.NewHybridTransportWithCalls(ft, ft)), + Router: router.NewManager(db, proxy.NewHybridTransportWithCalls(ft, ft, &websocket.Server{})), } srv := NewHttpsServer(httpsConf) diff --git a/target/flags.go b/target/flags.go index eccad3c..026863e 100644 --- a/target/flags.go +++ b/target/flags.go @@ -10,6 +10,7 @@ const ( FlagForwardHost FlagForwardAddr FlagIgnoreCert + FlagWebsocket ) var ( diff --git a/target/route.go b/target/route.go index 842e0ab..3926eca 100644 --- a/target/route.go +++ b/target/route.go @@ -1,10 +1,10 @@ package target import ( - "context" "fmt" "github.com/1f349/violet/proxy" "github.com/1f349/violet/utils" + websocket2 "github.com/gorilla/websocket" "github.com/rs/cors" "golang.org/x/net/http/httpguts" "io" @@ -138,12 +138,18 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) { if r.HasFlag(FlagForwardHost) { req2.Host = req.Host } - if r.HasFlag(FlagForwardAddr) { - req2.Header.Add("X-Forwarded-For", req.RemoteAddr) - } // adds extra request metadata - r.internalReverseProxyMeta(rw, req) + if r.internalReverseProxyMeta(rw, req, req2) { + return + } + + // switch to websocket handler + // internally the http hijack method is called + if r.HasFlag(FlagWebsocket) && websocket2.IsWebSocketUpgrade(req2) { + r.Proxy.ConnectWebsocket(rw, req2) + return + } // serve request with reverse proxy var resp *http.Response @@ -183,21 +189,20 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) { // due to the highly custom nature of this reverse proxy software we use a copy // of the code instead of the full httputil implementation to prevent overhead // from the more generic implementation -func (r Route) internalReverseProxyMeta(rw http.ResponseWriter, req *http.Request) { - outreq := req.Clone(context.Background()) +func (r Route) internalReverseProxyMeta(rw http.ResponseWriter, req, req2 *http.Request) bool { if req.ContentLength == 0 { - outreq.Body = nil // Issue 16036: nil Body for http.Transport retries + req2.Body = nil // Issue 16036: nil Body for http.Transport retries } - if outreq.Header == nil { - outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate + if req2.Header == nil { + req2.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate } - reqUpType := upgradeType(outreq.Header) + reqUpType := upgradeType(req2.Header) if !asciiIsPrint(reqUpType) { utils.RespondVioletError(rw, http.StatusBadRequest, fmt.Sprintf("client tried to switch to invalid protocol %q", reqUpType)) - return + return true } - removeHopByHopHeaders(outreq.Header) + removeHopByHopHeaders(req2.Header) // Issue 21096: tell backend applications that care about trailer support // that we support trailers. (We do, but we don't go out of our way to @@ -205,29 +210,33 @@ func (r Route) internalReverseProxyMeta(rw http.ResponseWriter, req *http.Reques // mentioning.) Note that we look at req.Header, not outreq.Header, since // the latter has passed through removeHopByHopHeaders. if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") { - outreq.Header.Set("Te", "trailers") + req2.Header.Set("Te", "trailers") } // After stripping all the hop-by-hop connection headers above, add back any // necessary for protocol upgrades, such as for websockets. if reqUpType != "" { - outreq.Header.Set("Connection", "Upgrade") - outreq.Header.Set("Upgrade", reqUpType) + req2.Header.Set("Connection", "Upgrade") + req2.Header.Set("Upgrade", reqUpType) } - if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { - // If we aren't the first proxy retain prior - // X-Forwarded-For information as a comma+space - // separated list and fold multiple headers into one. - prior, ok := outreq.Header["X-Forwarded-For"] - omit := ok && prior == nil // Issue 38079: nil now means don't populate the header - if len(prior) > 0 { - clientIP = strings.Join(prior, ", ") + ", " + clientIP - } - if !omit { - outreq.Header.Set("X-Forwarded-For", clientIP) + if r.HasFlag(FlagForwardAddr) { + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + prior, ok := req2.Header["X-Forwarded-For"] + omit := ok && prior == nil // Issue 38079: nil now means don't populate the header + if len(prior) > 0 { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + if !omit { + req2.Header.Set("X-Forwarded-For", clientIP) + } } } + + return false } // String outputs a debug string for the route. diff --git a/target/route_test.go b/target/route_test.go index 51cf48e..1e2b7b7 100644 --- a/target/route_test.go +++ b/target/route_test.go @@ -3,8 +3,10 @@ package target import ( "bytes" "github.com/1f349/violet/proxy" + "github.com/1f349/violet/proxy/websocket" "github.com/stretchr/testify/assert" "io" + "net" "net/http" "net/http/httptest" "testing" @@ -16,7 +18,7 @@ type proxyTester struct { } func (p *proxyTester) makeHybridTransport() *proxy.HybridTransport { - return proxy.NewHybridTransportWithCalls(p, p) + return proxy.NewHybridTransportWithCalls(p, p, &websocket.Server{}) } func (p *proxyTester) RoundTrip(req *http.Request) (*http.Response, error) { @@ -52,7 +54,9 @@ func TestRoute_ServeHTTP(t *testing.T) { assert.True(t, pt.got) assert.Equal(t, i.target, pt.req.URL.String()) if i.HasFlag(FlagForwardAddr) { - assert.Equal(t, req.RemoteAddr, pt.req.Header.Get("X-Forwarded-For")) + host, _, err := net.SplitHostPort(req.RemoteAddr) + assert.NoError(t, err) + assert.Equal(t, host, pt.req.Header.Get("X-Forwarded-For")) } if i.HasFlag(FlagForwardHost) { assert.Equal(t, req.Host, pt.req.Host) diff --git a/violet.openapi.yaml b/violet.openapi.yaml new file mode 100644 index 0000000..0cd67b3 --- /dev/null +++ b/violet.openapi.yaml @@ -0,0 +1,77 @@ +openapi: 3.0.3 +info: + title: Violet + description: Violet + version: 1.0.0 + contact: + name: Webmaster + email: webmaster@1f349.net +servers: + - url: 'https://api.1f349.net/v1/violet' +paths: + /compile: + post: + summary: Compile quick access data + tags: + - compile + responses: + '202': + description: Compile trigger sent + /domain/{domain}: + put: + summary: Add an allowed domain + tags: + - domain + parameters: + - name: domain + in: path + required: true + description: The domain to add + schema: + type: string + responses: + '202': + description: Domain added and compiled list reloaded + delete: + summary: Remove an allowed domain + tags: + - domain + parameters: + - name: domain + in: path + required: true + description: The domain to remove + schema: + type: string + responses: + '202': + description: Domain removed and compiled list reloaded + /acme-challenge/{domain}/{key}/{value}: + put: + summary: Add ACME challenge value + tags: + - acme-challenge + parameters: + - name: domain + in: path + required: true + description: The domain to add the challenge on + schema: + type: string + responses: + '202': + description: ACME challenge added + delete: + summary: Add ACME challenge value + tags: + - acme-challenge + parameters: + - name: domain + in: path + required: true + description: The domain to add the challenge on + schema: + type: string + responses: + '202': + description: ACME challenge added