Add websocket relay

This commit is contained in:
Melon 2023-08-17 14:38:00 +01:00
parent ce12384c15
commit cf098eb0b9
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
14 changed files with 256 additions and 39 deletions

View File

@ -10,6 +10,7 @@ import (
errorPages "github.com/1f349/violet/error-pages"
"github.com/1f349/violet/favicons"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/1f349/violet/router"
"github.com/1f349/violet/servers"
"github.com/1f349/violet/servers/api"
@ -109,10 +110,11 @@ func normalLoad(startUp startUpConfig, wd string) {
certDir := os.DirFS(filepath.Join(wd, "certs"))
keyDir := os.DirFS(filepath.Join(wd, "keys"))
ws := websocket.NewServer()
allowedDomains := domains.New(db) // load allowed domains
acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store
allowedCerts := certs.New(certDir, keyDir, startUp.SelfSigned) // load certificate manager
hybridTransport := proxy.NewHybridTransport() // load reverse proxy
hybridTransport := proxy.NewHybridTransport(ws) // load reverse proxy
dynamicFavicons := favicons.New(db, startUp.InkscapeCmd) // load dynamic favicon provider
dynamicErrorPages := errorPages.New(errorPageDir) // load dynamic error page provider
dynamicRouter := router.NewManager(db, hybridTransport) // load dynamic router manager
@ -167,5 +169,6 @@ func normalLoad(startUp startUpConfig, wd string) {
if srvHttps != nil {
srvHttps.Close()
}
ws.Shutdown()
})
}

View File

