violet/router/manager.go
2023-06-03 19:33:06 +01:00

229 lines
5.1 KiB
Go

package router
import (
"database/sql"
_ "embed"
"fmt"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/target"
"github.com/MrMelon54/violet/utils"
"log"
"net/http"
"strings"
"sync"
)
// Manager is a database and mutex wrap around router allowing it to be
// dynamically regenerated after updating the database of routes.
type Manager struct {
db *sql.DB
s *sync.RWMutex
r *Router
p *proxy.HybridTransport
}
var (
//go:embed create-table-routes.sql
createTableRoutes string
//go:embed create-table-redirects.sql
createTableRedirects string
//go:embed query-table-routes.sql
queryTableRoutes string
//go:embed query-table-redirects.sql
queryTableRedirects string
)
// NewManager create a new manager, initialises the routes and redirects tables
// in the database and runs a first time compile.
func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager {
m := &Manager{
db: db,
s: &sync.RWMutex{},
r: New(nil),
p: proxy,
}
// init routes table
_, err := m.db.Exec(createTableRoutes)
if err != nil {
log.Printf("[WARN] Failed to generate 'routes' table\n")
return nil
}
// init redirects table
_, err = m.db.Exec(createTableRedirects)
if err != nil {
log.Printf("[WARN] Failed to generate 'redirects' table\n")
return nil
}
// run compile to get the initial router
m.Compile()
return m
}
func (m *Manager) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
m.s.RLock()
m.r.ServeHTTP(rw, req)
m.s.RUnlock()
}
func (m *Manager) Compile() {
go func() {
// new router
router := New(m.p)
// compile router and check errors
err := m.internalCompile(router)
if err != nil {
log.Printf("[Manager] Compile failed: %s\n", err)
return
}
// lock while replacing router
m.s.Lock()
m.r = router
m.s.Unlock()
}()
}
// internalCompile is a hidden internal method for querying the database during
// the Compile() method.
func (m *Manager) internalCompile(router *Router) error {
log.Println("[Manager] Updating routes from database")
// sql or something?
rows, err := m.db.Query(queryTableRoutes)
if err != nil {
return err
}
defer rows.Close()
// loop through rows and scan the options
for rows.Next() {
var (
pre, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert bool
src, dst string
)
err := rows.Scan(&src, &pre, &dst, &abs, &cors, &secure_mode, &forward_host, &forward_addr, &ignore_cert)
if err != nil {
return err
}
err = addRoute(router, src, dst, target.Route{
Pre: pre,
Abs: abs,
Cors: cors,
SecureMode: secure_mode,
ForwardHost: forward_host,
ForwardAddr: forward_addr,
IgnoreCert: ignore_cert,
})
if err != nil {
return err
}
}
// check for errors
if err := rows.Err(); err != nil {
return err
}
// sql or something?
rows, err = m.db.Query(queryTableRedirects)
if err != nil {
return err
}
defer rows.Close()
// loop through rows and scan the options
for rows.Next() {
var (
pre, abs bool
code int
src, dst string
)
err := rows.Scan(&src, &pre, &dst, &abs, &code)
if err != nil {
return err
}
err = addRedirect(router, src, dst, target.Redirect{
Pre: pre,
Abs: abs,
Code: code,
})
if err != nil {
return err
}
}
// check for errors
return rows.Err()
}
// addRoute is an alias to parse the src and dst then add the route
func addRoute(router *Router, src string, dst string, t target.Route) error {
srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst)
if err != nil {
return err
}
// update target route values and add route
t.Host = dstHost
t.Port = dstPort
t.Path = dstPath
router.AddRoute(srcHost, srcPath, t)
return nil
}
// addRedirect is an alias to parse the src and dst then add the redirect
func addRedirect(router *Router, src string, dst string, t target.Redirect) error {
srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst)
if err != nil {
return err
}
t.Host = dstHost
t.Port = dstPort
t.Path = dstPath
router.AddRedirect(srcHost, srcPath, t)
return nil
}
// parseSrcDstHost extracts the host/path and host:port/path from the src and dst values
func parseSrcDstHost(src string, dst string) (string, string, string, int, string, error) {
// check if source has path
var srcHost, srcPath string
nSrc := strings.IndexByte(src, '/')
if nSrc == -1 {
// set host then path to /
srcHost = src
srcPath = "/"
} else {
// set host then custom path
srcHost = src[:nSrc]
srcPath = src[nSrc:]
}
// check if destination has path
var dstPath string
nDst := strings.IndexByte(dst, '/')
if nDst == -1 {
// set path to /
dstPath = "/"
} else {
// set custom path then trim dst string to the host
dstPath = dst[nDst:]
dst = dst[:nDst]
}
// try to split the destination host into domain + port
dstHost, dstPort, ok := utils.SplitDomainPort(dst, 0)
if !ok {
return "", "", "", 0, "", fmt.Errorf("failed to split destination '%s' into host + port", dst)
}
return srcHost, srcPath, dstHost, dstPort, dstPath, nil
}