Add Ctrl+C handling, self-signed mode for devs and fix some bugs in routing

This commit is contained in:
Melon 2023-04-24 15:36:21 +01:00
parent 0551e15979
commit 9147a813cb
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
21 changed files with 581 additions and 73 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*.sqlite

View File

@ -1,11 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true"> <component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="identifier.sqlite" uuid="a1c751d4-a71e-4c87-b033-ea49e424ae9a"> <data-source source="LOCAL" name="__db.sqlite" uuid="5aeb4e88-8ec4-4227-a921-ba4eaed357bf">
<driver-ref>sqlite.xerial</driver-ref> <driver-ref>sqlite.xerial</driver-ref>
<synchronize>true</synchronize> <synchronize>true</synchronize>
<jdbc-driver>org.sqlite.JDBC</jdbc-driver> <jdbc-driver>org.sqlite.JDBC</jdbc-driver>
<jdbc-url>jdbc:sqlite:identifier.sqlite</jdbc-url> <jdbc-url>jdbc:sqlite:__db.sqlite</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir> <working-dir>$ProjectFileDir$</working-dir>
</data-source> </data-source>
</component> </component>

View File

@ -3,34 +3,55 @@ package certs
import ( import (
"code.mrmelon54.com/melon/certgen" "code.mrmelon54.com/melon/certgen"
"crypto/tls" "crypto/tls"
"crypto/x509/pkix"
"fmt" "fmt"
"github.com/MrMelon54/violet/utils" "github.com/MrMelon54/violet/utils"
"io/fs" "io/fs"
"log" "log"
"math/big"
"path/filepath" "path/filepath"
"sync" "sync"
"sync/atomic"
"time"
) )
// Certs is the certificate loader and management system. // Certs is the certificate loader and management system.
type Certs struct { type Certs struct {
cDir fs.FS cDir fs.FS
kDir fs.FS kDir fs.FS
ss bool
s *sync.RWMutex s *sync.RWMutex
m map[string]*tls.Certificate m map[string]*tls.Certificate
ca *certgen.CertGen
sn atomic.Int64
} }
// New creates a new cert list // New creates a new cert list
func New(certDir fs.FS, keyDir fs.FS) *Certs { func New(certDir fs.FS, keyDir fs.FS, selfCert bool) *Certs {
a := &Certs{ c := &Certs{
cDir: certDir, cDir: certDir,
kDir: keyDir, kDir: keyDir,
ss: selfCert,
s: &sync.RWMutex{}, s: &sync.RWMutex{},
m: make(map[string]*tls.Certificate), m: make(map[string]*tls.Certificate),
} }
if c.ss {
ca, err := certgen.MakeCaTls(pkix.Name{
Country: []string{"GB"},
Organization: []string{"Violet"},
OrganizationalUnit: []string{"Development"},
SerialNumber: "0",
CommonName: fmt.Sprintf("%d.violet.test", time.Now().Unix()),
}, big.NewInt(0))
if err != nil {
log.Fatalln("Failed to generate CA cert for self-signed mode:", err)
}
c.ca = ca
}
// run compile to get the initial data // run compile to get the initial data
a.Compile() c.Compile()
return a return c
} }
func (c *Certs) GetCertForDomain(domain string) *tls.Certificate { func (c *Certs) GetCertForDomain(domain string) *tls.Certificate {
@ -43,6 +64,24 @@ func (c *Certs) GetCertForDomain(domain string) *tls.Certificate {
return cert return cert
} }
// if self-signed certificate is enabled then generate a certificate
if c.ss {
sn := c.sn.Add(1)
serverTls, err := certgen.MakeServerTls(c.ca, pkix.Name{
Country: []string{"GB"},
Organization: []string{domain},
OrganizationalUnit: []string{domain},
SerialNumber: fmt.Sprintf("%d", sn),
CommonName: domain,
}, big.NewInt(sn), []string{domain}, nil)
if err != nil {
return nil
}
leaf := serverTls.GetTlsLeaf()
c.m[domain] = &leaf
return &leaf
}
// lookup and return wildcard cert // lookup and return wildcard cert
if wildcardDomain, ok := utils.ReplaceSubdomainWithWildcard(domain); ok { if wildcardDomain, ok := utils.ReplaceSubdomainWithWildcard(domain); ok {
if cert, ok := c.m[wildcardDomain]; ok { if cert, ok := c.m[wildcardDomain]; ok {
@ -55,6 +94,11 @@ func (c *Certs) GetCertForDomain(domain string) *tls.Certificate {
} }
func (c *Certs) Compile() { func (c *Certs) Compile() {
// don't bother compiling in self-signed mode
if c.ss {
return
}
// async compile magic // async compile magic
go func() { go func() {
// new map // new map

View File

@ -4,16 +4,22 @@ import (
"database/sql" "database/sql"
_ "embed" _ "embed"
"flag" "flag"
"fmt"
"github.com/MrMelon54/violet/certs" "github.com/MrMelon54/violet/certs"
"github.com/MrMelon54/violet/domains" "github.com/MrMelon54/violet/domains"
errorPages "github.com/MrMelon54/violet/error-pages" errorPages "github.com/MrMelon54/violet/error-pages"
"github.com/MrMelon54/violet/favicons" "github.com/MrMelon54/violet/favicons"
"github.com/MrMelon54/violet/proxy" "github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/router"
"github.com/MrMelon54/violet/servers" "github.com/MrMelon54/violet/servers"
"github.com/MrMelon54/violet/utils" "github.com/MrMelon54/violet/utils"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"log" "log"
"net/http"
"os" "os"
"os/signal"
"syscall"
"time"
) )
// flags - each one has a usage field lol // flags - each one has a usage field lol
@ -21,6 +27,7 @@ var (
databasePath = flag.String("db", "", "/path/to/database.sqlite : path to the database file") databasePath = flag.String("db", "", "/path/to/database.sqlite : path to the database file")
keyPath = flag.String("keys", "", "/path/to/keys : path contains the keys with names matching the certificates and '.key' extensions") keyPath = flag.String("keys", "", "/path/to/keys : path contains the keys with names matching the certificates and '.key' extensions")
certPath = flag.String("certs", "", "/path/to/certificates : path contains the certificates to load in armoured PEM encoding") certPath = flag.String("certs", "", "/path/to/certificates : path contains the certificates to load in armoured PEM encoding")
selfSigned = flag.Bool("ss", false, "enable self-signed certificate mode")
errorPagePath = flag.String("errors", "", "/path/to/error-pages : path contains the custom error pages") errorPagePath = flag.String("errors", "", "/path/to/error-pages : path contains the custom error pages")
apiListen = flag.String("api", "127.0.0.1:8080", "address for api listening") apiListen = flag.String("api", "127.0.0.1:8080", "address for api listening")
httpListen = flag.String("http", "0.0.0.0:80", "address for http listening") httpListen = flag.String("http", "0.0.0.0:80", "address for http listening")
@ -30,17 +37,22 @@ var (
func main() { func main() {
log.Println("[Violet] Starting...") log.Println("[Violet] Starting...")
flag.Parse()
if *certPath != "" {
// create path to cert dir // create path to cert dir
err := os.MkdirAll(*certPath, os.ModePerm) err := os.MkdirAll(*certPath, os.ModePerm)
if err != nil { if err != nil {
log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath) log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath)
} }
}
if *keyPath != "" {
// create path to key dir // create path to key dir
err = os.MkdirAll(*keyPath, os.ModePerm) err := os.MkdirAll(*keyPath, os.ModePerm)
if err != nil { if err != nil {
log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath) log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath)
} }
}
// open sqlite database // open sqlite database
db, err := sql.Open("sqlite3", *databasePath) db, err := sql.Open("sqlite3", *databasePath)
@ -49,10 +61,11 @@ func main() {
} }
allowedDomains := domains.New(db) // load allowed domains allowedDomains := domains.New(db) // load allowed domains
allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath)) // load certificate manager allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath), *selfSigned) // load certificate manager
reverseProxy := proxy.CreateHybridReverseProxy() // load reverse proxy reverseProxy := proxy.CreateHybridReverseProxy() // load reverse proxy
dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider
dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider
dynamicRouter := router.NewManager(db, reverseProxy) // load dynamic router manager
// struct containing config for the http servers // struct containing config for the http servers
srvConf := &servers.Conf{ srvConf := &servers.Conf{
@ -65,16 +78,41 @@ func main() {
Favicons: dynamicFavicons, Favicons: dynamicFavicons,
Verify: nil, // TODO: add mjwt verify support Verify: nil, // TODO: add mjwt verify support
ErrorPages: dynamicErrorPages, ErrorPages: dynamicErrorPages,
Proxy: reverseProxy, Router: dynamicRouter,
} }
var srvApi, srvHttp, srvHttps *http.Server
if *apiListen != "" { if *apiListen != "" {
servers.NewApiServer(srvConf, utils.MultiCompilable{allowedDomains}) srvApi = servers.NewApiServer(srvConf, utils.MultiCompilable{allowedDomains, allowedCerts, dynamicFavicons, dynamicErrorPages, dynamicRouter})
} }
if *httpListen != "" { if *httpListen != "" {
servers.NewHttpServer(srvConf) srvHttp = servers.NewHttpServer(srvConf)
} }
if *httpsListen != "" { if *httpsListen != "" {
servers.NewHttpsServer(srvConf) srvHttps = servers.NewHttpsServer(srvConf)
} }
// Wait for exit signal
sc := make(chan os.Signal, 1)
signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt, os.Kill)
<-sc
fmt.Println()
// Stop servers
log.Printf("[Violet] Stopping...")
n := time.Now()
// close http servers
if srvApi != nil {
srvApi.Close()
}
if srvHttp != nil {
srvHttp.Close()
}
if srvHttps != nil {
srvHttps.Close()
}
log.Printf("[Violet] Took '%s' to shutdown\n", time.Now().Sub(n))
log.Println("[Violet] Goodbye")
} }

View File

@ -0,0 +1,6 @@
CREATE TABLE IF NOT EXISTS domains
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT,
active INTEGER DEFAULT 1
);

View File

@ -2,12 +2,16 @@ package domains
import ( import (
"database/sql" "database/sql"
_ "embed"
"github.com/MrMelon54/violet/utils" "github.com/MrMelon54/violet/utils"
"log" "log"
"strings" "strings"
"sync" "sync"
) )
//go:embed create-table-domains.sql
var createTableDomains string
// Domains is the domain list and management system. // Domains is the domain list and management system.
type Domains struct { type Domains struct {
db *sql.DB db *sql.DB
@ -24,7 +28,7 @@ func New(db *sql.DB) *Domains {
} }
// init domains table // init domains table
_, err := a.db.Exec(`create table if not exists domains (id integer primary key autoincrement, domain varchar)`) _, err := a.db.Exec(createTableDomains)
if err != nil { if err != nil {
log.Printf("[WARN] Failed to generate 'domains' table\n") log.Printf("[WARN] Failed to generate 'domains' table\n")
return nil return nil
@ -37,11 +41,7 @@ func New(db *sql.DB) *Domains {
// IsValid returns true if a domain is valid. // IsValid returns true if a domain is valid.
func (d *Domains) IsValid(host string) bool { func (d *Domains) IsValid(host string) bool {
// remove the port domain, _, _ := utils.SplitDomainPort(host, 0)
domain, ok := utils.GetDomainWithoutPort(host)
if !ok {
return false
}
// read lock for safety // read lock for safety
d.s.RLock() d.s.RLock()
@ -88,7 +88,7 @@ func (d *Domains) internalCompile(m map[string]struct{}) error {
log.Println("[Domains] Updating domains from database") log.Println("[Domains] Updating domains from database")
// sql or something? // sql or something?
rows, err := d.db.Query("select name from domains where enabled = true") rows, err := d.db.Query(`select domain from domains where active = 1`)
if err != nil { if err != nil {
return err return err
} }

Binary file not shown.

View File

@ -0,0 +1,10 @@
CREATE TABLE IF NOT EXISTS redirects
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
source TEXT,
pre INTEGER,
destination TEXT,
abs INTEGER,
code INTEGER,
active INTEGER DEFAULT 1
);

View File

@ -0,0 +1,14 @@
CREATE TABLE IF NOT EXISTS routes
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
source TEXT,
pre INTEGER,
destination TEXT,
abs INTEGER,
cors INTEGER,
secure_mode INTEGER,
forward_host INTEGER,
forward_addr INTEGER,
ignore_cert INTEGER,
active INTEGER DEFAULT 1
);

228
router/manager.go Normal file
View File

@ -0,0 +1,228 @@
package router
import (
"database/sql"
_ "embed"
"fmt"
"github.com/MrMelon54/violet/target"
"github.com/MrMelon54/violet/utils"
"log"
"net/http"
"net/http/httputil"
"strings"
"sync"
)
// Manager is a database and mutex wrap around router allowing it to be
// dynamically regenerated after updating the database of routes.
type Manager struct {
db *sql.DB
s *sync.RWMutex
r *Router
p *httputil.ReverseProxy
}
var (
//go:embed create-table-routes.sql
createTableRoutes string
//go:embed create-table-redirects.sql
createTableRedirects string
//go:embed query-table-routes.sql
queryTableRoutes string
//go:embed query-table-redirects.sql
queryTableRedirects string
)
// NewManager create a new manager, initialises the routes and redirects tables
// in the database and runs a first time compile.
func NewManager(db *sql.DB, proxy *httputil.ReverseProxy) *Manager {
m := &Manager{
db: db,
s: &sync.RWMutex{},
r: New(nil),
p: proxy,
}
// init routes table
_, err := m.db.Exec(createTableRoutes)
if err != nil {
log.Printf("[WARN] Failed to generate 'routes' table\n")
return nil
}
// init redirects table
_, err = m.db.Exec(createTableRedirects)
if err != nil {
log.Printf("[WARN] Failed to generate 'redirects' table\n")
return nil
}
// run compile to get the initial router
m.Compile()
return m
}
func (m *Manager) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
m.s.RLock()
m.r.ServeHTTP(rw, req)
m.s.RUnlock()
}
func (m *Manager) Compile() {
go func() {
// new router
router := New(m.p)
// compile router and check errors
err := m.internalCompile(router)
if err != nil {
log.Printf("[Manager] Compile failed: %s\n", err)
return
}
// lock while replacing router
m.s.Lock()
m.r = router
m.s.Unlock()
}()
}
// internalCompile is a hidden internal method for querying the database during
// the Compile() method.
func (m *Manager) internalCompile(router *Router) error {
log.Println("[Manager] Updating routes from database")
// sql or something?
rows, err := m.db.Query(queryTableRoutes)
if err != nil {
return err
}
defer rows.Close()
// loop through rows and scan the options
for rows.Next() {
var (
pre, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert bool
src, dst string
)
err := rows.Scan(&src, &pre, &dst, &abs, &cors, &secure_mode, &forward_host, &forward_addr, &ignore_cert)
if err != nil {
return err
}
err = addRoute(router, src, dst, target.Route{
Pre: pre,
Abs: abs,
Cors: cors,
SecureMode: secure_mode,
ForwardHost: forward_host,
ForwardAddr: forward_addr,
IgnoreCert: ignore_cert,
})
if err != nil {
return err
}
}
// check for errors
if err := rows.Err(); err != nil {
return err
}
// sql or something?
rows, err = m.db.Query(queryTableRedirects)
if err != nil {
return err
}
defer rows.Close()
// loop through rows and scan the options
for rows.Next() {
var (
pre, abs bool
code int
src, dst string
)
err := rows.Scan(&src, &pre, &dst, &abs, &code)
if err != nil {
return err
}
err = addRedirect(router, src, dst, target.Redirect{
Pre: pre,
Abs: abs,
Code: code,
})
if err != nil {
return err
}
}
// check for errors
return rows.Err()
}
// addRoute is an alias to parse the src and dst then add the route
func addRoute(router *Router, src string, dst string, t target.Route) error {
srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst)
if err != nil {
return err
}
// update target route values and add route
t.Host = dstHost
t.Port = dstPort
t.Path = dstPath
router.AddRoute(srcHost, srcPath, t)
return nil
}
// addRedirect is an alias to parse the src and dst then add the redirect
func addRedirect(router *Router, src string, dst string, t target.Redirect) error {
srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst)
if err != nil {
return err
}
t.Host = dstHost
t.Port = dstPort
t.Path = dstPath
router.AddRedirect(srcHost, srcPath, t)
return nil
}
// parseSrcDstHost extracts the host/path and host:port/path from the src and dst values
func parseSrcDstHost(src string, dst string) (string, string, string, int, string, error) {
// check if source has path
var srcHost, srcPath string
nSrc := strings.IndexByte(src, '/')
if nSrc == -1 {
// set host then path to /
srcHost = src
srcPath = "/"
} else {
// set host then custom path
srcHost = src[:nSrc]
srcPath = src[nSrc:]
}
// check if destination has path
var dstPath string
nDst := strings.IndexByte(dst, '/')
if nDst == -1 {
// set path to /
dstPath = "/"
} else {
// set custom path then trim dst string to the host
dstPath = dst[nDst:]
dst = dst[:nDst]
}
// try to split the destination host into domain + port
dstHost, dstPort, ok := utils.SplitDomainPort(dst, 0)
if !ok {
return "", "", "", 0, "", fmt.Errorf("failed to split destination '%s' into host + port", dst)
}
return srcHost, srcPath, dstHost, dstPort, dstPath, nil
}

View File

@ -0,0 +1,7 @@
select source,
pre,
destination,
abs,
code
from redirects
where active = true

View File

@ -0,0 +1,11 @@
select source,
pre,
destination,
abs,
cors,
secure_mode,
forward_host,
forward_addr,
ignore_cert
from routes
where active = true

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"github.com/MrMelon54/trie" "github.com/MrMelon54/trie"
"github.com/MrMelon54/violet/target" "github.com/MrMelon54/violet/target"
"github.com/MrMelon54/violet/utils"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"strings" "strings"
@ -59,7 +60,11 @@ func (r *Router) AddRedirect(host, path string, t target.Redirect) {
} }
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
host := req.Host if req.URL.Path == "" {
req.URL.Path = "/"
}
host, _, _ := utils.SplitDomainPort(req.Host, 0)
if r.serveRedirectHTTP(rw, req, host) { if r.serveRedirectHTTP(rw, req, host) {
return return
} }

View File

@ -123,8 +123,6 @@ func TestRouter_AddRedirect(t *testing.T) {
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: v} u1 := &url.URL{Scheme: "https", Host: "example.com", Path: v}
if v == "" { if v == "" {
u1 = nil u1 = nil
} else if v == "/" {
u1.Path = ""
} }
u2 := &url.URL{Scheme: "https", Host: "www.example.com", Path: k} u2 := &url.URL{Scheme: "https", Host: "www.example.com", Path: k}
assertHttpRedirect(t, r, http.StatusFound, outputUrl(u1), http.MethodGet, outputUrl(u2)) assertHttpRedirect(t, r, http.StatusFound, outputUrl(u1), http.MethodGet, outputUrl(u2))

View File

@ -6,19 +6,20 @@ import (
"github.com/MrMelon54/violet/domains" "github.com/MrMelon54/violet/domains"
errorPages "github.com/MrMelon54/violet/error-pages" errorPages "github.com/MrMelon54/violet/error-pages"
"github.com/MrMelon54/violet/favicons" "github.com/MrMelon54/violet/favicons"
"github.com/MrMelon54/violet/router"
"github.com/mrmelon54/mjwt" "github.com/mrmelon54/mjwt"
"net/http/httputil"
) )
// Conf stores the shared configuration for the API, HTTP and HTTPS servers.
type Conf struct { type Conf struct {
ApiListen string ApiListen string // api server listen address
HttpListen string HttpListen string // http server listen address
HttpsListen string HttpsListen string // https server listen address
DB *sql.DB DB *sql.DB
Domains *domains.Domains Domains *domains.Domains
Certs *certs.Certs Certs *certs.Certs
Favicons *favicons.Favicons Favicons *favicons.Favicons
Verify mjwt.Provider Verify mjwt.Provider
ErrorPages *errorPages.ErrorPages ErrorPages *errorPages.ErrorPages
Proxy *httputil.ReverseProxy Router *router.Manager
} }

View File

@ -3,7 +3,6 @@ package servers
import ( import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"github.com/MrMelon54/violet/router"
"github.com/MrMelon54/violet/utils" "github.com/MrMelon54/violet/utils"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/sethvargo/go-limiter/httplimit" "github.com/sethvargo/go-limiter/httplimit"
@ -17,20 +16,9 @@ import (
// NewHttpsServer creates and runs a http server containing the public https // NewHttpsServer creates and runs a http server containing the public https
// endpoints for the reverse proxy. // endpoints for the reverse proxy.
func NewHttpsServer(conf *Conf) *http.Server { func NewHttpsServer(conf *Conf) *http.Server {
r := router.New(conf.Proxy)
s := &http.Server{ s := &http.Server{
Addr: conf.HttpsListen, Addr: conf.HttpsListen,
Handler: setupRateLimiter(300).Middleware(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { Handler: setupRateLimiter(300).Middleware(conf.Router),
rw.Header().Set("Content-Type", "text/html")
rw.WriteHeader(http.StatusNotImplemented)
_, _ = rw.Write([]byte("<pre>"))
_, _ = rw.Write([]byte(fmt.Sprintf("%#v\n", req)))
_, _ = rw.Write([]byte("</pre>"))
_ = r
// TODO: serve from router and proxy
// r.ServeHTTP(rw, req)
})),
DisableGeneralOptionsHandler: false, DisableGeneralOptionsHandler: false,
TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
// error out on invalid domains // error out on invalid domains

View File

@ -6,17 +6,21 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"path" "path"
"strings"
) )
// Redirect is a target used by the router to manage redirecting the request
// using the specified configuration.
type Redirect struct { type Redirect struct {
Pre bool Pre bool // if the path has had a prefix removed
Host string Host string // target host
Port int Port int // target port
Path string Path string // target path (possibly a prefix or absolute)
Abs bool Abs bool // if the path is a prefix or absolute
Code int Code int // status code used to redirect
} }
// FullHost outputs a host:port combo or just the host if the port is 0.
func (r Redirect) FullHost() string { func (r Redirect) FullHost() string {
if r.Port == 0 { if r.Port == 0 {
return r.Host return r.Host
@ -24,22 +28,42 @@ func (r Redirect) FullHost() string {
return fmt.Sprintf("%s:%d", r.Host, r.Port) return fmt.Sprintf("%s:%d", r.Host, r.Port)
} }
// ServeHTTP responds with the redirect to the response writer provided.
func (r Redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (r Redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// default to redirecting with StatusFound if code is not set
code := r.Code
if r.Code == 0 {
code = http.StatusFound
}
// if not Abs then join with the ending of the current path
p := r.Path p := r.Path
if !r.Abs { if !r.Abs {
p = path.Join(r.Path, req.URL.Path) p = path.Join(r.Path, req.URL.Path)
// replace the trailing slash that path.Join() strips off
if strings.HasSuffix(req.URL.Path, "/") {
p += "/"
} }
}
// fix empty path
if p == "" {
p = "/"
}
// create a new URL
u := &url.URL{ u := &url.URL{
Scheme: req.URL.Scheme, Scheme: req.URL.Scheme,
Host: r.FullHost(), Host: r.FullHost(),
Path: p, Path: p,
} }
if u.Path == "/" {
u.Path = "" // use fast redirect for speed
} utils.FastRedirect(rw, req, u.String(), code)
utils.FastRedirect(rw, req, u.String(), r.Code)
} }
// String outputs a debug string for the redirect.
func (r Redirect) String() string { func (r Redirect) String() string {
return fmt.Sprintf("%#v", r) return fmt.Sprintf("%#v", r)
} }

View File

@ -11,8 +11,10 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"path" "path"
"strings"
) )
// serveApiCors outputs the cors headers to make APIs work.
var serveApiCors = cors.New(cors.Options{ var serveApiCors = cors.New(cors.Options{
AllowedOrigins: []string{"*"}, AllowedOrigins: []string{"*"},
AllowedHeaders: []string{"Content-Type", "Authorization"}, AllowedHeaders: []string{"Content-Type", "Authorization"},
@ -30,28 +32,35 @@ var serveApiCors = cors.New(cors.Options{
AllowCredentials: true, AllowCredentials: true,
}) })
// Route is a target used by the router to manage forwarding traffic to an
// internal server using the specified configuration.
type Route struct { type Route struct {
Pre bool Pre bool // if the path has had a prefix removed
Host string Host string // target host
Port int Port int // target port
Path string Path string // target path (possibly a prefix or absolute)
Abs bool Abs bool // if the path is a prefix or absolute
Cors bool Cors bool // add CORS headers
SecureMode bool SecureMode bool // use HTTPS internally
ForwardHost bool ForwardHost bool // forward host header internally
IgnoreCert bool ForwardAddr bool // forward remote address
Headers http.Header IgnoreCert bool // ignore self-cert
Proxy http.Handler Headers http.Header // extra headers
Proxy http.Handler // reverse proxy handler
} }
// IsIgnoreCert returns true if IgnoreCert is enabled.
func (r Route) IsIgnoreCert() bool { return r.IgnoreCert } func (r Route) IsIgnoreCert() bool { return r.IgnoreCert }
// UpdateHeaders takes an existing set of headers and overwrites them with the
// extra headers.
func (r Route) UpdateHeaders(header http.Header) { func (r Route) UpdateHeaders(header http.Header) {
for k, v := range r.Headers { for k, v := range r.Headers {
header[k] = v header[k] = v
} }
} }
// FullHost outputs a host:port combo or just the host if the port is 0.
func (r Route) FullHost() string { func (r Route) FullHost() string {
if r.Port == 0 { if r.Port == 0 {
return r.Host return r.Host
@ -59,15 +68,21 @@ func (r Route) FullHost() string {
return fmt.Sprintf("%s:%d", r.Host, r.Port) return fmt.Sprintf("%s:%d", r.Host, r.Port)
} }
// ServeHTTP responds with the data proxied from the internal server to the
// response writer provided.
func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if r.Cors { if r.Cors {
// wraps with CORS handler
serveApiCors.Handler(http.HandlerFunc(r.internalServeHTTP)).ServeHTTP(rw, req) serveApiCors.Handler(http.HandlerFunc(r.internalServeHTTP)).ServeHTTP(rw, req)
} else { } else {
r.internalServeHTTP(rw, req) r.internalServeHTTP(rw, req)
} }
} }
// internalServeHTTP is an internal method which handles configuring the request
// for the reverse proxy handler.
func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) { func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
// set the scheme and port using defaults if the port is 0
scheme := "http" scheme := "http"
if r.SecureMode { if r.SecureMode {
scheme = "https" scheme = "https"
@ -80,40 +95,76 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
} }
} }
// if not Abs then join with the ending of the current path
p := r.Path p := r.Path
if !r.Abs { if !r.Abs {
p = path.Join(r.Path, req.URL.Path) p = path.Join(r.Path, req.URL.Path)
// replace the trailing slash that path.Join() strips off
if strings.HasSuffix(req.URL.Path, "/") {
p += "/"
}
} }
// fix empty path
if p == "" { if p == "" {
p = "/" p = "/"
} }
// TODO: don't just copy the body into a buffer as this is really slow
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
if req.Body != nil { if req.Body != nil {
_, _ = io.Copy(buf, req.Body) _, _ = io.Copy(buf, req.Body)
} }
// create a new URL
u := &url.URL{ u := &url.URL{
Scheme: scheme, Scheme: scheme,
Host: r.FullHost(), Host: r.FullHost(),
Path: p, Path: p,
RawQuery: req.URL.RawQuery, RawQuery: req.URL.RawQuery,
} }
// create the internal request
req2, err := http.NewRequest(req.Method, u.String(), buf) req2, err := http.NewRequest(req.Method, u.String(), buf)
if err != nil { if err != nil {
log.Printf("[ServeRoute::ServeHTTP()] Error generating new request: %s\n", err) log.Printf("[ServeRoute::ServeHTTP()] Error generating new request: %s\n", err)
utils.RespondHttpStatus(rw, http.StatusBadGateway) utils.RespondHttpStatus(rw, http.StatusBadGateway)
return return
} }
// loops over the incoming request headers
for k, v := range req.Header { for k, v := range req.Header {
// ignore host header
if k == "Host" { if k == "Host" {
continue continue
} }
// copy header into the internal request
req2.Header[k] = v req2.Header[k] = v
} }
// if extra route headers are set
if r.Headers != nil {
// loop over headers
for k, v := range r.Headers {
// copy header into the internal request
req2.Header[k] = v
}
}
// if forward host is enabled then send the host
if r.ForwardHost { if r.ForwardHost {
req2.Host = req.Host req2.Host = req.Host
} }
if r.ForwardAddr {
req2.Header.Add("X-Forwarded-For", req.RemoteAddr)
}
// serve request with reverse proxy
r.Proxy.ServeHTTP(rw, proxy.SetReverseProxyHost(req2, r)) r.Proxy.ServeHTTP(rw, proxy.SetReverseProxyHost(req2, r))
} }
// String outputs a debug string for the route.
func (r Route) String() string {
return fmt.Sprintf("%#v", r)
}

75
target/route_test.go Normal file
View File

@ -0,0 +1,75 @@
package target
import (
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
)
type proxyTester struct {
got bool
rw http.ResponseWriter
req *http.Request
}
func (p *proxyTester) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
p.got = true
p.rw = rw
p.req = req
}
func TestRoute_FullHost(t *testing.T) {
assert.Equal(t, "localhost", Route{Host: "localhost"}.FullHost())
assert.Equal(t, "localhost:22", Route{Host: "localhost", Port: 22}.FullHost())
}
func TestRoute_ServeHTTP(t *testing.T) {
a := []struct {
Route
target string
}{
{Route{Host: "localhost", Port: 1234, Path: "/bye", Abs: true}, "http://localhost:1234/bye"},
{Route{Host: "1.2.3.4", Path: "/bye"}, "http://1.2.3.4:80/bye/hello/world"},
{Route{Host: "2.2.2.2", Path: "/world", Abs: true, SecureMode: true}, "https://2.2.2.2:443/world"},
{Route{Host: "api.example.com", Path: "/world", Abs: true, SecureMode: true, ForwardHost: true}, "https://api.example.com:443/world"},
{Route{Host: "api.example.org", Path: "/world", Abs: true, SecureMode: true, ForwardAddr: true}, "https://api.example.org:443/world"},
{Route{Host: "3.3.3.3", Path: "/headers", Abs: true, Headers: http.Header{"X-Other": []string{"test value"}}}, "http://3.3.3.3:80/headers"},
}
for _, i := range a {
pt := &proxyTester{}
i.Proxy = pt
res := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "https://www.example.com/hello/world", nil)
i.ServeHTTP(res, req)
assert.True(t, pt.got)
assert.Equal(t, i.target, pt.req.URL.String())
if i.ForwardAddr {
assert.Equal(t, req.RemoteAddr, pt.req.Header.Get("X-Forwarded-For"))
}
if i.ForwardHost {
assert.Equal(t, req.Host, pt.req.Host)
}
if i.Headers != nil {
assert.Equal(t, i.Headers, pt.req.Header)
}
}
}
func TestRoute_ServeHTTP_Cors(t *testing.T) {
pt := &proxyTester{}
res := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "https://www.example.com/test", nil)
req.Header.Set("Origin", "https://test.example.com")
i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt}
i.ServeHTTP(res, req)
assert.True(t, pt.got)
assert.Equal(t, http.MethodOptions, pt.req.Method)
assert.Equal(t, "http://1.1.1.1:8080/hello/test", pt.req.URL.String())
assert.Equal(t, "Origin", res.Header().Get("Vary"))
assert.Equal(t, "*", res.Header().Get("Access-Control-Allow-Origin"))
assert.Equal(t, "true", res.Header().Get("Access-Control-Allow-Credentials"))
assert.Equal(t, "Origin", res.Header().Get("Vary"))
}

