mirror of
https://github.com/1f349/violet.git
synced 2024-11-21 19:01:39 +00:00
Add Ctrl+C handling, self-signed mode for devs and fix some bugs in routing
This commit is contained in:
parent
0551e15979
commit
9147a813cb
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
*.sqlite
|
@ -1,11 +1,11 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
|
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
|
||||||
<data-source source="LOCAL" name="identifier.sqlite" uuid="a1c751d4-a71e-4c87-b033-ea49e424ae9a">
|
<data-source source="LOCAL" name="__db.sqlite" uuid="5aeb4e88-8ec4-4227-a921-ba4eaed357bf">
|
||||||
<driver-ref>sqlite.xerial</driver-ref>
|
<driver-ref>sqlite.xerial</driver-ref>
|
||||||
<synchronize>true</synchronize>
|
<synchronize>true</synchronize>
|
||||||
<jdbc-driver>org.sqlite.JDBC</jdbc-driver>
|
<jdbc-driver>org.sqlite.JDBC</jdbc-driver>
|
||||||
<jdbc-url>jdbc:sqlite:identifier.sqlite</jdbc-url>
|
<jdbc-url>jdbc:sqlite:__db.sqlite</jdbc-url>
|
||||||
<working-dir>$ProjectFileDir$</working-dir>
|
<working-dir>$ProjectFileDir$</working-dir>
|
||||||
</data-source>
|
</data-source>
|
||||||
</component>
|
</component>
|
||||||
|
@ -3,34 +3,55 @@ package certs
|
|||||||
import (
|
import (
|
||||||
"code.mrmelon54.com/melon/certgen"
|
"code.mrmelon54.com/melon/certgen"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509/pkix"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/MrMelon54/violet/utils"
|
"github.com/MrMelon54/violet/utils"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log"
|
"log"
|
||||||
|
"math/big"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Certs is the certificate loader and management system.
|
// Certs is the certificate loader and management system.
|
||||||
type Certs struct {
|
type Certs struct {
|
||||||
cDir fs.FS
|
cDir fs.FS
|
||||||
kDir fs.FS
|
kDir fs.FS
|
||||||
|
ss bool
|
||||||
s *sync.RWMutex
|
s *sync.RWMutex
|
||||||
m map[string]*tls.Certificate
|
m map[string]*tls.Certificate
|
||||||
|
ca *certgen.CertGen
|
||||||
|
sn atomic.Int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new cert list
|
// New creates a new cert list
|
||||||
func New(certDir fs.FS, keyDir fs.FS) *Certs {
|
func New(certDir fs.FS, keyDir fs.FS, selfCert bool) *Certs {
|
||||||
a := &Certs{
|
c := &Certs{
|
||||||
cDir: certDir,
|
cDir: certDir,
|
||||||
kDir: keyDir,
|
kDir: keyDir,
|
||||||
|
ss: selfCert,
|
||||||
s: &sync.RWMutex{},
|
s: &sync.RWMutex{},
|
||||||
m: make(map[string]*tls.Certificate),
|
m: make(map[string]*tls.Certificate),
|
||||||
}
|
}
|
||||||
|
if c.ss {
|
||||||
|
ca, err := certgen.MakeCaTls(pkix.Name{
|
||||||
|
Country: []string{"GB"},
|
||||||
|
Organization: []string{"Violet"},
|
||||||
|
OrganizationalUnit: []string{"Development"},
|
||||||
|
SerialNumber: "0",
|
||||||
|
CommonName: fmt.Sprintf("%d.violet.test", time.Now().Unix()),
|
||||||
|
}, big.NewInt(0))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("Failed to generate CA cert for self-signed mode:", err)
|
||||||
|
}
|
||||||
|
c.ca = ca
|
||||||
|
}
|
||||||
|
|
||||||
// run compile to get the initial data
|
// run compile to get the initial data
|
||||||
a.Compile()
|
c.Compile()
|
||||||
return a
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Certs) GetCertForDomain(domain string) *tls.Certificate {
|
func (c *Certs) GetCertForDomain(domain string) *tls.Certificate {
|
||||||
@ -43,6 +64,24 @@ func (c *Certs) GetCertForDomain(domain string) *tls.Certificate {
|
|||||||
return cert
|
return cert
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if self-signed certificate is enabled then generate a certificate
|
||||||
|
if c.ss {
|
||||||
|
sn := c.sn.Add(1)
|
||||||
|
serverTls, err := certgen.MakeServerTls(c.ca, pkix.Name{
|
||||||
|
Country: []string{"GB"},
|
||||||
|
Organization: []string{domain},
|
||||||
|
OrganizationalUnit: []string{domain},
|
||||||
|
SerialNumber: fmt.Sprintf("%d", sn),
|
||||||
|
CommonName: domain,
|
||||||
|
}, big.NewInt(sn), []string{domain}, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
leaf := serverTls.GetTlsLeaf()
|
||||||
|
c.m[domain] = &leaf
|
||||||
|
return &leaf
|
||||||
|
}
|
||||||
|
|
||||||
// lookup and return wildcard cert
|
// lookup and return wildcard cert
|
||||||
if wildcardDomain, ok := utils.ReplaceSubdomainWithWildcard(domain); ok {
|
if wildcardDomain, ok := utils.ReplaceSubdomainWithWildcard(domain); ok {
|
||||||
if cert, ok := c.m[wildcardDomain]; ok {
|
if cert, ok := c.m[wildcardDomain]; ok {
|
||||||
@ -55,6 +94,11 @@ func (c *Certs) GetCertForDomain(domain string) *tls.Certificate {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Certs) Compile() {
|
func (c *Certs) Compile() {
|
||||||
|
// don't bother compiling in self-signed mode
|
||||||
|
if c.ss {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// async compile magic
|
// async compile magic
|
||||||
go func() {
|
go func() {
|
||||||
// new map
|
// new map
|
||||||
|
@ -4,16 +4,22 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"flag"
|
"flag"
|
||||||
|
"fmt"
|
||||||
"github.com/MrMelon54/violet/certs"
|
"github.com/MrMelon54/violet/certs"
|
||||||
"github.com/MrMelon54/violet/domains"
|
"github.com/MrMelon54/violet/domains"
|
||||||
errorPages "github.com/MrMelon54/violet/error-pages"
|
errorPages "github.com/MrMelon54/violet/error-pages"
|
||||||
"github.com/MrMelon54/violet/favicons"
|
"github.com/MrMelon54/violet/favicons"
|
||||||
"github.com/MrMelon54/violet/proxy"
|
"github.com/MrMelon54/violet/proxy"
|
||||||
|
"github.com/MrMelon54/violet/router"
|
||||||
"github.com/MrMelon54/violet/servers"
|
"github.com/MrMelon54/violet/servers"
|
||||||
"github.com/MrMelon54/violet/utils"
|
"github.com/MrMelon54/violet/utils"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
"log"
|
"log"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// flags - each one has a usage field lol
|
// flags - each one has a usage field lol
|
||||||
@ -21,6 +27,7 @@ var (
|
|||||||
databasePath = flag.String("db", "", "/path/to/database.sqlite : path to the database file")
|
databasePath = flag.String("db", "", "/path/to/database.sqlite : path to the database file")
|
||||||
keyPath = flag.String("keys", "", "/path/to/keys : path contains the keys with names matching the certificates and '.key' extensions")
|
keyPath = flag.String("keys", "", "/path/to/keys : path contains the keys with names matching the certificates and '.key' extensions")
|
||||||
certPath = flag.String("certs", "", "/path/to/certificates : path contains the certificates to load in armoured PEM encoding")
|
certPath = flag.String("certs", "", "/path/to/certificates : path contains the certificates to load in armoured PEM encoding")
|
||||||
|
selfSigned = flag.Bool("ss", false, "enable self-signed certificate mode")
|
||||||
errorPagePath = flag.String("errors", "", "/path/to/error-pages : path contains the custom error pages")
|
errorPagePath = flag.String("errors", "", "/path/to/error-pages : path contains the custom error pages")
|
||||||
apiListen = flag.String("api", "127.0.0.1:8080", "address for api listening")
|
apiListen = flag.String("api", "127.0.0.1:8080", "address for api listening")
|
||||||
httpListen = flag.String("http", "0.0.0.0:80", "address for http listening")
|
httpListen = flag.String("http", "0.0.0.0:80", "address for http listening")
|
||||||
@ -30,16 +37,21 @@ var (
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
log.Println("[Violet] Starting...")
|
log.Println("[Violet] Starting...")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
// create path to cert dir
|
if *certPath != "" {
|
||||||
err := os.MkdirAll(*certPath, os.ModePerm)
|
// create path to cert dir
|
||||||
if err != nil {
|
err := os.MkdirAll(*certPath, os.ModePerm)
|
||||||
log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath)
|
if err != nil {
|
||||||
|
log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// create path to key dir
|
if *keyPath != "" {
|
||||||
err = os.MkdirAll(*keyPath, os.ModePerm)
|
// create path to key dir
|
||||||
if err != nil {
|
err := os.MkdirAll(*keyPath, os.ModePerm)
|
||||||
log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath)
|
if err != nil {
|
||||||
|
log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// open sqlite database
|
// open sqlite database
|
||||||
@ -48,11 +60,12 @@ func main() {
|
|||||||
log.Fatalf("[Violet] Failed to open database '%s'...", *databasePath)
|
log.Fatalf("[Violet] Failed to open database '%s'...", *databasePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
allowedDomains := domains.New(db) // load allowed domains
|
allowedDomains := domains.New(db) // load allowed domains
|
||||||
allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath)) // load certificate manager
|
allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath), *selfSigned) // load certificate manager
|
||||||
reverseProxy := proxy.CreateHybridReverseProxy() // load reverse proxy
|
reverseProxy := proxy.CreateHybridReverseProxy() // load reverse proxy
|
||||||
dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider
|
dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider
|
||||||
dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider
|
dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider
|
||||||
|
dynamicRouter := router.NewManager(db, reverseProxy) // load dynamic router manager
|
||||||
|
|
||||||
// struct containing config for the http servers
|
// struct containing config for the http servers
|
||||||
srvConf := &servers.Conf{
|
srvConf := &servers.Conf{
|
||||||
@ -65,16 +78,41 @@ func main() {
|
|||||||
Favicons: dynamicFavicons,
|
Favicons: dynamicFavicons,
|
||||||
Verify: nil, // TODO: add mjwt verify support
|
Verify: nil, // TODO: add mjwt verify support
|
||||||
ErrorPages: dynamicErrorPages,
|
ErrorPages: dynamicErrorPages,
|
||||||
Proxy: reverseProxy,
|
Router: dynamicRouter,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var srvApi, srvHttp, srvHttps *http.Server
|
||||||
if *apiListen != "" {
|
if *apiListen != "" {
|
||||||
servers.NewApiServer(srvConf, utils.MultiCompilable{allowedDomains})
|
srvApi = servers.NewApiServer(srvConf, utils.MultiCompilable{allowedDomains, allowedCerts, dynamicFavicons, dynamicErrorPages, dynamicRouter})
|
||||||
}
|
}
|
||||||
if *httpListen != "" {
|
if *httpListen != "" {
|
||||||
servers.NewHttpServer(srvConf)
|
srvHttp = servers.NewHttpServer(srvConf)
|
||||||
}
|
}
|
||||||
if *httpsListen != "" {
|
if *httpsListen != "" {
|
||||||
servers.NewHttpsServer(srvConf)
|
srvHttps = servers.NewHttpsServer(srvConf)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wait for exit signal
|
||||||
|
sc := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt, os.Kill)
|
||||||
|
<-sc
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
// Stop servers
|
||||||
|
log.Printf("[Violet] Stopping...")
|
||||||
|
n := time.Now()
|
||||||
|
|
||||||
|
// close http servers
|
||||||
|
if srvApi != nil {
|
||||||
|
srvApi.Close()
|
||||||
|
}
|
||||||
|
if srvHttp != nil {
|
||||||
|
srvHttp.Close()
|
||||||
|
}
|
||||||
|
if srvHttps != nil {
|
||||||
|
srvHttps.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[Violet] Took '%s' to shutdown\n", time.Now().Sub(n))
|
||||||
|
log.Println("[Violet] Goodbye")
|
||||||
}
|
}
|
||||||
|
6
domains/create-table-domains.sql
Normal file
6
domains/create-table-domains.sql
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS domains
|
||||||
|
(
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
domain TEXT,
|
||||||
|
active INTEGER DEFAULT 1
|
||||||
|
);
|
@ -2,12 +2,16 @@ package domains
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
_ "embed"
|
||||||
"github.com/MrMelon54/violet/utils"
|
"github.com/MrMelon54/violet/utils"
|
||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//go:embed create-table-domains.sql
|
||||||
|
var createTableDomains string
|
||||||
|
|
||||||
// Domains is the domain list and management system.
|
// Domains is the domain list and management system.
|
||||||
type Domains struct {
|
type Domains struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
@ -24,7 +28,7 @@ func New(db *sql.DB) *Domains {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// init domains table
|
// init domains table
|
||||||
_, err := a.db.Exec(`create table if not exists domains (id integer primary key autoincrement, domain varchar)`)
|
_, err := a.db.Exec(createTableDomains)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[WARN] Failed to generate 'domains' table\n")
|
log.Printf("[WARN] Failed to generate 'domains' table\n")
|
||||||
return nil
|
return nil
|
||||||
@ -37,11 +41,7 @@ func New(db *sql.DB) *Domains {
|
|||||||
|
|
||||||
// IsValid returns true if a domain is valid.
|
// IsValid returns true if a domain is valid.
|
||||||
func (d *Domains) IsValid(host string) bool {
|
func (d *Domains) IsValid(host string) bool {
|
||||||
// remove the port
|
domain, _, _ := utils.SplitDomainPort(host, 0)
|
||||||
domain, ok := utils.GetDomainWithoutPort(host)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// read lock for safety
|
// read lock for safety
|
||||||
d.s.RLock()
|
d.s.RLock()
|
||||||
@ -88,7 +88,7 @@ func (d *Domains) internalCompile(m map[string]struct{}) error {
|
|||||||
log.Println("[Domains] Updating domains from database")
|
log.Println("[Domains] Updating domains from database")
|
||||||
|
|
||||||
// sql or something?
|
// sql or something?
|
||||||
rows, err := d.db.Query("select name from domains where enabled = true")
|
rows, err := d.db.Query(`select domain from domains where active = 1`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
Binary file not shown.
10
router/create-table-redirects.sql
Normal file
10
router/create-table-redirects.sql
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS redirects
|
||||||
|
(
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
source TEXT,
|
||||||
|
pre INTEGER,
|
||||||
|
destination TEXT,
|
||||||
|
abs INTEGER,
|
||||||
|
code INTEGER,
|
||||||
|
active INTEGER DEFAULT 1
|
||||||
|
);
|
14
router/create-table-routes.sql
Normal file
14
router/create-table-routes.sql
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS routes
|
||||||
|
(
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
source TEXT,
|
||||||
|
pre INTEGER,
|
||||||
|
destination TEXT,
|
||||||
|
abs INTEGER,
|
||||||
|
cors INTEGER,
|
||||||
|
secure_mode INTEGER,
|
||||||
|
forward_host INTEGER,
|
||||||
|
forward_addr INTEGER,
|
||||||
|
ignore_cert INTEGER,
|
||||||
|
active INTEGER DEFAULT 1
|
||||||
|
);
|
228
router/manager.go
Normal file
228
router/manager.go
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
_ "embed"
|
||||||
|
"fmt"
|
||||||
|
"github.com/MrMelon54/violet/target"
|
||||||
|
"github.com/MrMelon54/violet/utils"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager is a database and mutex wrap around router allowing it to be
|
||||||
|
// dynamically regenerated after updating the database of routes.
|
||||||
|
type Manager struct {
|
||||||
|
db *sql.DB
|
||||||
|
s *sync.RWMutex
|
||||||
|
r *Router
|
||||||
|
p *httputil.ReverseProxy
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
//go:embed create-table-routes.sql
|
||||||
|
createTableRoutes string
|
||||||
|
//go:embed create-table-redirects.sql
|
||||||
|
createTableRedirects string
|
||||||
|
//go:embed query-table-routes.sql
|
||||||
|
queryTableRoutes string
|
||||||
|
//go:embed query-table-redirects.sql
|
||||||
|
queryTableRedirects string
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewManager create a new manager, initialises the routes and redirects tables
|
||||||
|
// in the database and runs a first time compile.
|
||||||
|
func NewManager(db *sql.DB, proxy *httputil.ReverseProxy) *Manager {
|
||||||
|
m := &Manager{
|
||||||
|
db: db,
|
||||||
|
s: &sync.RWMutex{},
|
||||||
|
r: New(nil),
|
||||||
|
p: proxy,
|
||||||
|
}
|
||||||
|
|
||||||
|
// init routes table
|
||||||
|
_, err := m.db.Exec(createTableRoutes)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[WARN] Failed to generate 'routes' table\n")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// init redirects table
|
||||||
|
_, err = m.db.Exec(createTableRedirects)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[WARN] Failed to generate 'redirects' table\n")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// run compile to get the initial router
|
||||||
|
m.Compile()
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
m.s.RLock()
|
||||||
|
m.r.ServeHTTP(rw, req)
|
||||||
|
m.s.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Compile() {
|
||||||
|
go func() {
|
||||||
|
// new router
|
||||||
|
router := New(m.p)
|
||||||
|
|
||||||
|
// compile router and check errors
|
||||||
|
err := m.internalCompile(router)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Manager] Compile failed: %s\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// lock while replacing router
|
||||||
|
m.s.Lock()
|
||||||
|
m.r = router
|
||||||
|
m.s.Unlock()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// internalCompile is a hidden internal method for querying the database during
|
||||||
|
// the Compile() method.
|
||||||
|
func (m *Manager) internalCompile(router *Router) error {
|
||||||
|
log.Println("[Manager] Updating routes from database")
|
||||||
|
|
||||||
|
// sql or something?
|
||||||
|
rows, err := m.db.Query(queryTableRoutes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
// loop through rows and scan the options
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
pre, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert bool
|
||||||
|
src, dst string
|
||||||
|
)
|
||||||
|
err := rows.Scan(&src, &pre, &dst, &abs, &cors, &secure_mode, &forward_host, &forward_addr, &ignore_cert)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = addRoute(router, src, dst, target.Route{
|
||||||
|
Pre: pre,
|
||||||
|
Abs: abs,
|
||||||
|
Cors: cors,
|
||||||
|
SecureMode: secure_mode,
|
||||||
|
ForwardHost: forward_host,
|
||||||
|
ForwardAddr: forward_addr,
|
||||||
|
IgnoreCert: ignore_cert,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check for errors
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// sql or something?
|
||||||
|
rows, err = m.db.Query(queryTableRedirects)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
// loop through rows and scan the options
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
pre, abs bool
|
||||||
|
code int
|
||||||
|
src, dst string
|
||||||
|
)
|
||||||
|
err := rows.Scan(&src, &pre, &dst, &abs, &code)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = addRedirect(router, src, dst, target.Redirect{
|
||||||
|
Pre: pre,
|
||||||
|
Abs: abs,
|
||||||
|
Code: code,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check for errors
|
||||||
|
return rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRoute is an alias to parse the src and dst then add the route
|
||||||
|
func addRoute(router *Router, src string, dst string, t target.Route) error {
|
||||||
|
srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// update target route values and add route
|
||||||
|
t.Host = dstHost
|
||||||
|
t.Port = dstPort
|
||||||
|
t.Path = dstPath
|
||||||
|
router.AddRoute(srcHost, srcPath, t)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRedirect is an alias to parse the src and dst then add the redirect
|
||||||
|
func addRedirect(router *Router, src string, dst string, t target.Redirect) error {
|
||||||
|
srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Host = dstHost
|
||||||
|
t.Port = dstPort
|
||||||
|
t.Path = dstPath
|
||||||
|
router.AddRedirect(srcHost, srcPath, t)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSrcDstHost extracts the host/path and host:port/path from the src and dst values
|
||||||
|
func parseSrcDstHost(src string, dst string) (string, string, string, int, string, error) {
|
||||||
|
// check if source has path
|
||||||
|
var srcHost, srcPath string
|
||||||
|
nSrc := strings.IndexByte(src, '/')
|
||||||
|
if nSrc == -1 {
|
||||||
|
// set host then path to /
|
||||||
|
srcHost = src
|
||||||
|
srcPath = "/"
|
||||||
|
} else {
|
||||||
|
// set host then custom path
|
||||||
|
srcHost = src[:nSrc]
|
||||||
|
srcPath = src[nSrc:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if destination has path
|
||||||
|
var dstPath string
|
||||||
|
nDst := strings.IndexByte(dst, '/')
|
||||||
|
if nDst == -1 {
|
||||||
|
// set path to /
|
||||||
|
dstPath = "/"
|
||||||
|
} else {
|
||||||
|
// set custom path then trim dst string to the host
|
||||||
|
dstPath = dst[nDst:]
|
||||||
|
dst = dst[:nDst]
|
||||||
|
}
|
||||||
|
|
||||||
|
// try to split the destination host into domain + port
|
||||||
|
dstHost, dstPort, ok := utils.SplitDomainPort(dst, 0)
|
||||||
|
if !ok {
|
||||||
|
return "", "", "", 0, "", fmt.Errorf("failed to split destination '%s' into host + port", dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
return srcHost, srcPath, dstHost, dstPort, dstPath, nil
|
||||||
|
}
|
7
router/query-table-redirects.sql
Normal file
7
router/query-table-redirects.sql
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
select source,
|
||||||
|
pre,
|
||||||
|
destination,
|
||||||
|
abs,
|
||||||
|
code
|
||||||
|
from redirects
|
||||||
|
where active = true
|
11
router/query-table-routes.sql
Normal file
11
router/query-table-routes.sql
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
select source,
|
||||||
|
pre,
|
||||||
|
destination,
|
||||||
|
abs,
|
||||||
|
cors,
|
||||||
|
secure_mode,
|
||||||
|
forward_host,
|
||||||
|
forward_addr,
|
||||||
|
ignore_cert
|
||||||
|
from routes
|
||||||
|
where active = true
|
@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/MrMelon54/trie"
|
"github.com/MrMelon54/trie"
|
||||||
"github.com/MrMelon54/violet/target"
|
"github.com/MrMelon54/violet/target"
|
||||||
|
"github.com/MrMelon54/violet/utils"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"strings"
|
"strings"
|
||||||
@ -59,7 +60,11 @@ func (r *Router) AddRedirect(host, path string, t target.Redirect) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
host := req.Host
|
if req.URL.Path == "" {
|
||||||
|
req.URL.Path = "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
host, _, _ := utils.SplitDomainPort(req.Host, 0)
|
||||||
if r.serveRedirectHTTP(rw, req, host) {
|
if r.serveRedirectHTTP(rw, req, host) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -123,8 +123,6 @@ func TestRouter_AddRedirect(t *testing.T) {
|
|||||||
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: v}
|
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: v}
|
||||||
if v == "" {
|
if v == "" {
|
||||||
u1 = nil
|
u1 = nil
|
||||||
} else if v == "/" {
|
|
||||||
u1.Path = ""
|
|
||||||
}
|
}
|
||||||
u2 := &url.URL{Scheme: "https", Host: "www.example.com", Path: k}
|
u2 := &url.URL{Scheme: "https", Host: "www.example.com", Path: k}
|
||||||
assertHttpRedirect(t, r, http.StatusFound, outputUrl(u1), http.MethodGet, outputUrl(u2))
|
assertHttpRedirect(t, r, http.StatusFound, outputUrl(u1), http.MethodGet, outputUrl(u2))
|
||||||
|
@ -6,19 +6,20 @@ import (
|
|||||||
"github.com/MrMelon54/violet/domains"
|
"github.com/MrMelon54/violet/domains"
|
||||||
errorPages "github.com/MrMelon54/violet/error-pages"
|
errorPages "github.com/MrMelon54/violet/error-pages"
|
||||||
"github.com/MrMelon54/violet/favicons"
|
"github.com/MrMelon54/violet/favicons"
|
||||||
|
"github.com/MrMelon54/violet/router"
|
||||||
"github.com/mrmelon54/mjwt"
|
"github.com/mrmelon54/mjwt"
|
||||||
"net/http/httputil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Conf stores the shared configuration for the API, HTTP and HTTPS servers.
|
||||||
type Conf struct {
|
type Conf struct {
|
||||||
ApiListen string
|
ApiListen string // api server listen address
|
||||||
HttpListen string
|
HttpListen string // http server listen address
|
||||||
HttpsListen string
|
HttpsListen string // https server listen address
|
||||||
DB *sql.DB
|
DB *sql.DB
|
||||||
Domains *domains.Domains
|
Domains *domains.Domains
|
||||||
Certs *certs.Certs
|
Certs *certs.Certs
|
||||||
Favicons *favicons.Favicons
|
Favicons *favicons.Favicons
|
||||||
Verify mjwt.Provider
|
Verify mjwt.Provider
|
||||||
ErrorPages *errorPages.ErrorPages
|
ErrorPages *errorPages.ErrorPages
|
||||||
Proxy *httputil.ReverseProxy
|
Router *router.Manager
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,6 @@ package servers
|
|||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/MrMelon54/violet/router"
|
|
||||||
"github.com/MrMelon54/violet/utils"
|
"github.com/MrMelon54/violet/utils"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/sethvargo/go-limiter/httplimit"
|
"github.com/sethvargo/go-limiter/httplimit"
|
||||||
@ -17,20 +16,9 @@ import (
|
|||||||
// NewHttpsServer creates and runs a http server containing the public https
|
// NewHttpsServer creates and runs a http server containing the public https
|
||||||
// endpoints for the reverse proxy.
|
// endpoints for the reverse proxy.
|
||||||
func NewHttpsServer(conf *Conf) *http.Server {
|
func NewHttpsServer(conf *Conf) *http.Server {
|
||||||
r := router.New(conf.Proxy)
|
|
||||||
|
|
||||||
s := &http.Server{
|
s := &http.Server{
|
||||||
Addr: conf.HttpsListen,
|
Addr: conf.HttpsListen,
|
||||||
Handler: setupRateLimiter(300).Middleware(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
Handler: setupRateLimiter(300).Middleware(conf.Router),
|
||||||
rw.Header().Set("Content-Type", "text/html")
|
|
||||||
rw.WriteHeader(http.StatusNotImplemented)
|
|
||||||
_, _ = rw.Write([]byte("<pre>"))
|
|
||||||
_, _ = rw.Write([]byte(fmt.Sprintf("%#v\n", req)))
|
|
||||||
_, _ = rw.Write([]byte("</pre>"))
|
|
||||||
_ = r
|
|
||||||
// TODO: serve from router and proxy
|
|
||||||
// r.ServeHTTP(rw, req)
|
|
||||||
})),
|
|
||||||
DisableGeneralOptionsHandler: false,
|
DisableGeneralOptionsHandler: false,
|
||||||
TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
// error out on invalid domains
|
// error out on invalid domains
|
||||||
|
@ -6,17 +6,21 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path"
|
"path"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Redirect is a target used by the router to manage redirecting the request
|
||||||
|
// using the specified configuration.
|
||||||
type Redirect struct {
|
type Redirect struct {
|
||||||
Pre bool
|
Pre bool // if the path has had a prefix removed
|
||||||
Host string
|
Host string // target host
|
||||||
Port int
|
Port int // target port
|
||||||
Path string
|
Path string // target path (possibly a prefix or absolute)
|
||||||
Abs bool
|
Abs bool // if the path is a prefix or absolute
|
||||||
Code int
|
Code int // status code used to redirect
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FullHost outputs a host:port combo or just the host if the port is 0.
|
||||||
func (r Redirect) FullHost() string {
|
func (r Redirect) FullHost() string {
|
||||||
if r.Port == 0 {
|
if r.Port == 0 {
|
||||||
return r.Host
|
return r.Host
|
||||||
@ -24,22 +28,42 @@ func (r Redirect) FullHost() string {
|
|||||||
return fmt.Sprintf("%s:%d", r.Host, r.Port)
|
return fmt.Sprintf("%s:%d", r.Host, r.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ServeHTTP responds with the redirect to the response writer provided.
|
||||||
func (r Redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (r Redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
// default to redirecting with StatusFound if code is not set
|
||||||
|
code := r.Code
|
||||||
|
if r.Code == 0 {
|
||||||
|
code = http.StatusFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// if not Abs then join with the ending of the current path
|
||||||
p := r.Path
|
p := r.Path
|
||||||
if !r.Abs {
|
if !r.Abs {
|
||||||
p = path.Join(r.Path, req.URL.Path)
|
p = path.Join(r.Path, req.URL.Path)
|
||||||
|
|
||||||
|
// replace the trailing slash that path.Join() strips off
|
||||||
|
if strings.HasSuffix(req.URL.Path, "/") {
|
||||||
|
p += "/"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fix empty path
|
||||||
|
if p == "" {
|
||||||
|
p = "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a new URL
|
||||||
u := &url.URL{
|
u := &url.URL{
|
||||||
Scheme: req.URL.Scheme,
|
Scheme: req.URL.Scheme,
|
||||||
Host: r.FullHost(),
|
Host: r.FullHost(),
|
||||||
Path: p,
|
Path: p,
|
||||||
}
|
}
|
||||||
if u.Path == "/" {
|
|
||||||
u.Path = ""
|
// use fast redirect for speed
|
||||||
}
|
utils.FastRedirect(rw, req, u.String(), code)
|
||||||
utils.FastRedirect(rw, req, u.String(), r.Code)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// String outputs a debug string for the redirect.
|
||||||
func (r Redirect) String() string {
|
func (r Redirect) String() string {
|
||||||
return fmt.Sprintf("%#v", r)
|
return fmt.Sprintf("%#v", r)
|
||||||
}
|
}
|
||||||
|
@ -11,8 +11,10 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path"
|
"path"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// serveApiCors outputs the cors headers to make APIs work.
|
||||||
var serveApiCors = cors.New(cors.Options{
|
var serveApiCors = cors.New(cors.Options{
|
||||||
AllowedOrigins: []string{"*"},
|
AllowedOrigins: []string{"*"},
|
||||||
AllowedHeaders: []string{"Content-Type", "Authorization"},
|
AllowedHeaders: []string{"Content-Type", "Authorization"},
|
||||||
@ -30,28 +32,35 @@ var serveApiCors = cors.New(cors.Options{
|
|||||||
AllowCredentials: true,
|
AllowCredentials: true,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Route is a target used by the router to manage forwarding traffic to an
|
||||||
|
// internal server using the specified configuration.
|
||||||
type Route struct {
|
type Route struct {
|
||||||
Pre bool
|
Pre bool // if the path has had a prefix removed
|
||||||
Host string
|
Host string // target host
|
||||||
Port int
|
Port int // target port
|
||||||
Path string
|
Path string // target path (possibly a prefix or absolute)
|
||||||
Abs bool
|
Abs bool // if the path is a prefix or absolute
|
||||||
Cors bool
|
Cors bool // add CORS headers
|
||||||
SecureMode bool
|
SecureMode bool // use HTTPS internally
|
||||||
ForwardHost bool
|
ForwardHost bool // forward host header internally
|
||||||
IgnoreCert bool
|
ForwardAddr bool // forward remote address
|
||||||
Headers http.Header
|
IgnoreCert bool // ignore self-cert
|
||||||
Proxy http.Handler
|
Headers http.Header // extra headers
|
||||||
|
Proxy http.Handler // reverse proxy handler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsIgnoreCert returns true if IgnoreCert is enabled.
|
||||||
func (r Route) IsIgnoreCert() bool { return r.IgnoreCert }
|
func (r Route) IsIgnoreCert() bool { return r.IgnoreCert }
|
||||||
|
|
||||||
|
// UpdateHeaders takes an existing set of headers and overwrites them with the
|
||||||
|
// extra headers.
|
||||||
func (r Route) UpdateHeaders(header http.Header) {
|
func (r Route) UpdateHeaders(header http.Header) {
|
||||||
for k, v := range r.Headers {
|
for k, v := range r.Headers {
|
||||||
header[k] = v
|
header[k] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FullHost outputs a host:port combo or just the host if the port is 0.
|
||||||
func (r Route) FullHost() string {
|
func (r Route) FullHost() string {
|
||||||
if r.Port == 0 {
|
if r.Port == 0 {
|
||||||
return r.Host
|
return r.Host
|
||||||
@ -59,15 +68,21 @@ func (r Route) FullHost() string {
|
|||||||
return fmt.Sprintf("%s:%d", r.Host, r.Port)
|
return fmt.Sprintf("%s:%d", r.Host, r.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ServeHTTP responds with the data proxied from the internal server to the
|
||||||
|
// response writer provided.
|
||||||
func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
if r.Cors {
|
if r.Cors {
|
||||||
|
// wraps with CORS handler
|
||||||
serveApiCors.Handler(http.HandlerFunc(r.internalServeHTTP)).ServeHTTP(rw, req)
|
serveApiCors.Handler(http.HandlerFunc(r.internalServeHTTP)).ServeHTTP(rw, req)
|
||||||
} else {
|
} else {
|
||||||
r.internalServeHTTP(rw, req)
|
r.internalServeHTTP(rw, req)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// internalServeHTTP is an internal method which handles configuring the request
|
||||||
|
// for the reverse proxy handler.
|
||||||
func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
// set the scheme and port using defaults if the port is 0
|
||||||
scheme := "http"
|
scheme := "http"
|
||||||
if r.SecureMode {
|
if r.SecureMode {
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
@ -80,40 +95,76 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if not Abs then join with the ending of the current path
|
||||||
p := r.Path
|
p := r.Path
|
||||||
if !r.Abs {
|
if !r.Abs {
|
||||||
p = path.Join(r.Path, req.URL.Path)
|
p = path.Join(r.Path, req.URL.Path)
|
||||||
|
|
||||||
|
// replace the trailing slash that path.Join() strips off
|
||||||
|
if strings.HasSuffix(req.URL.Path, "/") {
|
||||||
|
p += "/"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fix empty path
|
||||||
if p == "" {
|
if p == "" {
|
||||||
p = "/"
|
p = "/"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: don't just copy the body into a buffer as this is really slow
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
if req.Body != nil {
|
if req.Body != nil {
|
||||||
_, _ = io.Copy(buf, req.Body)
|
_, _ = io.Copy(buf, req.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create a new URL
|
||||||
u := &url.URL{
|
u := &url.URL{
|
||||||
Scheme: scheme,
|
Scheme: scheme,
|
||||||
Host: r.FullHost(),
|
Host: r.FullHost(),
|
||||||
Path: p,
|
Path: p,
|
||||||
RawQuery: req.URL.RawQuery,
|
RawQuery: req.URL.RawQuery,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create the internal request
|
||||||
req2, err := http.NewRequest(req.Method, u.String(), buf)
|
req2, err := http.NewRequest(req.Method, u.String(), buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[ServeRoute::ServeHTTP()] Error generating new request: %s\n", err)
|
log.Printf("[ServeRoute::ServeHTTP()] Error generating new request: %s\n", err)
|
||||||
utils.RespondHttpStatus(rw, http.StatusBadGateway)
|
utils.RespondHttpStatus(rw, http.StatusBadGateway)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// loops over the incoming request headers
|
||||||
for k, v := range req.Header {
|
for k, v := range req.Header {
|
||||||
|
// ignore host header
|
||||||
if k == "Host" {
|
if k == "Host" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// copy header into the internal request
|
||||||
req2.Header[k] = v
|
req2.Header[k] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if extra route headers are set
|
||||||
|
if r.Headers != nil {
|
||||||
|
// loop over headers
|
||||||
|
for k, v := range r.Headers {
|
||||||
|
// copy header into the internal request
|
||||||
|
req2.Header[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if forward host is enabled then send the host
|
||||||
if r.ForwardHost {
|
if r.ForwardHost {
|
||||||
req2.Host = req.Host
|
req2.Host = req.Host
|
||||||
}
|
}
|
||||||
|
if r.ForwardAddr {
|
||||||
|
req2.Header.Add("X-Forwarded-For", req.RemoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// serve request with reverse proxy
|
||||||
r.Proxy.ServeHTTP(rw, proxy.SetReverseProxyHost(req2, r))
|
r.Proxy.ServeHTTP(rw, proxy.SetReverseProxyHost(req2, r))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// String outputs a debug string for the route.
|
||||||
|
func (r Route) String() string {
|
||||||
|
return fmt.Sprintf("%#v", r)
|
||||||
|
}
|
||||||
|
75
target/route_test.go
Normal file
75
target/route_test.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package target
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type proxyTester struct {
|
||||||
|
got bool
|
||||||
|
rw http.ResponseWriter
|
||||||
|
req *http.Request
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *proxyTester) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
p.got = true
|
||||||
|
p.rw = rw
|
||||||
|
p.req = req
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoute_FullHost(t *testing.T) {
|
||||||
|
assert.Equal(t, "localhost", Route{Host: "localhost"}.FullHost())
|
||||||
|
assert.Equal(t, "localhost:22", Route{Host: "localhost", Port: 22}.FullHost())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoute_ServeHTTP(t *testing.T) {
|
||||||
|
a := []struct {
|
||||||
|
Route
|
||||||
|
target string
|
||||||
|
}{
|
||||||
|
{Route{Host: "localhost", Port: 1234, Path: "/bye", Abs: true}, "http://localhost:1234/bye"},
|
||||||
|
{Route{Host: "1.2.3.4", Path: "/bye"}, "http://1.2.3.4:80/bye/hello/world"},
|
||||||
|
{Route{Host: "2.2.2.2", Path: "/world", Abs: true, SecureMode: true}, "https://2.2.2.2:443/world"},
|
||||||
|
{Route{Host: "api.example.com", Path: "/world", Abs: true, SecureMode: true, ForwardHost: true}, "https://api.example.com:443/world"},
|
||||||
|
{Route{Host: "api.example.org", Path: "/world", Abs: true, SecureMode: true, ForwardAddr: true}, "https://api.example.org:443/world"},
|
||||||
|
{Route{Host: "3.3.3.3", Path: "/headers", Abs: true, Headers: http.Header{"X-Other": []string{"test value"}}}, "http://3.3.3.3:80/headers"},
|
||||||
|
}
|
||||||
|
for _, i := range a {
|
||||||
|
pt := &proxyTester{}
|
||||||
|
i.Proxy = pt
|
||||||
|
res := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "https://www.example.com/hello/world", nil)
|
||||||
|
i.ServeHTTP(res, req)
|
||||||
|
|
||||||
|
assert.True(t, pt.got)
|
||||||
|
assert.Equal(t, i.target, pt.req.URL.String())
|
||||||
|
if i.ForwardAddr {
|
||||||
|
assert.Equal(t, req.RemoteAddr, pt.req.Header.Get("X-Forwarded-For"))
|
||||||
|
}
|
||||||
|
if i.ForwardHost {
|
||||||
|
assert.Equal(t, req.Host, pt.req.Host)
|
||||||
|
}
|
||||||
|
if i.Headers != nil {
|
||||||
|
assert.Equal(t, i.Headers, pt.req.Header)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoute_ServeHTTP_Cors(t *testing.T) {
|
||||||
|
pt := &proxyTester{}
|
||||||
|
res := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodOptions, "https://www.example.com/test", nil)
|
||||||
|
req.Header.Set("Origin", "https://test.example.com")
|
||||||
|
i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt}
|
||||||
|
i.ServeHTTP(res, req)
|
||||||
|
|
||||||
|
assert.True(t, pt.got)
|
||||||
|
assert.Equal(t, http.MethodOptions, pt.req.Method)
|
||||||
|
assert.Equal(t, "http://1.1.1.1:8080/hello/test", pt.req.URL.String())
|
||||||
|
assert.Equal(t, "Origin", res.Header().Get("Vary"))
|
||||||
|
assert.Equal(t, "*", res.Header().Get("Access-Control-Allow-Origin"))
|
||||||
|
assert.Equal(t, "true", res.Header().Get("Access-Control-Allow-Credentials"))
|
||||||
|
assert.Equal(t, "Origin", res.Header().Get("Vary"))
|
||||||
|
}
|
@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// RespondHttpStatus outputs the status code and text using http.Error()
|
||||||
func RespondHttpStatus(rw http.ResponseWriter, status int) {
|
func RespondHttpStatus(rw http.ResponseWriter, status int) {
|
||||||
http.Error(rw, fmt.Sprintf("%d %s\n", status, http.StatusText(status)), status)
|
http.Error(rw, fmt.Sprintf("%d %s\n", status, http.StatusText(status)), status)
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// logHttpServerError is the internal function powering the logging in
|
||||||
|
// RunBackgroundHttp and RunBackgroundHttps.
|
||||||
func logHttpServerError(prefix string, err error) {
|
func logHttpServerError(prefix string, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == http.ErrServerClosed {
|
if err == http.ErrServerClosed {
|
||||||
@ -16,10 +18,14 @@ func logHttpServerError(prefix string, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RunBackgroundHttp runs a http server and logs when the server closes or
|
||||||
|
// errors.
|
||||||
func RunBackgroundHttp(prefix string, s *http.Server) {
|
func RunBackgroundHttp(prefix string, s *http.Server) {
|
||||||
logHttpServerError(prefix, s.ListenAndServe())
|
logHttpServerError(prefix, s.ListenAndServe())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RunBackgroundHttps runs a http server with TLS encryption and logs when the
|
||||||
|
// server closes or errors.
|
||||||
func RunBackgroundHttps(prefix string, s *http.Server) {
|
func RunBackgroundHttps(prefix string, s *http.Server) {
|
||||||
logHttpServerError(prefix, s.ListenAndServeTLS("", ""))
|
logHttpServerError(prefix, s.ListenAndServeTLS("", ""))
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user