mirror of
https://github.com/1f349/violet.git
synced 2024-11-23 11:51:37 +00:00
229 lines
5.1 KiB
Go
229 lines
5.1 KiB
Go
package router
|
|
|
|
import (
|
|
"context"
|
|
_ "embed"
|
|
"github.com/1f349/violet/database"
|
|
"github.com/1f349/violet/logger"
|
|
"github.com/1f349/violet/proxy"
|
|
"github.com/1f349/violet/target"
|
|
"github.com/mrmelon54/rescheduler"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
var Logger = logger.Logger.WithPrefix("Violet Manager")
|
|
|
|
// Manager is a database and mutex wrap around router allowing it to be
|
|
// dynamically regenerated after updating the database of routes.
|
|
type Manager struct {
|
|
db *database.Queries
|
|
s *sync.RWMutex
|
|
r *Router
|
|
p *proxy.HybridTransport
|
|
z *rescheduler.Rescheduler
|
|
}
|
|
|
|
// NewManager create a new manager, initialises the routes and redirects tables
|
|
// in the database and runs a first time compile.
|
|
func NewManager(db *database.Queries, proxy *proxy.HybridTransport) *Manager {
|
|
m := &Manager{
|
|
db: db,
|
|
s: &sync.RWMutex{},
|
|
r: New(proxy),
|
|
p: proxy,
|
|
}
|
|
m.z = rescheduler.NewRescheduler(m.threadCompile)
|
|
return m
|
|
}
|
|
|
|
func (m *Manager) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|
m.s.RLock()
|
|
r := m.r
|
|
m.s.RUnlock()
|
|
r.ServeHTTP(rw, req)
|
|
}
|
|
|
|
func (m *Manager) Compile() {
|
|
m.z.Run()
|
|
}
|
|
|
|
func (m *Manager) threadCompile() {
|
|
// new router
|
|
router := New(m.p)
|
|
|
|
// compile router and check errors
|
|
err := m.internalCompile(router)
|
|
if err != nil {
|
|
Logger.Info("Compile failed", "err", err)
|
|
return
|
|
}
|
|
|
|
// lock while replacing router
|
|
m.s.Lock()
|
|
m.r = router
|
|
m.s.Unlock()
|
|
}
|
|
|
|
// internalCompile is a hidden internal method for querying the database during
|
|
// the Compile() method.
|
|
func (m *Manager) internalCompile(router *Router) error {
|
|
Logger.Info("Updating routes from database")
|
|
|
|
// sql or something?
|
|
routeRows, err := m.db.GetActiveRoutes(context.Background())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, row := range routeRows {
|
|
router.AddRoute(target.Route{
|
|
Src: row.Source,
|
|
Dst: row.Destination,
|
|
Flags: row.Flags.NormaliseRouteFlags(),
|
|
})
|
|
}
|
|
|
|
// sql or something?
|
|
redirectsRows, err := m.db.GetActiveRedirects(context.Background())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, row := range redirectsRows {
|
|
router.AddRedirect(target.Redirect{
|
|
Src: row.Source,
|
|
Dst: row.Destination,
|
|
Flags: row.Flags.NormaliseRedirectFlags(),
|
|
Code: row.Code,
|
|
})
|
|
}
|
|
|
|
// check for errors
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) GetAllRoutes(hosts []string) ([]target.RouteWithActive, error) {
|
|
if len(hosts) < 1 {
|
|
return []target.RouteWithActive{}, nil
|
|
}
|
|
|
|
s := make([]target.RouteWithActive, 0)
|
|
|
|
rows, err := m.db.GetAllRoutes(context.Background())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, row := range rows {
|
|
a := target.RouteWithActive{
|
|
Route: target.Route{
|
|
Src: row.Source,
|
|
Dst: row.Destination,
|
|
Desc: row.Description,
|
|
Flags: row.Flags,
|
|
},
|
|
Active: row.Active,
|
|
}
|
|
|
|
for _, i := range hosts {
|
|
// if this is never true then the domain was mistakenly grabbed from the database
|
|
if a.OnDomain(i) {
|
|
s = append(s, a)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (m *Manager) InsertRoute(route target.RouteWithActive) error {
|
|
return m.db.AddRoute(context.Background(), database.AddRouteParams{
|
|
Source: route.Src,
|
|
Destination: route.Dst,
|
|
Description: route.Desc,
|
|
Flags: route.Flags,
|
|
Active: route.Active,
|
|
})
|
|
}
|
|
|
|
func (m *Manager) DeleteRoute(source string) error {
|
|
return m.db.RemoveRoute(context.Background(), source)
|
|
}
|
|
|
|
func (m *Manager) GetAllRedirects(hosts []string) ([]target.RedirectWithActive, error) {
|
|
if len(hosts) < 1 {
|
|
return []target.RedirectWithActive{}, nil
|
|
}
|
|
|
|
s := make([]target.RedirectWithActive, 0)
|
|
|
|
rows, err := m.db.GetAllRedirects(context.Background())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, row := range rows {
|
|
a := target.RedirectWithActive{
|
|
Redirect: target.Redirect{
|
|
Src: row.Source,
|
|
Dst: row.Destination,
|
|
Desc: row.Description,
|
|
Flags: row.Flags,
|
|
Code: row.Code,
|
|
},
|
|
Active: row.Active,
|
|
}
|
|
|
|
for _, i := range hosts {
|
|
// if this is never true then the domain was mistakenly grabbed from the database
|
|
if a.OnDomain(i) {
|
|
s = append(s, a)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (m *Manager) InsertRedirect(redirect target.RedirectWithActive) error {
|
|
return m.db.AddRedirect(context.Background(), database.AddRedirectParams{
|
|
Source: redirect.Src,
|
|
Destination: redirect.Dst,
|
|
Description: redirect.Desc,
|
|
Flags: redirect.Flags,
|
|
Code: redirect.Code,
|
|
Active: redirect.Active,
|
|
})
|
|
}
|
|
|
|
func (m *Manager) DeleteRedirect(source string) error {
|
|
return m.db.RemoveRedirect(context.Background(), source)
|
|
}
|
|
|
|
// GenerateHostSearch this should help improve performance
|
|
// TODO(Melon) discover how to implement this correctly
|
|
func GenerateHostSearch(hosts []string) (string, []string) {
|
|
var searchString strings.Builder
|
|
searchString.WriteString("WHERE ")
|
|
|
|
hostArgs := make([]string, len(hosts)*2)
|
|
for i := range hosts {
|
|
if i != 0 {
|
|
searchString.WriteString(" OR ")
|
|
}
|
|
// these like checks are not perfect but do reduce load on the database
|
|
searchString.WriteString("source LIKE '%' + ? + '/%'")
|
|
searchString.WriteString(" OR source LIKE '%' + ?")
|
|
|
|
// loads the hostname into even and odd args
|
|
hostArgs[i*2] = hosts[i]
|
|
hostArgs[i*2+1] = hosts[i]
|
|
}
|
|
|
|
return searchString.String(), hostArgs
|
|
}
|