mirror of
https://github.com/1f349/violet.git
synced 2024-11-21 10:51:40 +00:00
Add Ctrl+C handling, self-signed mode for devs and fix some bugs in routing
This commit is contained in:
parent
0551e15979
commit
9147a813cb
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
*.sqlite
|
@ -1,11 +1,11 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
|
||||
<data-source source="LOCAL" name="identifier.sqlite" uuid="a1c751d4-a71e-4c87-b033-ea49e424ae9a">
|
||||
<data-source source="LOCAL" name="__db.sqlite" uuid="5aeb4e88-8ec4-4227-a921-ba4eaed357bf">
|
||||
<driver-ref>sqlite.xerial</driver-ref>
|
||||
<synchronize>true</synchronize>
|
||||
<jdbc-driver>org.sqlite.JDBC</jdbc-driver>
|
||||
<jdbc-url>jdbc:sqlite:identifier.sqlite</jdbc-url>
|
||||
<jdbc-url>jdbc:sqlite:__db.sqlite</jdbc-url>
|
||||
<working-dir>$ProjectFileDir$</working-dir>
|
||||
</data-source>
|
||||
</component>
|
||||
|
@ -3,34 +3,55 @@ package certs
|
||||
import (
|
||||
"code.mrmelon54.com/melon/certgen"
|
||||
"crypto/tls"
|
||||
"crypto/x509/pkix"
|
||||
"fmt"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"io/fs"
|
||||
"log"
|
||||
"math/big"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Certs is the certificate loader and management system.
|
||||
type Certs struct {
|
||||
cDir fs.FS
|
||||
kDir fs.FS
|
||||
ss bool
|
||||
s *sync.RWMutex
|
||||
m map[string]*tls.Certificate
|
||||
ca *certgen.CertGen
|
||||
sn atomic.Int64
|
||||
}
|
||||
|
||||
// New creates a new cert list
|
||||
func New(certDir fs.FS, keyDir fs.FS) *Certs {
|
||||
a := &Certs{
|
||||
func New(certDir fs.FS, keyDir fs.FS, selfCert bool) *Certs {
|
||||
c := &Certs{
|
||||
cDir: certDir,
|
||||
kDir: keyDir,
|
||||
ss: selfCert,
|
||||
s: &sync.RWMutex{},
|
||||
m: make(map[string]*tls.Certificate),
|
||||
}
|
||||
if c.ss {
|
||||
ca, err := certgen.MakeCaTls(pkix.Name{
|
||||
Country: []string{"GB"},
|
||||
Organization: []string{"Violet"},
|
||||
OrganizationalUnit: []string{"Development"},
|
||||
SerialNumber: "0",
|
||||
CommonName: fmt.Sprintf("%d.violet.test", time.Now().Unix()),
|
||||
}, big.NewInt(0))
|
||||
if err != nil {
|
||||
log.Fatalln("Failed to generate CA cert for self-signed mode:", err)
|
||||
}
|
||||
c.ca = ca
|
||||
}
|
||||
|
||||
// run compile to get the initial data
|
||||
a.Compile()
|
||||
return a
|
||||
c.Compile()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Certs) GetCertForDomain(domain string) *tls.Certificate {
|
||||
@ -43,6 +64,24 @@ func (c *Certs) GetCertForDomain(domain string) *tls.Certificate {
|
||||
return cert
|
||||
}
|
||||
|
||||
// if self-signed certificate is enabled then generate a certificate
|
||||
if c.ss {
|
||||
sn := c.sn.Add(1)
|
||||
serverTls, err := certgen.MakeServerTls(c.ca, pkix.Name{
|
||||
Country: []string{"GB"},
|
||||
Organization: []string{domain},
|
||||
OrganizationalUnit: []string{domain},
|
||||
SerialNumber: fmt.Sprintf("%d", sn),
|
||||
CommonName: domain,
|
||||
}, big.NewInt(sn), []string{domain}, nil)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
leaf := serverTls.GetTlsLeaf()
|
||||
c.m[domain] = &leaf
|
||||
return &leaf
|
||||
}
|
||||
|
||||
// lookup and return wildcard cert
|
||||
if wildcardDomain, ok := utils.ReplaceSubdomainWithWildcard(domain); ok {
|
||||
if cert, ok := c.m[wildcardDomain]; ok {
|
||||
@ -55,6 +94,11 @@ func (c *Certs) GetCertForDomain(domain string) *tls.Certificate {
|
||||
}
|
||||
|
||||
func (c *Certs) Compile() {
|
||||
// don't bother compiling in self-signed mode
|
||||
if c.ss {
|
||||
return
|
||||
}
|
||||
|
||||
// async compile magic
|
||||
go func() {
|
||||
// new map
|
||||
|
@ -4,16 +4,22 @@ import (
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/MrMelon54/violet/certs"
|
||||
"github.com/MrMelon54/violet/domains"
|
||||
errorPages "github.com/MrMelon54/violet/error-pages"
|
||||
"github.com/MrMelon54/violet/favicons"
|
||||
"github.com/MrMelon54/violet/proxy"
|
||||
"github.com/MrMelon54/violet/router"
|
||||
"github.com/MrMelon54/violet/servers"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// flags - each one has a usage field lol
|
||||
@ -21,6 +27,7 @@ var (
|
||||
databasePath = flag.String("db", "", "/path/to/database.sqlite : path to the database file")
|
||||
keyPath = flag.String("keys", "", "/path/to/keys : path contains the keys with names matching the certificates and '.key' extensions")
|
||||
certPath = flag.String("certs", "", "/path/to/certificates : path contains the certificates to load in armoured PEM encoding")
|
||||
selfSigned = flag.Bool("ss", false, "enable self-signed certificate mode")
|
||||
errorPagePath = flag.String("errors", "", "/path/to/error-pages : path contains the custom error pages")
|
||||
apiListen = flag.String("api", "127.0.0.1:8080", "address for api listening")
|
||||
httpListen = flag.String("http", "0.0.0.0:80", "address for http listening")
|
||||
@ -30,16 +37,21 @@ var (
|
||||
|
||||
func main() {
|
||||
log.Println("[Violet] Starting...")
|
||||
flag.Parse()
|
||||
|
||||
// create path to cert dir
|
||||
err := os.MkdirAll(*certPath, os.ModePerm)
|
||||
if err != nil {
|
||||
log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath)
|
||||
if *certPath != "" {
|
||||
// create path to cert dir
|
||||
err := os.MkdirAll(*certPath, os.ModePerm)
|
||||
if err != nil {
|
||||
log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath)
|
||||
}
|
||||
}
|
||||
// create path to key dir
|
||||
err = os.MkdirAll(*keyPath, os.ModePerm)
|
||||
if err != nil {
|
||||
log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath)
|
||||
if *keyPath != "" {
|
||||
// create path to key dir
|
||||
err := os.MkdirAll(*keyPath, os.ModePerm)
|
||||
if err != nil {
|
||||
log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath)
|
||||
}
|
||||
}
|
||||
|
||||
// open sqlite database
|
||||
@ -48,11 +60,12 @@ func main() {
|
||||
log.Fatalf("[Violet] Failed to open database '%s'...", *databasePath)
|
||||
}
|
||||
|
||||
allowedDomains := domains.New(db) // load allowed domains
|
||||
allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath)) // load certificate manager
|
||||
reverseProxy := proxy.CreateHybridReverseProxy() // load reverse proxy
|
||||
dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider
|
||||
dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider
|
||||
allowedDomains := domains.New(db) // load allowed domains
|
||||
allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath), *selfSigned) // load certificate manager
|
||||
reverseProxy := proxy.CreateHybridReverseProxy() // load reverse proxy
|
||||
dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider
|
||||
dynamicErrorPages := errorPages.New(os.DirFS(*errorPagePath)) // load dynamic error page provider
|
||||
dynamicRouter := router.NewManager(db, reverseProxy) // load dynamic router manager
|
||||
|
||||
// struct containing config for the http servers
|
||||
srvConf := &servers.Conf{
|
||||
@ -65,16 +78,41 @@ func main() {
|
||||
Favicons: dynamicFavicons,
|
||||
Verify: nil, // TODO: add mjwt verify support
|
||||
ErrorPages: dynamicErrorPages,
|
||||
Proxy: reverseProxy,
|
||||
Router: dynamicRouter,
|
||||
}
|
||||
|
||||
var srvApi, srvHttp, srvHttps *http.Server
|
||||
if *apiListen != "" {
|
||||
servers.NewApiServer(srvConf, utils.MultiCompilable{allowedDomains})
|
||||
srvApi = servers.NewApiServer(srvConf, utils.MultiCompilable{allowedDomains, allowedCerts, dynamicFavicons, dynamicErrorPages, dynamicRouter})
|
||||
}
|
||||
if *httpListen != "" {
|
||||
servers.NewHttpServer(srvConf)
|
||||
srvHttp = servers.NewHttpServer(srvConf)
|
||||
}
|
||||
if *httpsListen != "" {
|
||||
servers.NewHttpsServer(srvConf)
|
||||
srvHttps = servers.NewHttpsServer(srvConf)
|
||||
}
|
||||
|
||||
// Wait for exit signal
|
||||
sc := make(chan os.Signal, 1)
|
||||
signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt, os.Kill)
|
||||
<-sc
|
||||
fmt.Println()
|
||||
|
||||
// Stop servers
|
||||
log.Printf("[Violet] Stopping...")
|
||||
n := time.Now()
|
||||
|
||||
// close http servers
|
||||
if srvApi != nil {
|
||||
srvApi.Close()
|
||||
}
|
||||
if srvHttp != nil {
|
||||
srvHttp.Close()
|
||||
}
|
||||
if srvHttps != nil {
|
||||
srvHttps.Close()
|
||||
}
|
||||
|
||||
log.Printf("[Violet] Took '%s' to shutdown\n", time.Now().Sub(n))
|
||||
log.Println("[Violet] Goodbye")
|
||||
}
|
||||
|
6
domains/create-table-domains.sql
Normal file
6
domains/create-table-domains.sql
Normal file
@ -0,0 +1,6 @@
|
||||
CREATE TABLE IF NOT EXISTS domains
|
||||
(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
domain TEXT,
|
||||
active INTEGER DEFAULT 1
|
||||
);
|
@ -2,12 +2,16 @@ package domains
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
//go:embed create-table-domains.sql
|
||||
var createTableDomains string
|
||||
|
||||
// Domains is the domain list and management system.
|
||||
type Domains struct {
|
||||
db *sql.DB
|
||||
@ -24,7 +28,7 @@ func New(db *sql.DB) *Domains {
|
||||
}
|
||||
|
||||
// init domains table
|
||||
_, err := a.db.Exec(`create table if not exists domains (id integer primary key autoincrement, domain varchar)`)
|
||||
_, err := a.db.Exec(createTableDomains)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Failed to generate 'domains' table\n")
|
||||
return nil
|
||||
@ -37,11 +41,7 @@ func New(db *sql.DB) *Domains {
|
||||
|
||||
// IsValid returns true if a domain is valid.
|
||||
func (d *Domains) IsValid(host string) bool {
|
||||
// remove the port
|
||||
domain, ok := utils.GetDomainWithoutPort(host)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
domain, _, _ := utils.SplitDomainPort(host, 0)
|
||||
|
||||
// read lock for safety
|
||||
d.s.RLock()
|
||||
@ -88,7 +88,7 @@ func (d *Domains) internalCompile(m map[string]struct{}) error {
|
||||
log.Println("[Domains] Updating domains from database")
|
||||
|
||||
// sql or something?
|
||||
rows, err := d.db.Query("select name from domains where enabled = true")
|
||||
rows, err := d.db.Query(`select domain from domains where active = 1`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
Binary file not shown.
10
router/create-table-redirects.sql
Normal file
10
router/create-table-redirects.sql
Normal file
@ -0,0 +1,10 @@
|
||||
CREATE TABLE IF NOT EXISTS redirects
|
||||
(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source TEXT,
|
||||
pre INTEGER,
|
||||
destination TEXT,
|
||||
abs INTEGER,
|
||||
code INTEGER,
|
||||
active INTEGER DEFAULT 1
|
||||
);
|
14
router/create-table-routes.sql
Normal file
14
router/create-table-routes.sql
Normal file
@ -0,0 +1,14 @@
|
||||
CREATE TABLE IF NOT EXISTS routes
|
||||
(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source TEXT,
|
||||
pre INTEGER,
|
||||
destination TEXT,
|
||||
abs INTEGER,
|
||||
cors INTEGER,
|
||||
secure_mode INTEGER,
|
||||
forward_host INTEGER,
|
||||
forward_addr INTEGER,
|
||||
ignore_cert INTEGER,
|
||||
active INTEGER DEFAULT 1
|
||||
);
|
228
router/manager.go
Normal file
228
router/manager.go
Normal file
@ -0,0 +1,228 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"github.com/MrMelon54/violet/target"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Manager is a database and mutex wrap around router allowing it to be
|
||||
// dynamically regenerated after updating the database of routes.
|
||||
type Manager struct {
|
||||
db *sql.DB
|
||||
s *sync.RWMutex
|
||||
r *Router
|
||||
p *httputil.ReverseProxy
|
||||
}
|
||||
|
||||
var (
|
||||
//go:embed create-table-routes.sql
|
||||
createTableRoutes string
|
||||
//go:embed create-table-redirects.sql
|
||||
createTableRedirects string
|
||||
//go:embed query-table-routes.sql
|
||||
queryTableRoutes string
|
||||
//go:embed query-table-redirects.sql
|
||||
queryTableRedirects string
|
||||
)
|
||||
|
||||
// NewManager create a new manager, initialises the routes and redirects tables
|
||||
// in the database and runs a first time compile.
|
||||
func NewManager(db *sql.DB, proxy *httputil.ReverseProxy) *Manager {
|
||||
m := &Manager{
|
||||
db: db,
|
||||
s: &sync.RWMutex{},
|
||||
r: New(nil),
|
||||
p: proxy,
|
||||
}
|
||||
|
||||
// init routes table
|
||||
_, err := m.db.Exec(createTableRoutes)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Failed to generate 'routes' table\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
// init redirects table
|
||||
_, err = m.db.Exec(createTableRedirects)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Failed to generate 'redirects' table\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
// run compile to get the initial router
|
||||
m.Compile()
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Manager) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
m.s.RLock()
|
||||
m.r.ServeHTTP(rw, req)
|
||||
m.s.RUnlock()
|
||||
}
|
||||
|
||||
func (m *Manager) Compile() {
|
||||
go func() {
|
||||
// new router
|
||||
router := New(m.p)
|
||||
|
||||
// compile router and check errors
|
||||
err := m.internalCompile(router)
|
||||
if err != nil {
|
||||
log.Printf("[Manager] Compile failed: %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// lock while replacing router
|
||||
m.s.Lock()
|
||||
m.r = router
|
||||
m.s.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
// internalCompile is a hidden internal method for querying the database during
|
||||
// the Compile() method.
|
||||
func (m *Manager) internalCompile(router *Router) error {
|
||||
log.Println("[Manager] Updating routes from database")
|
||||
|
||||
// sql or something?
|
||||
rows, err := m.db.Query(queryTableRoutes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// loop through rows and scan the options
|
||||
for rows.Next() {
|
||||
var (
|
||||
pre, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert bool
|
||||
src, dst string
|
||||
)
|
||||
err := rows.Scan(&src, &pre, &dst, &abs, &cors, &secure_mode, &forward_host, &forward_addr, &ignore_cert)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = addRoute(router, src, dst, target.Route{
|
||||
Pre: pre,
|
||||
Abs: abs,
|
||||
Cors: cors,
|
||||
SecureMode: secure_mode,
|
||||
ForwardHost: forward_host,
|
||||
ForwardAddr: forward_addr,
|
||||
IgnoreCert: ignore_cert,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// check for errors
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// sql or something?
|
||||
rows, err = m.db.Query(queryTableRedirects)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// loop through rows and scan the options
|
||||
for rows.Next() {
|
||||
var (
|
||||
pre, abs bool
|
||||
code int
|
||||
src, dst string
|
||||
)
|
||||
err := rows.Scan(&src, &pre, &dst, &abs, &code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = addRedirect(router, src, dst, target.Redirect{
|
||||
Pre: pre,
|
||||
Abs: abs,
|
||||
Code: code,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// check for errors
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
// addRoute is an alias to parse the src and dst then add the route
|
||||
func addRoute(router *Router, src string, dst string, t target.Route) error {
|
||||
srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update target route values and add route
|
||||
t.Host = dstHost
|
||||
t.Port = dstPort
|
||||
t.Path = dstPath
|
||||
router.AddRoute(srcHost, srcPath, t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRedirect is an alias to parse the src and dst then add the redirect
|
||||
func addRedirect(router *Router, src string, dst string, t target.Redirect) error {
|
||||
srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.Host = dstHost
|
||||
t.Port = dstPort
|
||||
t.Path = dstPath
|
||||
router.AddRedirect(srcHost, srcPath, t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseSrcDstHost extracts the host/path and host:port/path from the src and dst values
|
||||
func parseSrcDstHost(src string, dst string) (string, string, string, int, string, error) {
|
||||
// check if source has path
|
||||
var srcHost, srcPath string
|
||||
nSrc := strings.IndexByte(src, '/')
|
||||
if nSrc == -1 {
|
||||
// set host then path to /
|
||||
srcHost = src
|
||||
srcPath = "/"
|
||||
} else {
|
||||
// set host then custom path
|
||||
srcHost = src[:nSrc]
|
||||
srcPath = src[nSrc:]
|
||||
}
|
||||
|
||||
// check if destination has path
|
||||
var dstPath string
|
||||
nDst := strings.IndexByte(dst, '/')
|
||||
if nDst == -1 {
|
||||
// set path to /
|
||||
dstPath = "/"
|
||||
} else {
|
||||
// set custom path then trim dst string to the host
|
||||
dstPath = dst[nDst:]
|
||||
dst = dst[:nDst]
|
||||
}
|
||||
|
||||
// try to split the destination host into domain + port
|
||||
dstHost, dstPort, ok := utils.SplitDomainPort(dst, 0)
|
||||
if !ok {
|
||||
return "", "", "", 0, "", fmt.Errorf("failed to split destination '%s' into host + port", dst)
|
||||
}
|
||||
|
||||
return srcHost, srcPath, dstHost, dstPort, dstPath, nil
|
||||
}
|
7
router/query-table-redirects.sql
Normal file
7
router/query-table-redirects.sql
Normal file
@ -0,0 +1,7 @@
|
||||
select source,
|
||||
pre,
|
||||
destination,
|
||||
abs,
|
||||
code
|
||||
from redirects
|
||||
where active = true
|
11
router/query-table-routes.sql
Normal file
11
router/query-table-routes.sql
Normal file
@ -0,0 +1,11 @@
|
||||
select source,
|
||||
pre,
|
||||
destination,
|
||||
abs,
|
||||
cors,
|
||||
secure_mode,
|
||||
forward_host,
|
||||
forward_addr,
|
||||
ignore_cert
|
||||
from routes
|
||||
where active = true
|
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/MrMelon54/trie"
|
||||
"github.com/MrMelon54/violet/target"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
@ -59,7 +60,11 @@ func (r *Router) AddRedirect(host, path string, t target.Redirect) {
|
||||
}
|
||||
|
||||
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
host := req.Host
|
||||
if req.URL.Path == "" {
|
||||
req.URL.Path = "/"
|
||||
}
|
||||
|
||||
host, _, _ := utils.SplitDomainPort(req.Host, 0)
|
||||
if r.serveRedirectHTTP(rw, req, host) {
|
||||
return
|
||||
}
|
||||
|
@ -123,8 +123,6 @@ func TestRouter_AddRedirect(t *testing.T) {
|
||||
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: v}
|
||||
if v == "" {
|
||||
u1 = nil
|
||||
} else if v == "/" {
|
||||
u1.Path = ""
|
||||
}
|
||||
u2 := &url.URL{Scheme: "https", Host: "www.example.com", Path: k}
|
||||
assertHttpRedirect(t, r, http.StatusFound, outputUrl(u1), http.MethodGet, outputUrl(u2))
|
||||
|
@ -6,19 +6,20 @@ import (
|
||||
"github.com/MrMelon54/violet/domains"
|
||||
errorPages "github.com/MrMelon54/violet/error-pages"
|
||||
"github.com/MrMelon54/violet/favicons"
|
||||
"github.com/MrMelon54/violet/router"
|
||||
"github.com/mrmelon54/mjwt"
|
||||
"net/http/httputil"
|
||||
)
|
||||
|
||||
// Conf stores the shared configuration for the API, HTTP and HTTPS servers.
|
||||
type Conf struct {
|
||||
ApiListen string
|
||||
HttpListen string
|
||||
HttpsListen string
|
||||
ApiListen string // api server listen address
|
||||
HttpListen string // http server listen address
|
||||
HttpsListen string // https server listen address
|
||||
DB *sql.DB
|
||||
Domains *domains.Domains
|
||||
Certs *certs.Certs
|
||||
Favicons *favicons.Favicons
|
||||
Verify mjwt.Provider
|
||||
ErrorPages *errorPages.ErrorPages
|
||||
Proxy *httputil.ReverseProxy
|
||||
Router *router.Manager
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ package servers
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/MrMelon54/violet/router"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/sethvargo/go-limiter/httplimit"
|
||||
@ -17,20 +16,9 @@ import (
|
||||
// NewHttpsServer creates and runs a http server containing the public https
|
||||
// endpoints for the reverse proxy.
|
||||
func NewHttpsServer(conf *Conf) *http.Server {
|
||||
r := router.New(conf.Proxy)
|
||||
|
||||
s := &http.Server{
|
||||
Addr: conf.HttpsListen,
|
||||
Handler: setupRateLimiter(300).Middleware(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("Content-Type", "text/html")
|
||||
rw.WriteHeader(http.StatusNotImplemented)
|
||||
_, _ = rw.Write([]byte("<pre>"))
|
||||
_, _ = rw.Write([]byte(fmt.Sprintf("%#v\n", req)))
|
||||
_, _ = rw.Write([]byte("</pre>"))
|
||||
_ = r
|
||||
// TODO: serve from router and proxy
|
||||
// r.ServeHTTP(rw, req)
|
||||
})),
|
||||
Addr: conf.HttpsListen,
|
||||
Handler: setupRateLimiter(300).Middleware(conf.Router),
|
||||
DisableGeneralOptionsHandler: false,
|
||||
TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
// error out on invalid domains
|
||||
|
@ -6,17 +6,21 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Redirect is a target used by the router to manage redirecting the request
|
||||
// using the specified configuration.
|
||||
type Redirect struct {
|
||||
Pre bool
|
||||
Host string
|
||||
Port int
|
||||
Path string
|
||||
Abs bool
|
||||
Code int
|
||||
Pre bool // if the path has had a prefix removed
|
||||
Host string // target host
|
||||
Port int // target port
|
||||
Path string // target path (possibly a prefix or absolute)
|
||||
Abs bool // if the path is a prefix or absolute
|
||||
Code int // status code used to redirect
|
||||
}
|
||||
|
||||
// FullHost outputs a host:port combo or just the host if the port is 0.
|
||||
func (r Redirect) FullHost() string {
|
||||
if r.Port == 0 {
|
||||
return r.Host
|
||||
@ -24,22 +28,42 @@ func (r Redirect) FullHost() string {
|
||||
return fmt.Sprintf("%s:%d", r.Host, r.Port)
|
||||
}
|
||||
|
||||
// ServeHTTP responds with the redirect to the response writer provided.
|
||||
func (r Redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// default to redirecting with StatusFound if code is not set
|
||||
code := r.Code
|
||||
if r.Code == 0 {
|
||||
code = http.StatusFound
|
||||
}
|
||||
|
||||
// if not Abs then join with the ending of the current path
|
||||
p := r.Path
|
||||
if !r.Abs {
|
||||
p = path.Join(r.Path, req.URL.Path)
|
||||
|
||||
// replace the trailing slash that path.Join() strips off
|
||||
if strings.HasSuffix(req.URL.Path, "/") {
|
||||
p += "/"
|
||||
}
|
||||
}
|
||||
|
||||
// fix empty path
|
||||
if p == "" {
|
||||
p = "/"
|
||||
}
|
||||
|
||||
// create a new URL
|
||||
u := &url.URL{
|
||||
Scheme: req.URL.Scheme,
|
||||
Host: r.FullHost(),
|
||||
Path: p,
|
||||
}
|
||||
if u.Path == "/" {
|
||||
u.Path = ""
|
||||
}
|
||||
utils.FastRedirect(rw, req, u.String(), r.Code)
|
||||
|
||||
// use fast redirect for speed
|
||||
utils.FastRedirect(rw, req, u.String(), code)
|
||||
}
|
||||
|
||||
// String outputs a debug string for the redirect.
|
||||
func (r Redirect) String() string {
|
||||
return fmt.Sprintf("%#v", r)
|
||||
}
|
||||
|
@ -11,8 +11,10 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// serveApiCors outputs the cors headers to make APIs work.
|
||||
var serveApiCors = cors.New(cors.Options{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedHeaders: []string{"Content-Type", "Authorization"},
|
||||
@ -30,28 +32,35 @@ var serveApiCors = cors.New(cors.Options{
|
||||
AllowCredentials: true,
|
||||
})
|
||||
|
||||
// Route is a target used by the router to manage forwarding traffic to an
|
||||
// internal server using the specified configuration.
|
||||
type Route struct {
|
||||
Pre bool
|
||||
Host string
|
||||
Port int
|
||||
Path string
|
||||
Abs bool
|
||||
Cors bool
|
||||
SecureMode bool
|
||||
ForwardHost bool
|
||||
IgnoreCert bool
|
||||
Headers http.Header
|
||||
Proxy http.Handler
|
||||
Pre bool // if the path has had a prefix removed
|
||||
Host string // target host
|
||||
Port int // target port
|
||||
Path string // target path (possibly a prefix or absolute)
|
||||
Abs bool // if the path is a prefix or absolute
|
||||
Cors bool // add CORS headers
|
||||
SecureMode bool // use HTTPS internally
|
||||
ForwardHost bool // forward host header internally
|
||||
ForwardAddr bool // forward remote address
|
||||
IgnoreCert bool // ignore self-cert
|
||||
Headers http.Header // extra headers
|
||||
Proxy http.Handler // reverse proxy handler
|
||||
}
|
||||
|
||||
// IsIgnoreCert returns true if IgnoreCert is enabled.
|
||||
func (r Route) IsIgnoreCert() bool { return r.IgnoreCert }
|
||||
|
||||
// UpdateHeaders takes an existing set of headers and overwrites them with the
|
||||
// extra headers.
|
||||
func (r Route) UpdateHeaders(header http.Header) {
|
||||
for k, v := range r.Headers {
|
||||
header[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// FullHost outputs a host:port combo or just the host if the port is 0.
|
||||
func (r Route) FullHost() string {
|
||||
if r.Port == 0 {
|
||||
return r.Host
|
||||
@ -59,15 +68,21 @@ func (r Route) FullHost() string {
|
||||
return fmt.Sprintf("%s:%d", r.Host, r.Port)
|
||||
}
|
||||
|
||||
// ServeHTTP responds with the data proxied from the internal server to the
|
||||
// response writer provided.
|
||||
func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if r.Cors {
|
||||
// wraps with CORS handler
|
||||
serveApiCors.Handler(http.HandlerFunc(r.internalServeHTTP)).ServeHTTP(rw, req)
|
||||
} else {
|
||||
r.internalServeHTTP(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
// internalServeHTTP is an internal method which handles configuring the request
|
||||
// for the reverse proxy handler.
|
||||
func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// set the scheme and port using defaults if the port is 0
|
||||
scheme := "http"
|
||||
if r.SecureMode {
|
||||
scheme = "https"
|
||||
@ -80,40 +95,76 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// if not Abs then join with the ending of the current path
|
||||
p := r.Path
|
||||
if !r.Abs {
|
||||
p = path.Join(r.Path, req.URL.Path)
|
||||
|
||||
// replace the trailing slash that path.Join() strips off
|
||||
if strings.HasSuffix(req.URL.Path, "/") {
|
||||
p += "/"
|
||||
}
|
||||
}
|
||||
|
||||
// fix empty path
|
||||
if p == "" {
|
||||
p = "/"
|
||||
}
|
||||
|
||||
// TODO: don't just copy the body into a buffer as this is really slow
|
||||
buf := new(bytes.Buffer)
|
||||
if req.Body != nil {
|
||||
_, _ = io.Copy(buf, req.Body)
|
||||
}
|
||||
|
||||
// create a new URL
|
||||
u := &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: r.FullHost(),
|
||||
Path: p,
|
||||
RawQuery: req.URL.RawQuery,
|
||||
}
|
||||
|
||||
// create the internal request
|
||||
req2, err := http.NewRequest(req.Method, u.String(), buf)
|
||||
if err != nil {
|
||||
log.Printf("[ServeRoute::ServeHTTP()] Error generating new request: %s\n", err)
|
||||
utils.RespondHttpStatus(rw, http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
// loops over the incoming request headers
|
||||
for k, v := range req.Header {
|
||||
// ignore host header
|
||||
if k == "Host" {
|
||||
continue
|
||||
}
|
||||
// copy header into the internal request
|
||||
req2.Header[k] = v
|
||||
}
|
||||
|
||||
// if extra route headers are set
|
||||
if r.Headers != nil {
|
||||
// loop over headers
|
||||
for k, v := range r.Headers {
|
||||
// copy header into the internal request
|
||||
req2.Header[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// if forward host is enabled then send the host
|
||||
if r.ForwardHost {
|
||||
req2.Host = req.Host
|
||||
}
|
||||
if r.ForwardAddr {
|
||||
req2.Header.Add("X-Forwarded-For", req.RemoteAddr)
|
||||
}
|
||||
|
||||
// serve request with reverse proxy
|
||||
r.Proxy.ServeHTTP(rw, proxy.SetReverseProxyHost(req2, r))
|
||||
}
|
||||
|
||||
// String outputs a debug string for the route.
|
||||
func (r Route) String() string {
|
||||
return fmt.Sprintf("%#v", r)
|
||||
}
|
||||
|
75
target/route_test.go
Normal file
75
target/route_test.go
Normal file
@ -0,0 +1,75 @@
|
||||
package target
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type proxyTester struct {
|
||||
got bool
|
||||
rw http.ResponseWriter
|
||||
req *http.Request
|
||||
}
|
||||
|
||||
func (p *proxyTester) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
p.got = true
|
||||
p.rw = rw
|
||||
p.req = req
|
||||
}
|
||||
|
||||
func TestRoute_FullHost(t *testing.T) {
|
||||
assert.Equal(t, "localhost", Route{Host: "localhost"}.FullHost())
|
||||
assert.Equal(t, "localhost:22", Route{Host: "localhost", Port: 22}.FullHost())
|
||||
}
|
||||
|
||||
func TestRoute_ServeHTTP(t *testing.T) {
|
||||
a := []struct {
|
||||
Route
|
||||
target string
|
||||
}{
|
||||
{Route{Host: "localhost", Port: 1234, Path: "/bye", Abs: true}, "http://localhost:1234/bye"},
|
||||
{Route{Host: "1.2.3.4", Path: "/bye"}, "http://1.2.3.4:80/bye/hello/world"},
|
||||
{Route{Host: "2.2.2.2", Path: "/world", Abs: true, SecureMode: true}, "https://2.2.2.2:443/world"},
|
||||
{Route{Host: "api.example.com", Path: "/world", Abs: true, SecureMode: true, ForwardHost: true}, "https://api.example.com:443/world"},
|
||||
{Route{Host: "api.example.org", Path: "/world", Abs: true, SecureMode: true, ForwardAddr: true}, "https://api.example.org:443/world"},
|
||||
{Route{Host: "3.3.3.3", Path: "/headers", Abs: true, Headers: http.Header{"X-Other": []string{"test value"}}}, "http://3.3.3.3:80/headers"},
|
||||
}
|
||||
for _, i := range a {
|
||||
pt := &proxyTester{}
|
||||
i.Proxy = pt
|
||||
res := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "https://www.example.com/hello/world", nil)
|
||||
i.ServeHTTP(res, req)
|
||||
|
||||
assert.True(t, pt.got)
|
||||
assert.Equal(t, i.target, pt.req.URL.String())
|
||||
if i.ForwardAddr {
|
||||
assert.Equal(t, req.RemoteAddr, pt.req.Header.Get("X-Forwarded-For"))
|
||||
}
|
||||
if i.ForwardHost {
|
||||
assert.Equal(t, req.Host, pt.req.Host)
|
||||
}
|
||||
if i.Headers != nil {
|
||||
assert.Equal(t, i.Headers, pt.req.Header)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoute_ServeHTTP_Cors(t *testing.T) {
|
||||
pt := &proxyTester{}
|
||||
res := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodOptions, "https://www.example.com/test", nil)
|
||||
req.Header.Set("Origin", "https://test.example.com")
|
||||
i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt}
|
||||
i.ServeHTTP(res, req)
|
||||
|
||||
assert.True(t, pt.got)
|
||||
assert.Equal(t, http.MethodOptions, pt.req.Method)
|
||||
assert.Equal(t, "http://1.1.1.1:8080/hello/test", pt.req.URL.String())
|
||||
assert.Equal(t, "Origin", res.Header().Get("Vary"))
|
||||
assert.Equal(t, "*", res.Header().Get("Access-Control-Allow-Origin"))
|
||||
assert.Equal(t, "true", res.Header().Get("Access-Control-Allow-Credentials"))
|
||||
assert.Equal(t, "Origin", res.Header().Get("Vary"))
|
||||
}
|
@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// RespondHttpStatus outputs the status code and text using http.Error()
|
||||
func RespondHttpStatus(rw http.ResponseWriter, status int) {
|
||||
http.Error(rw, fmt.Sprintf("%d %s\n", status, http.StatusText(status)), status)
|
||||
}
|
||||
|
@ -6,6 +6,8 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// logHttpServerError is the internal function powering the logging in
|
||||
// RunBackgroundHttp and RunBackgroundHttps.
|
||||
func logHttpServerError(prefix string, err error) {
|
||||
if err != nil {
|
||||
if err == http.ErrServerClosed {
|
||||
@ -16,10 +18,14 @@ func logHttpServerError(prefix string, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// RunBackgroundHttp runs a http server and logs when the server closes or
|
||||
// errors.
|
||||
func RunBackgroundHttp(prefix string, s *http.Server) {
|
||||
logHttpServerError(prefix, s.ListenAndServe())
|
||||
}
|
||||
|
||||
// RunBackgroundHttps runs a http server with TLS encryption and logs when the
|
||||
// server closes or errors.
|
||||
func RunBackgroundHttps(prefix string, s *http.Server) {
|
||||
logHttpServerError(prefix, s.ListenAndServeTLS("", ""))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user