mirror of
https://github.com/1f349/violet.git
synced 2024-11-09 22:22:50 +00:00
241 lines
5.7 KiB
Go
241 lines
5.7 KiB
Go
package router
|
|
|
|
import (
|
|
"database/sql"
|
|
_ "embed"
|
|
"fmt"
|
|
"github.com/MrMelon54/rescheduler"
|
|
"github.com/MrMelon54/violet/proxy"
|
|
"github.com/MrMelon54/violet/target"
|
|
"github.com/MrMelon54/violet/utils"
|
|
"log"
|
|
"net/http"
|
|
"path"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
// Manager is a database and mutex wrap around router allowing it to be
|
|
// dynamically regenerated after updating the database of routes.
|
|
type Manager struct {
|
|
db *sql.DB
|
|
s *sync.RWMutex
|
|
r *Router
|
|
p *proxy.HybridTransport
|
|
z *rescheduler.Rescheduler
|
|
}
|
|
|
|
var (
|
|
//go:embed create-table-routes.sql
|
|
createTableRoutes string
|
|
//go:embed create-table-redirects.sql
|
|
createTableRedirects string
|
|
//go:embed query-table-routes.sql
|
|
queryTableRoutes string
|
|
//go:embed query-table-redirects.sql
|
|
queryTableRedirects string
|
|
)
|
|
|
|
// NewManager create a new manager, initialises the routes and redirects tables
|
|
// in the database and runs a first time compile.
|
|
func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager {
|
|
m := &Manager{
|
|
db: db,
|
|
s: &sync.RWMutex{},
|
|
r: New(proxy),
|
|
p: proxy,
|
|
}
|
|
m.z = rescheduler.NewRescheduler(m.threadCompile)
|
|
|
|
// init routes table
|
|
_, err := m.db.Exec(createTableRoutes)
|
|
if err != nil {
|
|
log.Printf("[WARN] Failed to generate 'routes' table\n")
|
|
return nil
|
|
}
|
|
|
|
// init redirects table
|
|
_, err = m.db.Exec(createTableRedirects)
|
|
if err != nil {
|
|
log.Printf("[WARN] Failed to generate 'redirects' table\n")
|
|
return nil
|
|
}
|
|
return m
|
|
}
|
|
|
|
func (m *Manager) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|
m.s.RLock()
|
|
m.r.ServeHTTP(rw, req)
|
|
m.s.RUnlock()
|
|
}
|
|
|
|
func (m *Manager) Compile() {
|
|
m.z.Run()
|
|
}
|
|
|
|
func (m *Manager) threadCompile() {
|
|
// new router
|
|
router := New(m.p)
|
|
|
|
// compile router and check errors
|
|
err := m.internalCompile(router)
|
|
if err != nil {
|
|
log.Printf("[Manager] Compile failed: %s\n", err)
|
|
return
|
|
}
|
|
|
|
// lock while replacing router
|
|
m.s.Lock()
|
|
m.r = router
|
|
m.s.Unlock()
|
|
}
|
|
|
|
// internalCompile is a hidden internal method for querying the database during
|
|
// the Compile() method.
|
|
func (m *Manager) internalCompile(router *Router) error {
|
|
log.Println("[Manager] Updating routes from database")
|
|
|
|
// sql or something?
|
|
rows, err := m.db.Query(queryTableRoutes)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
// loop through rows and scan the options
|
|
for rows.Next() {
|
|
var (
|
|
pre, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert bool
|
|
src, dst string
|
|
)
|
|
err := rows.Scan(&src, &pre, &dst, &abs, &cors, &secure_mode, &forward_host, &forward_addr, &ignore_cert)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = addRoute(router, src, dst, target.Route{
|
|
Pre: pre,
|
|
Abs: abs,
|
|
Cors: cors,
|
|
SecureMode: secure_mode,
|
|
ForwardHost: forward_host,
|
|
ForwardAddr: forward_addr,
|
|
IgnoreCert: ignore_cert,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// check for errors
|
|
if err := rows.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// sql or something?
|
|
rows, err = m.db.Query(queryTableRedirects)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
// loop through rows and scan the options
|
|
for rows.Next() {
|
|
var (
|
|
pre, abs bool
|
|
code int
|
|
src, dst string
|
|
)
|
|
err := rows.Scan(&src, &pre, &dst, &abs, &code)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = addRedirect(router, src, dst, target.Redirect{
|
|
Pre: pre,
|
|
Abs: abs,
|
|
Code: code,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// check for errors
|
|
return rows.Err()
|
|
}
|
|
|
|
func (m *Manager) Add(source string, route target.Route, active bool) {
|
|
m.s.Lock()
|
|
defer m.s.Unlock()
|
|
_, err := m.db.Exec(`INSERT INTO routes (source, pre, destination, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, source, route.Pre, path.Join(route.Host, route.Path), route.Abs, route.Cors, route.SecureMode, route.ForwardHost, route.ForwardAddr, route.IgnoreCert, active)
|
|
if err != nil {
|
|
log.Printf("[Violet] Database error: %s\n", err)
|
|
}
|
|
}
|
|
|
|
// addRoute is an alias to parse the src and dst then add the route
|
|
func addRoute(router *Router, src string, dst string, t target.Route) error {
|
|
srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// update target route values and add route
|
|
t.Host = dstHost
|
|
t.Port = dstPort
|
|
t.Path = dstPath
|
|
router.AddRoute(srcHost, srcPath, t)
|
|
return nil
|
|
}
|
|
|
|
// addRedirect is an alias to parse the src and dst then add the redirect
|
|
func addRedirect(router *Router, src string, dst string, t target.Redirect) error {
|
|
srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
t.Host = dstHost
|
|
t.Port = dstPort
|
|
t.Path = dstPath
|
|
router.AddRedirect(srcHost, srcPath, t)
|
|
return nil
|
|
}
|
|
|
|
// parseSrcDstHost extracts the host/path and host:port/path from the src and dst values
|
|
func parseSrcDstHost(src string, dst string) (string, string, string, int, string, error) {
|
|
// check if source has path
|
|
var srcHost, srcPath string
|
|
nSrc := strings.IndexByte(src, '/')
|
|
if nSrc == -1 {
|
|
// set host then path to /
|
|
srcHost = src
|
|
srcPath = "/"
|
|
} else {
|
|
// set host then custom path
|
|
srcHost = src[:nSrc]
|
|
srcPath = src[nSrc:]
|
|
}
|
|
|
|
// check if destination has path
|
|
var dstPath string
|
|
nDst := strings.IndexByte(dst, '/')
|
|
if nDst == -1 {
|
|
// set path to /
|
|
dstPath = "/"
|
|
} else {
|
|
// set custom path then trim dst string to the host
|
|
dstPath = dst[nDst:]
|
|
dst = dst[:nDst]
|
|
}
|
|
|
|
// try to split the destination host into domain + port
|
|
dstHost, dstPort, ok := utils.SplitDomainPort(dst, 0)
|
|
if !ok {
|
|
return "", "", "", 0, "", fmt.Errorf("failed to split destination '%s' into host + port", dst)
|
|
}
|
|
|
|
return srcHost, srcPath, dstHost, dstPort, dstPath, nil
|
|
}
|