View File

@ -5,6 +5,7 @@ import (
"net/http" "net/http"
) )
// RespondHttpStatus outputs the status code and text using http.Error()
func RespondHttpStatus(rw http.ResponseWriter, status int) { func RespondHttpStatus(rw http.ResponseWriter, status int) {
http.Error(rw, fmt.Sprintf("%d %s\n", status, http.StatusText(status)), status) http.Error(rw, fmt.Sprintf("%d %s\n", status, http.StatusText(status)), status)
} }

View File

@ -6,6 +6,8 @@ import (
"strings" "strings"
) )
// logHttpServerError is the internal function powering the logging in
// RunBackgroundHttp and RunBackgroundHttps.
func logHttpServerError(prefix string, err error) { func logHttpServerError(prefix string, err error) {
if err != nil { if err != nil {
if err == http.ErrServerClosed { if err == http.ErrServerClosed {
@ -16,10 +18,14 @@ func logHttpServerError(prefix string, err error) {
} }
} }
// RunBackgroundHttp runs a http server and logs when the server closes or
// errors.
func RunBackgroundHttp(prefix string, s *http.Server) { func RunBackgroundHttp(prefix string, s *http.Server) {
logHttpServerError(prefix, s.ListenAndServe()) logHttpServerError(prefix, s.ListenAndServe())
} }
// RunBackgroundHttps runs a http server with TLS encryption and logs when the
// server closes or errors.
func RunBackgroundHttps(prefix string, s *http.Server) { func RunBackgroundHttps(prefix string, s *http.Server) {
logHttpServerError(prefix, s.ListenAndServeTLS("", "")) logHttpServerError(prefix, s.ListenAndServeTLS("", ""))
} }