diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9b1dffd --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.sqlite diff --git a/.idea/dataSources.xml b/.idea/dataSources.xml index 2e275d2..aba4cb2 100644 --- a/.idea/dataSources.xml +++ b/.idea/dataSources.xml @@ -1,11 +1,11 @@ - + sqlite.xerial true org.sqlite.JDBC - jdbc:sqlite:identifier.sqlite + jdbc:sqlite:__db.sqlite $ProjectFileDir$ diff --git a/certs/certs.go b/certs/certs.go index b79b3e5..66a0f6d 100644 --- a/certs/certs.go +++ b/certs/certs.go @@ -3,34 +3,55 @@ package certs import ( "code.mrmelon54.com/melon/certgen" "crypto/tls" + "crypto/x509/pkix" "fmt" "github.com/MrMelon54/violet/utils" "io/fs" "log" + "math/big" "path/filepath" "sync" + "sync/atomic" + "time" ) // Certs is the certificate loader and management system. type Certs struct { cDir fs.FS kDir fs.FS + ss bool s *sync.RWMutex m map[string]*tls.Certificate + ca *certgen.CertGen + sn atomic.Int64 } // New creates a new cert list -func New(certDir fs.FS, keyDir fs.FS) *Certs { - a := &Certs{ +func New(certDir fs.FS, keyDir fs.FS, selfCert bool) *Certs { + c := &Certs{ cDir: certDir, kDir: keyDir, + ss: selfCert, s: &sync.RWMutex{}, m: make(map[string]*tls.Certificate), } + if c.ss { + ca, err := certgen.MakeCaTls(pkix.Name{ + Country: []string{"GB"}, + Organization: []string{"Violet"}, + OrganizationalUnit: []string{"Development"}, + SerialNumber: "0", + CommonName: fmt.Sprintf("%d.violet.test", time.Now().Unix()), + }, big.NewInt(0)) + if err != nil { + log.Fatalln("Failed to generate CA cert for self-signed mode:", err) + } + c.ca = ca + } // run compile to get the initial data - a.Compile() - return a + c.Compile() + return c } func (c *Certs) GetCertForDomain(domain string) *tls.Certificate { @@ -43,6 +64,24 @@ func (c *Certs) GetCertForDomain(domain string) *tls.Certificate { return cert } + // if self-signed certificate is enabled then generate a certificate + if c.ss { + sn := c.sn.Add(1) + serverTls, err := certgen.MakeServerTls(c.ca, pkix.Name{ + Country: []string{"GB"}, + Organization: []string{domain}, + OrganizationalUnit: []string{domain}, + SerialNumber: fmt.Sprintf("%d", sn), + CommonName: domain, + }, big.NewInt(sn), []string{domain}, nil) + if err != nil { + return nil + } + leaf := serverTls.GetTlsLeaf() + c.m[domain] = &leaf + return &leaf + } + // lookup and return wildcard cert if wildcardDomain, ok := utils.ReplaceSubdomainWithWildcard(domain); ok { if cert, ok := c.m[wildcardDomain]; ok { @@ -55,6 +94,11 @@ func (c *Certs) GetCertForDomain(domain string) *tls.Certificate { } func (c *Certs) Compile() { + // don't bother compiling in self-signed mode + if c.ss { + return + } + // async compile magic go func() { // new map diff --git a/cmd/violet/main.go b/cmd/violet/main.go index 1d51441..e7bf603 100644 --- a/cmd/violet/main.go +++ b/cmd/violet/main.go @@ -4,16 +4,22 @@ import ( "database/sql" _ "embed" "flag" + "fmt" "github.com/MrMelon54/violet/certs" "github.com/MrMelon54/violet/domains" errorPages "github.com/MrMelon54/violet/error-pages" "github.com/MrMelon54/violet/favicons" "github.com/MrMelon54/violet/proxy" + "github.com/MrMelon54/violet/router" "github.com/MrMelon54/violet/servers" "github.com/MrMelon54/violet/utils" _ "github.com/mattn/go-sqlite3" "log" + "net/http" "os" + "os/signal" + "syscall" + "time" ) // flags - each one has a usage field lol @@ -21,6 +27,7 @@ var ( databasePath = flag.String("db", "", "/path/to/database.sqlite : path to the database file") keyPath = flag.String("keys", "", "/path/to/keys : path contains the keys with names matching the certificates and '.key' extensions") certPath = flag.String("certs", "", "/path/to/certificates : path contains the certificates to load in armoured PEM encoding") + selfSigned = flag.Bool("ss", false, "enable self-signed certificate mode") errorPagePath = flag.String("errors", "", "/path/to/error-pages : path contains the custom error pages") apiListen = flag.String("api", "127.0.0.1:8080", "address for api listening") httpListen = flag.String("http", "0.0.0.0:80", "address for http listening") @@ -30,16 +37,21 @@ var ( func main() { log.Println("[Violet] Starting...") + flag.Parse() - // create path to cert dir - err := os.MkdirAll(*certPath, os.ModePerm) - if err != nil { - log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath) + if *certPath != "" { + // create path to cert dir + err := os.MkdirAll(*certPath, os.ModePerm) + if err != nil { + log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath) + } } - // create path to key dir - err = os.MkdirAll(*keyPath, os.ModePerm) - if err != nil { - log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath) + if *keyPath != "" { + // create path to key dir + err := os.MkdirAll(*keyPath, os.ModePerm) + if err != nil { + log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath) + } } // open sqlite database @@ -48,11 +60,12 @@ func main() { log.Fatalf("[Violet] Failed to open database '%s'...", *databasePath) } - allowedDomains := domains.New(db) // load allowed domains - allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath)) // load certificate manager - reverseProxy := proxy.CreateHybridReverseProxy() // load reverse proxy - dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider - dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider + allowedDomains := domains.New(db) // load allowed domains + allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath), *selfSigned) // load certificate manager + reverseProxy := proxy.CreateHybridReverseProxy() // load reverse proxy + dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider + dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider + dynamicRouter := router.NewManager(db, reverseProxy) // load dynamic router manager // struct containing config for the http servers srvConf := &servers.Conf{ @@ -65,16 +78,41 @@ func main() { Favicons: dynamicFavicons, Verify: nil, // TODO: add mjwt verify support ErrorPages: dynamicErrorPages, - Proxy: reverseProxy, + Router: dynamicRouter, } + var srvApi, srvHttp, srvHttps *http.Server if *apiListen != "" { - servers.NewApiServer(srvConf, utils.MultiCompilable{allowedDomains}) + srvApi = servers.NewApiServer(srvConf, utils.MultiCompilable{allowedDomains, allowedCerts, dynamicFavicons, dynamicErrorPages, dynamicRouter}) } if *httpListen != "" { - servers.NewHttpServer(srvConf) + srvHttp = servers.NewHttpServer(srvConf) } if *httpsListen != "" { - servers.NewHttpsServer(srvConf) + srvHttps = servers.NewHttpsServer(srvConf) } + + // Wait for exit signal + sc := make(chan os.Signal, 1) + signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt, os.Kill) + <-sc + fmt.Println() + + // Stop servers + log.Printf("[Violet] Stopping...") + n := time.Now() + + // close http servers + if srvApi != nil { + srvApi.Close() + } + if srvHttp != nil { + srvHttp.Close() + } + if srvHttps != nil { + srvHttps.Close() + } + + log.Printf("[Violet] Took '%s' to shutdown\n", time.Now().Sub(n)) + log.Println("[Violet] Goodbye") } diff --git a/domains/create-table-domains.sql b/domains/create-table-domains.sql new file mode 100644 index 0000000..129e224 --- /dev/null +++ b/domains/create-table-domains.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS domains +( + id INTEGER PRIMARY KEY AUTOINCREMENT, + domain TEXT, + active INTEGER DEFAULT 1 +); diff --git a/domains/domains.go b/domains/domains.go index 5ae003b..9bac6e6 100644 --- a/domains/domains.go +++ b/domains/domains.go @@ -2,12 +2,16 @@ package domains import ( "database/sql" + _ "embed" "github.com/MrMelon54/violet/utils" "log" "strings" "sync" ) +//go:embed create-table-domains.sql +var createTableDomains string + // Domains is the domain list and management system. type Domains struct { db *sql.DB @@ -24,7 +28,7 @@ func New(db *sql.DB) *Domains { } // init domains table - _, err := a.db.Exec(`create table if not exists domains (id integer primary key autoincrement, domain varchar)`) + _, err := a.db.Exec(createTableDomains) if err != nil { log.Printf("[WARN] Failed to generate 'domains' table\n") return nil @@ -37,11 +41,7 @@ func New(db *sql.DB) *Domains { // IsValid returns true if a domain is valid. func (d *Domains) IsValid(host string) bool { - // remove the port - domain, ok := utils.GetDomainWithoutPort(host) - if !ok { - return false - } + domain, _, _ := utils.SplitDomainPort(host, 0) // read lock for safety d.s.RLock() @@ -88,7 +88,7 @@ func (d *Domains) internalCompile(m map[string]struct{}) error { log.Println("[Domains] Updating domains from database") // sql or something? - rows, err := d.db.Query("select name from domains where enabled = true") + rows, err := d.db.Query(`select domain from domains where active = 1`) if err != nil { return err } diff --git a/identifier.sqlite b/identifier.sqlite deleted file mode 100644 index 288d146..0000000 Binary files a/identifier.sqlite and /dev/null differ diff --git a/router/create-table-redirects.sql b/router/create-table-redirects.sql new file mode 100644 index 0000000..2ae16d1 --- /dev/null +++ b/router/create-table-redirects.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS redirects +( + id INTEGER PRIMARY KEY AUTOINCREMENT, + source TEXT, + pre INTEGER, + destination TEXT, + abs INTEGER, + code INTEGER, + active INTEGER DEFAULT 1 +); diff --git a/router/create-table-routes.sql b/router/create-table-routes.sql new file mode 100644 index 0000000..2b5fc38 --- /dev/null +++ b/router/create-table-routes.sql @@ -0,0 +1,14 @@ +CREATE TABLE IF NOT EXISTS routes +( + id INTEGER PRIMARY KEY AUTOINCREMENT, + source TEXT, + pre INTEGER, + destination TEXT, + abs INTEGER, + cors INTEGER, + secure_mode INTEGER, + forward_host INTEGER, + forward_addr INTEGER, + ignore_cert INTEGER, + active INTEGER DEFAULT 1 +); diff --git a/router/manager.go b/router/manager.go new file mode 100644 index 0000000..0d8f7b8 --- /dev/null +++ b/router/manager.go @@ -0,0 +1,228 @@ +package router + +import ( + "database/sql" + _ "embed" + "fmt" + "github.com/MrMelon54/violet/target" + "github.com/MrMelon54/violet/utils" + "log" + "net/http" + "net/http/httputil" + "strings" + "sync" +) + +// Manager is a database and mutex wrap around router allowing it to be +// dynamically regenerated after updating the database of routes. +type Manager struct { + db *sql.DB + s *sync.RWMutex + r *Router + p *httputil.ReverseProxy +} + +var ( + //go:embed create-table-routes.sql + createTableRoutes string + //go:embed create-table-redirects.sql + createTableRedirects string + //go:embed query-table-routes.sql + queryTableRoutes string + //go:embed query-table-redirects.sql + queryTableRedirects string +) + +// NewManager create a new manager, initialises the routes and redirects tables +// in the database and runs a first time compile. +func NewManager(db *sql.DB, proxy *httputil.ReverseProxy) *Manager { + m := &Manager{ + db: db, + s: &sync.RWMutex{}, + r: New(nil), + p: proxy, + } + + // init routes table + _, err := m.db.Exec(createTableRoutes) + if err != nil { + log.Printf("[WARN] Failed to generate 'routes' table\n") + return nil + } + + // init redirects table + _, err = m.db.Exec(createTableRedirects) + if err != nil { + log.Printf("[WARN] Failed to generate 'redirects' table\n") + return nil + } + + // run compile to get the initial router + m.Compile() + return m +} + +func (m *Manager) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + m.s.RLock() + m.r.ServeHTTP(rw, req) + m.s.RUnlock() +} + +func (m *Manager) Compile() { + go func() { + // new router + router := New(m.p) + + // compile router and check errors + err := m.internalCompile(router) + if err != nil { + log.Printf("[Manager] Compile failed: %s\n", err) + return + } + + // lock while replacing router + m.s.Lock() + m.r = router + m.s.Unlock() + }() +} + +// internalCompile is a hidden internal method for querying the database during +// the Compile() method. +func (m *Manager) internalCompile(router *Router) error { + log.Println("[Manager] Updating routes from database") + + // sql or something? + rows, err := m.db.Query(queryTableRoutes) + if err != nil { + return err + } + defer rows.Close() + + // loop through rows and scan the options + for rows.Next() { + var ( + pre, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert bool + src, dst string + ) + err := rows.Scan(&src, &pre, &dst, &abs, &cors, &secure_mode, &forward_host, &forward_addr, &ignore_cert) + if err != nil { + return err + } + + err = addRoute(router, src, dst, target.Route{ + Pre: pre, + Abs: abs, + Cors: cors, + SecureMode: secure_mode, + ForwardHost: forward_host, + ForwardAddr: forward_addr, + IgnoreCert: ignore_cert, + }) + if err != nil { + return err + } + } + + // check for errors + if err := rows.Err(); err != nil { + return err + } + + // sql or something? + rows, err = m.db.Query(queryTableRedirects) + if err != nil { + return err + } + defer rows.Close() + + // loop through rows and scan the options + for rows.Next() { + var ( + pre, abs bool + code int + src, dst string + ) + err := rows.Scan(&src, &pre, &dst, &abs, &code) + if err != nil { + return err + } + + err = addRedirect(router, src, dst, target.Redirect{ + Pre: pre, + Abs: abs, + Code: code, + }) + if err != nil { + return err + } + } + + // check for errors + return rows.Err() +} + +// addRoute is an alias to parse the src and dst then add the route +func addRoute(router *Router, src string, dst string, t target.Route) error { + srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst) + if err != nil { + return err + } + + // update target route values and add route + t.Host = dstHost + t.Port = dstPort + t.Path = dstPath + router.AddRoute(srcHost, srcPath, t) + return nil +} + +// addRedirect is an alias to parse the src and dst then add the redirect +func addRedirect(router *Router, src string, dst string, t target.Redirect) error { + srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst) + if err != nil { + return err + } + + t.Host = dstHost + t.Port = dstPort + t.Path = dstPath + router.AddRedirect(srcHost, srcPath, t) + return nil +} + +// parseSrcDstHost extracts the host/path and host:port/path from the src and dst values +func parseSrcDstHost(src string, dst string) (string, string, string, int, string, error) { + // check if source has path + var srcHost, srcPath string + nSrc := strings.IndexByte(src, '/') + if nSrc == -1 { + // set host then path to / + srcHost = src + srcPath = "/" + } else { + // set host then custom path + srcHost = src[:nSrc] + srcPath = src[nSrc:] + } + + // check if destination has path + var dstPath string + nDst := strings.IndexByte(dst, '/') + if nDst == -1 { + // set path to / + dstPath = "/" + } else { + // set custom path then trim dst string to the host + dstPath = dst[nDst:] + dst = dst[:nDst] + } + + // try to split the destination host into domain + port + dstHost, dstPort, ok := utils.SplitDomainPort(dst, 0) + if !ok { + return "", "", "", 0, "", fmt.Errorf("failed to split destination '%s' into host + port", dst) + } + + return srcHost, srcPath, dstHost, dstPort, dstPath, nil +} diff --git a/router/query-table-redirects.sql b/router/query-table-redirects.sql new file mode 100644 index 0000000..cb90280 --- /dev/null +++ b/router/query-table-redirects.sql @@ -0,0 +1,7 @@ +select source, + pre, + destination, + abs, + code +from redirects +where active = true diff --git a/router/query-table-routes.sql b/router/query-table-routes.sql new file mode 100644 index 0000000..107cd5c --- /dev/null +++ b/router/query-table-routes.sql @@ -0,0 +1,11 @@ +select source, + pre, + destination, + abs, + cors, + secure_mode, + forward_host, + forward_addr, + ignore_cert +from routes +where active = true diff --git a/router/router.go b/router/router.go index de9713e..b6e89c0 100644 --- a/router/router.go +++ b/router/router.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/MrMelon54/trie" "github.com/MrMelon54/violet/target" + "github.com/MrMelon54/violet/utils" "net/http" "net/http/httputil" "strings" @@ -59,7 +60,11 @@ func (r *Router) AddRedirect(host, path string, t target.Redirect) { } func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - host := req.Host + if req.URL.Path == "" { + req.URL.Path = "/" + } + + host, _, _ := utils.SplitDomainPort(req.Host, 0) if r.serveRedirectHTTP(rw, req, host) { return } diff --git a/router/router_test.go b/router/router_test.go index 03a68ae..10dadc1 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -123,8 +123,6 @@ func TestRouter_AddRedirect(t *testing.T) { u1 := &url.URL{Scheme: "https", Host: "example.com", Path: v} if v == "" { u1 = nil - } else if v == "/" { - u1.Path = "" } u2 := &url.URL{Scheme: "https", Host: "www.example.com", Path: k} assertHttpRedirect(t, r, http.StatusFound, outputUrl(u1), http.MethodGet, outputUrl(u2)) diff --git a/servers/conf.go b/servers/conf.go index 6479c86..5bf11a3 100644 --- a/servers/conf.go +++ b/servers/conf.go @@ -6,19 +6,20 @@ import ( "github.com/MrMelon54/violet/domains" errorPages "github.com/MrMelon54/violet/error-pages" "github.com/MrMelon54/violet/favicons" + "github.com/MrMelon54/violet/router" "github.com/mrmelon54/mjwt" - "net/http/httputil" ) +// Conf stores the shared configuration for the API, HTTP and HTTPS servers. type Conf struct { - ApiListen string - HttpListen string - HttpsListen string + ApiListen string // api server listen address + HttpListen string // http server listen address + HttpsListen string // https server listen address DB *sql.DB Domains *domains.Domains Certs *certs.Certs Favicons *favicons.Favicons Verify mjwt.Provider ErrorPages *errorPages.ErrorPages - Proxy *httputil.ReverseProxy + Router *router.Manager } diff --git a/servers/https.go b/servers/https.go index e21f83b..31b6afc 100644 --- a/servers/https.go +++ b/servers/https.go @@ -3,7 +3,6 @@ package servers import ( "crypto/tls" "fmt" - "github.com/MrMelon54/violet/router" "github.com/MrMelon54/violet/utils" "github.com/gorilla/mux" "github.com/sethvargo/go-limiter/httplimit" @@ -17,20 +16,9 @@ import ( // NewHttpsServer creates and runs a http server containing the public https // endpoints for the reverse proxy. func NewHttpsServer(conf *Conf) *http.Server { - r := router.New(conf.Proxy) - s := &http.Server{ - Addr: conf.HttpsListen, - Handler: setupRateLimiter(300).Middleware(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("Content-Type", "text/html") - rw.WriteHeader(http.StatusNotImplemented) - _, _ = rw.Write([]byte("
"))
-			_, _ = rw.Write([]byte(fmt.Sprintf("%#v\n", req)))
-			_, _ = rw.Write([]byte("
")) - _ = r - // TODO: serve from router and proxy - // r.ServeHTTP(rw, req) - })), + Addr: conf.HttpsListen, + Handler: setupRateLimiter(300).Middleware(conf.Router), DisableGeneralOptionsHandler: false, TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { // error out on invalid domains diff --git a/target/redirect.go b/target/redirect.go index 5ed9fde..ac59246 100644 --- a/target/redirect.go +++ b/target/redirect.go @@ -6,17 +6,21 @@ import ( "net/http" "net/url" "path" + "strings" ) +// Redirect is a target used by the router to manage redirecting the request +// using the specified configuration. type Redirect struct { - Pre bool - Host string - Port int - Path string - Abs bool - Code int + Pre bool // if the path has had a prefix removed + Host string // target host + Port int // target port + Path string // target path (possibly a prefix or absolute) + Abs bool // if the path is a prefix or absolute + Code int // status code used to redirect } +// FullHost outputs a host:port combo or just the host if the port is 0. func (r Redirect) FullHost() string { if r.Port == 0 { return r.Host @@ -24,22 +28,42 @@ func (r Redirect) FullHost() string { return fmt.Sprintf("%s:%d", r.Host, r.Port) } +// ServeHTTP responds with the redirect to the response writer provided. func (r Redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + // default to redirecting with StatusFound if code is not set + code := r.Code + if r.Code == 0 { + code = http.StatusFound + } + + // if not Abs then join with the ending of the current path p := r.Path if !r.Abs { p = path.Join(r.Path, req.URL.Path) + + // replace the trailing slash that path.Join() strips off + if strings.HasSuffix(req.URL.Path, "/") { + p += "/" + } } + + // fix empty path + if p == "" { + p = "/" + } + + // create a new URL u := &url.URL{ Scheme: req.URL.Scheme, Host: r.FullHost(), Path: p, } - if u.Path == "/" { - u.Path = "" - } - utils.FastRedirect(rw, req, u.String(), r.Code) + + // use fast redirect for speed + utils.FastRedirect(rw, req, u.String(), code) } +// String outputs a debug string for the redirect. func (r Redirect) String() string { return fmt.Sprintf("%#v", r) } diff --git a/target/route.go b/target/route.go index eaa8043..f0dfdf3 100644 --- a/target/route.go +++ b/target/route.go @@ -11,8 +11,10 @@ import ( "net/http" "net/url" "path" + "strings" ) +// serveApiCors outputs the cors headers to make APIs work. var serveApiCors = cors.New(cors.Options{ AllowedOrigins: []string{"*"}, AllowedHeaders: []string{"Content-Type", "Authorization"}, @@ -30,28 +32,35 @@ var serveApiCors = cors.New(cors.Options{ AllowCredentials: true, }) +// Route is a target used by the router to manage forwarding traffic to an +// internal server using the specified configuration. type Route struct { - Pre bool - Host string - Port int - Path string - Abs bool - Cors bool - SecureMode bool - ForwardHost bool - IgnoreCert bool - Headers http.Header - Proxy http.Handler + Pre bool // if the path has had a prefix removed + Host string // target host + Port int // target port + Path string // target path (possibly a prefix or absolute) + Abs bool // if the path is a prefix or absolute + Cors bool // add CORS headers + SecureMode bool // use HTTPS internally + ForwardHost bool // forward host header internally + ForwardAddr bool // forward remote address + IgnoreCert bool // ignore self-cert + Headers http.Header // extra headers + Proxy http.Handler // reverse proxy handler } +// IsIgnoreCert returns true if IgnoreCert is enabled. func (r Route) IsIgnoreCert() bool { return r.IgnoreCert } +// UpdateHeaders takes an existing set of headers and overwrites them with the +// extra headers. func (r Route) UpdateHeaders(header http.Header) { for k, v := range r.Headers { header[k] = v } } +// FullHost outputs a host:port combo or just the host if the port is 0. func (r Route) FullHost() string { if r.Port == 0 { return r.Host @@ -59,15 +68,21 @@ func (r Route) FullHost() string { return fmt.Sprintf("%s:%d", r.Host, r.Port) } +// ServeHTTP responds with the data proxied from the internal server to the +// response writer provided. func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if r.Cors { + // wraps with CORS handler serveApiCors.Handler(http.HandlerFunc(r.internalServeHTTP)).ServeHTTP(rw, req) } else { r.internalServeHTTP(rw, req) } } +// internalServeHTTP is an internal method which handles configuring the request +// for the reverse proxy handler. func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) { + // set the scheme and port using defaults if the port is 0 scheme := "http" if r.SecureMode { scheme = "https" @@ -80,40 +95,76 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) { } } + // if not Abs then join with the ending of the current path p := r.Path if !r.Abs { p = path.Join(r.Path, req.URL.Path) + + // replace the trailing slash that path.Join() strips off + if strings.HasSuffix(req.URL.Path, "/") { + p += "/" + } } + // fix empty path if p == "" { p = "/" } + // TODO: don't just copy the body into a buffer as this is really slow buf := new(bytes.Buffer) if req.Body != nil { _, _ = io.Copy(buf, req.Body) } + // create a new URL u := &url.URL{ Scheme: scheme, Host: r.FullHost(), Path: p, RawQuery: req.URL.RawQuery, } + + // create the internal request req2, err := http.NewRequest(req.Method, u.String(), buf) if err != nil { log.Printf("[ServeRoute::ServeHTTP()] Error generating new request: %s\n", err) utils.RespondHttpStatus(rw, http.StatusBadGateway) return } + + // loops over the incoming request headers for k, v := range req.Header { + // ignore host header if k == "Host" { continue } + // copy header into the internal request req2.Header[k] = v } + + // if extra route headers are set + if r.Headers != nil { + // loop over headers + for k, v := range r.Headers { + // copy header into the internal request + req2.Header[k] = v + } + } + + // if forward host is enabled then send the host if r.ForwardHost { req2.Host = req.Host } + if r.ForwardAddr { + req2.Header.Add("X-Forwarded-For", req.RemoteAddr) + } + + // serve request with reverse proxy r.Proxy.ServeHTTP(rw, proxy.SetReverseProxyHost(req2, r)) } + +// String outputs a debug string for the route. +func (r Route) String() string { + return fmt.Sprintf("%#v", r) +} diff --git a/target/route_test.go b/target/route_test.go new file mode 100644 index 0000000..a1598c9 --- /dev/null +++ b/target/route_test.go @@ -0,0 +1,75 @@ +package target + +import ( + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" +) + +type proxyTester struct { + got bool + rw http.ResponseWriter + req *http.Request +} + +func (p *proxyTester) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + p.got = true + p.rw = rw + p.req = req +} + +func TestRoute_FullHost(t *testing.T) { + assert.Equal(t, "localhost", Route{Host: "localhost"}.FullHost()) + assert.Equal(t, "localhost:22", Route{Host: "localhost", Port: 22}.FullHost()) +} + +func TestRoute_ServeHTTP(t *testing.T) { + a := []struct { + Route + target string + }{ + {Route{Host: "localhost", Port: 1234, Path: "/bye", Abs: true}, "http://localhost:1234/bye"}, + {Route{Host: "1.2.3.4", Path: "/bye"}, "http://1.2.3.4:80/bye/hello/world"}, + {Route{Host: "2.2.2.2", Path: "/world", Abs: true, SecureMode: true}, "https://2.2.2.2:443/world"}, + {Route{Host: "api.example.com", Path: "/world", Abs: true, SecureMode: true, ForwardHost: true}, "https://api.example.com:443/world"}, + {Route{Host: "api.example.org", Path: "/world", Abs: true, SecureMode: true, ForwardAddr: true}, "https://api.example.org:443/world"}, + {Route{Host: "3.3.3.3", Path: "/headers", Abs: true, Headers: http.Header{"X-Other": []string{"test value"}}}, "http://3.3.3.3:80/headers"}, + } + for _, i := range a { + pt := &proxyTester{} + i.Proxy = pt + res := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "https://www.example.com/hello/world", nil) + i.ServeHTTP(res, req) + + assert.True(t, pt.got) + assert.Equal(t, i.target, pt.req.URL.String()) + if i.ForwardAddr { + assert.Equal(t, req.RemoteAddr, pt.req.Header.Get("X-Forwarded-For")) + } + if i.ForwardHost { + assert.Equal(t, req.Host, pt.req.Host) + } + if i.Headers != nil { + assert.Equal(t, i.Headers, pt.req.Header) + } + } +} + +func TestRoute_ServeHTTP_Cors(t *testing.T) { + pt := &proxyTester{} + res := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodOptions, "https://www.example.com/test", nil) + req.Header.Set("Origin", "https://test.example.com") + i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt} + i.ServeHTTP(res, req) + + assert.True(t, pt.got) + assert.Equal(t, http.MethodOptions, pt.req.Method) + assert.Equal(t, "http://1.1.1.1:8080/hello/test", pt.req.URL.String()) + assert.Equal(t, "Origin", res.Header().Get("Vary")) + assert.Equal(t, "*", res.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "true", res.Header().Get("Access-Control-Allow-Credentials")) + assert.Equal(t, "Origin", res.Header().Get("Vary")) +} diff --git a/utils/response.go b/utils/response.go index cfdabc6..cf59f42 100644 --- a/utils/response.go +++ b/utils/response.go @@ -5,6 +5,7 @@ import ( "net/http" ) +// RespondHttpStatus outputs the status code and text using http.Error() func RespondHttpStatus(rw http.ResponseWriter, status int) { http.Error(rw, fmt.Sprintf("%d %s\n", status, http.StatusText(status)), status) } diff --git a/utils/server-utils.go b/utils/server-utils.go index e390693..41eac5a 100644 --- a/utils/server-utils.go +++ b/utils/server-utils.go @@ -6,6 +6,8 @@ import ( "strings" ) +// logHttpServerError is the internal function powering the logging in +// RunBackgroundHttp and RunBackgroundHttps. func logHttpServerError(prefix string, err error) { if err != nil { if err == http.ErrServerClosed { @@ -16,10 +18,14 @@ func logHttpServerError(prefix string, err error) { } } +// RunBackgroundHttp runs a http server and logs when the server closes or +// errors. func RunBackgroundHttp(prefix string, s *http.Server) { logHttpServerError(prefix, s.ListenAndServe()) } +// RunBackgroundHttps runs a http server with TLS encryption and logs when the +// server closes or errors. func RunBackgroundHttps(prefix string, s *http.Server) { logHttpServerError(prefix, s.ListenAndServeTLS("", "")) }