@ -8,6 +8,7 @@ import (
"fmt"
"github.com/1f349/violet/domains"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/1f349/violet/router"
"github.com/1f349/violet/target"
"github.com/AlecAivazis/survey/v2"
@ -180,7 +181,7 @@ func (s *setupCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
// add with the route manager, no need to compile as this will run when opened
// with the serve subcommand
routeManager := router.NewManager(db, proxy.NewHybridTransportWithCalls(&nilTransport{}, &nilTransport{}))
routeManager := router.NewManager(db, proxy.NewHybridTransportWithCalls(&nilTransport{}, &nilTransport{}, &websocket.Server{}))
err = routeManager.InsertRoute(target.Route{
Src: path.Join(apiUrl.Host, apiUrl.Path),
Dst: answers.ApiListen,

1
go.mod
View File

@ -11,6 +11,7 @@ require (
github.com/MrMelon54/rescheduler v0.0.1
github.com/MrMelon54/trie v0.0.2
github.com/google/subcommands v1.2.0
github.com/gorilla/websocket v1.5.0
github.com/julienschmidt/httprouter v1.3.0
github.com/mattn/go-sqlite3 v1.14.16
github.com/rs/cors v1.9.0

2
go.sum
View File

@ -23,6 +23,8 @@ github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOW
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog=
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=

View File

@ -2,6 +2,7 @@ package proxy
import (
"crypto/tls"
"github.com/1f349/violet/proxy/websocket"
"net"
"net/http"
"sync"
@ -14,18 +15,19 @@ type HybridTransport struct {
insecureTransport http.RoundTripper
socksSync *sync.RWMutex
socksTransport map[string]http.RoundTripper
ws *websocket.Server
}
// NewHybridTransport creates a new hybrid transport
func NewHybridTransport() *HybridTransport {
return NewHybridTransportWithCalls(nil, nil)
func NewHybridTransport(ws *websocket.Server) *HybridTransport {
return NewHybridTransportWithCalls(nil, nil, ws)
}
// NewHybridTransportWithCalls creates new hybrid transport with custom normal
// and insecure http.RoundTripper functions.
//
// NewHybridTransportWithCalls(nil, nil) is equivalent to NewHybridTransport()
func NewHybridTransportWithCalls(normal, insecure http.RoundTripper) *HybridTransport {
func NewHybridTransportWithCalls(normal, insecure http.RoundTripper, ws *websocket.Server) *HybridTransport {
h := &HybridTransport{
baseDialer: &net.Dialer{
Timeout: 30 * time.Second,
@ -33,6 +35,7 @@ func NewHybridTransportWithCalls(normal, insecure http.RoundTripper) *HybridTran
},
normalTransport: normal,
insecureTransport: insecure,
ws: ws,
}
if h.normalTransport == nil {
h.normalTransport = &http.Transport{
@ -71,3 +74,8 @@ func (h *HybridTransport) SecureRoundTrip(req *http.Request) (*http.Response, er
func (h *HybridTransport) InsecureRoundTrip(req *http.Request) (*http.Response, error) {
return h.insecureTransport.RoundTrip(req)
}
// ConnectWebsocket calls the websocket upgrader and thus hijacks the connection
func (h *HybridTransport) ConnectWebsocket(rw http.ResponseWriter, req *http.Request) {
h.ws.Upgrade(rw, req)
}

View File

@ -7,7 +7,7 @@ import (
)
func TestNewHybridTransport(t *testing.T) {
h := NewHybridTransport()
h := NewHybridTransport(nil)
req, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
assert.NoError(t, err)
trip, err := h.SecureRoundTrip(req)

108
proxy/websocket/server.go Normal file
View File

@ -0,0 +1,108 @@
package websocket
import (
"github.com/gorilla/websocket"
"log"
"net/http"
"sync"
"time"
)
var upgrader = websocket.Upgrader{
HandshakeTimeout: time.Second * 5,
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// allow requests from any origin
// the internal service can decide what origins to allow
return true
},
}
type Server struct {
connLock *sync.RWMutex
connStop bool
conns map[string]*websocket.Conn
}
func NewServer() *Server {
return &Server{
connLock: new(sync.RWMutex),
conns: make(map[string]*websocket.Conn),
}
}
func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) {
c, err := upgrader.Upgrade(rw, req, nil)
if err != nil {
return
}
s.connLock.Lock()
defer s.connLock.Unlock()
// no more connections allowed
if s.connStop {
_ = c.Close()
return
}
// save connection for shutdown
s.conns[c.RemoteAddr().String()] = c
log.Printf("[Websocket] Dialing: '%s'\n", req.URL.String())
// dial for internal connection
ic, _, err := websocket.DefaultDialer.DialContext(req.Context(), req.URL.String(), req.Header)
if err != nil {
s.Remove(c)
return
}
done := make(chan struct{}, 1)
// relay messages each way
s.wsRelay(done, c, ic)
s.wsRelay(done, ic, c)
// wait for done signal and close both connections
go func() {
<-done
_ = c.Close()
_ = ic.Close()
}()
}
func (s *Server) wsRelay(done chan struct{}, a, b *websocket.Conn) {
defer func() {
done <- struct{}{}
}()
for {
mt, message, err := a.ReadMessage()
if err != nil {
return
}
if b.WriteMessage(mt, message) != nil {
return
}
}
}
func (s *Server) Remove(c *websocket.Conn) {
s.connLock.Lock()
delete(s.conns, c.RemoteAddr().String())
s.connLock.Unlock()
_ = c.Close()
}
func (s *Server) Shutdown() {
s.connLock.Lock()
defer s.connLock.Unlock()
// flag shutdown and close all open connections
s.connStop = true
for _, i := range s.conns {
_ = i.Close()
}
// clear connections, not required but do it anyway
s.conns = make(map[string]*websocket.Conn)
}

View File

@ -3,6 +3,7 @@ package router
import (
"database/sql"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/1f349/violet/target"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
@ -25,7 +26,7 @@ func TestNewManager(t *testing.T) {
assert.NoError(t, err)
ft := &fakeTransport{}
ht := proxy.NewHybridTransportWithCalls(ft, ft)
ht := proxy.NewHybridTransportWithCalls(ft, ft, &websocket.Server{})
m := NewManager(db, ht)
assert.NoError(t, m.internalCompile(m.r))

View File

@ -2,6 +2,7 @@ package router
import (
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/1f349/violet/target"
"net/http"
"net/http/httptest"
@ -180,7 +181,7 @@ func TestRouter_AddRoute(t *testing.T) {
transInsecure := &fakeTransport{}
for _, i := range routeTests {
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure))
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure, &websocket.Server{}))
dst := i.dst
dst.Dst = path.Join("127.0.0.1:8080", dst.Dst)
dst.Src = path.Join("example.com", i.path)
@ -266,7 +267,7 @@ func TestRouter_AddWildcardRoute(t *testing.T) {
transInsecure := &fakeTransport{}
for _, i := range routeTests {
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure))
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure, &websocket.Server{}))
dst := i.dst
dst.Dst = path.Join("127.0.0.1:8080", dst.Dst)
dst.Src = path.Join("*.example.com", i.path)

View File

@ -4,6 +4,7 @@ import (
"database/sql"
"github.com/1f349/violet/certs"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/1f349/violet/router"
"github.com/1f349/violet/servers/conf"
"github.com/1f349/violet/utils/fake"
@ -33,7 +34,7 @@ func TestNewHttpsServer_RateLimit(t *testing.T) {
Domains: &fake.Domains{},
Certs: certs.New(nil, nil, true),
Signer: fake.SnakeOilProv,
Router: router.NewManager(db, proxy.NewHybridTransportWithCalls(ft, ft)),
Router: router.NewManager(db, proxy.NewHybridTransportWithCalls(ft, ft, &websocket.Server{})),
}
srv := NewHttpsServer(httpsConf)

View File

@ -10,6 +10,7 @@ const (
FlagForwardHost
FlagForwardAddr
FlagIgnoreCert
FlagWebsocket
)
var (

View File

@ -1,10 +1,10 @@
package target
import (
"context"
"fmt"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/utils"
websocket2 "github.com/gorilla/websocket"
"github.com/rs/cors"
"golang.org/x/net/http/httpguts"
"io"
@ -138,12 +138,18 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
if r.HasFlag(FlagForwardHost) {
req2.Host = req.Host
}
if r.HasFlag(FlagForwardAddr) {
req2.Header.Add("X-Forwarded-For", req.RemoteAddr)
}
// adds extra request metadata
r.internalReverseProxyMeta(rw, req)
if r.internalReverseProxyMeta(rw, req, req2) {
return
}
// switch to websocket handler
// internally the http hijack method is called
if r.HasFlag(FlagWebsocket) && websocket2.IsWebSocketUpgrade(req2) {
r.Proxy.ConnectWebsocket(rw, req2)
return
}
// serve request with reverse proxy
var resp *http.Response
@ -183,21 +189,20 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
// due to the highly custom nature of this reverse proxy software we use a copy
// of the code instead of the full httputil implementation to prevent overhead
// from the more generic implementation
func (r Route) internalReverseProxyMeta(rw http.ResponseWriter, req *http.Request) {
outreq := req.Clone(context.Background())
func (r Route) internalReverseProxyMeta(rw http.ResponseWriter, req, req2 *http.Request) bool {
if req.ContentLength == 0 {
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
req2.Body = nil // Issue 16036: nil Body for http.Transport retries
}
if outreq.Header == nil {
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
if req2.Header == nil {
req2.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
}
reqUpType := upgradeType(outreq.Header)
reqUpType := upgradeType(req2.Header)
if !asciiIsPrint(reqUpType) {
utils.RespondVioletError(rw, http.StatusBadRequest, fmt.Sprintf("client tried to switch to invalid protocol %q", reqUpType))
return
return true
}
removeHopByHopHeaders(outreq.Header)
removeHopByHopHeaders(req2.Header)
// Issue 21096: tell backend applications that care about trailer support
// that we support trailers. (We do, but we don't go out of our way to
@ -205,31 +210,35 @@ func (r Route) internalReverseProxyMeta(rw http.ResponseWriter, req *http.Reques
// mentioning.) Note that we look at req.Header, not outreq.Header, since
// the latter has passed through removeHopByHopHeaders.
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
outreq.Header.Set("Te", "trailers")
req2.Header.Set("Te", "trailers")
}
// After stripping all the hop-by-hop connection headers above, add back any
// necessary for protocol upgrades, such as for websockets.
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
req2.Header.Set("Connection", "Upgrade")
req2.Header.Set("Upgrade", reqUpType)
}
if r.HasFlag(FlagForwardAddr) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
prior, ok := outreq.Header["X-Forwarded-For"]
prior, ok := req2.Header["X-Forwarded-For"]
omit := ok && prior == nil // Issue 38079: nil now means don't populate the header
if len(prior) > 0 {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
if !omit {
outreq.Header.Set("X-Forwarded-For", clientIP)
req2.Header.Set("X-Forwarded-For", clientIP)
}
}
}
return false
}
// String outputs a debug string for the route.
func (r Route) String() string {
return fmt.Sprintf("%#v", r)

View File

@ -3,8 +3,10 @@ package target
import (
"bytes"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/stretchr/testify/assert"
"io"
"net"
"net/http"
"net/http/httptest"
"testing"
@ -16,7 +18,7 @@ type proxyTester struct {
}
func (p *proxyTester) makeHybridTransport() *proxy.HybridTransport {
return proxy.NewHybridTransportWithCalls(p, p)
return proxy.NewHybridTransportWithCalls(p, p, &websocket.Server{})
}
func (p *proxyTester) RoundTrip(req *http.Request) (*http.Response, error) {
@ -52,7 +54,9 @@ func TestRoute_ServeHTTP(t *testing.T) {
assert.True(t, pt.got)
assert.Equal(t, i.target, pt.req.URL.String())
if i.HasFlag(FlagForwardAddr) {
assert.Equal(t, req.RemoteAddr, pt.req.Header.Get("X-Forwarded-For"))
host, _, err := net.SplitHostPort(req.RemoteAddr)
assert.NoError(t, err)
assert.Equal(t, host, pt.req.Header.Get("X-Forwarded-For"))
}
if i.HasFlag(FlagForwardHost) {
assert.Equal(t, req.Host, pt.req.Host)

77
violet.openapi.yaml Normal file
View File

@ -0,0 +1,77 @@
openapi: 3.0.3
info:
title: Violet
description: Violet
version: 1.0.0
contact:
name: Webmaster
email: webmaster@1f349.net
servers:
- url: 'https://api.1f349.net/v1/violet'
paths:
/compile:
post:
summary: Compile quick access data
tags:
- compile
responses:
'202':
description: Compile trigger sent
/domain/{domain}:
put:
summary: Add an allowed domain
tags:
- domain
parameters:
- name: domain
in: path
required: true
description: The domain to add
schema:
type: string
responses:
'202':
description: Domain added and compiled list reloaded
delete:
summary: Remove an allowed domain
tags:
- domain
parameters:
- name: domain
in: path
required: true
description: The domain to remove
schema:
type: string
responses:
'202':
description: Domain removed and compiled list reloaded
/acme-challenge/{domain}/{key}/{value}:
put:
summary: Add ACME challenge value
tags:
- acme-challenge
parameters:
- name: domain
in: path
required: true
description: The domain to add the challenge on
schema:
type: string
responses:
'202':
description: ACME challenge added
delete:
summary: Add ACME challenge value
tags:
- acme-challenge
parameters:
- name: domain
in: path
required: true
description: The domain to add the challenge on
schema:
type: string
responses:
'202':
description: ACME challenge added