mirror of
https://github.com/1f349/violet.git
synced 2024-11-21 19:01:39 +00:00
Write route/redirect APIs and rearrage some other code to make it possible
This commit is contained in:
parent
c930ddff28
commit
949dcd298a
@ -1,11 +1,11 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
|
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
|
||||||
<data-source source="LOCAL" name="__db.sqlite" uuid="5aeb4e88-8ec4-4227-a921-ba4eaed357bf">
|
<data-source source="LOCAL" name="identifier.sqlite" uuid="5b42d21a-92a8-43d0-8651-c1555b91060c">
|
||||||
<driver-ref>sqlite.xerial</driver-ref>
|
<driver-ref>sqlite.xerial</driver-ref>
|
||||||
<synchronize>true</synchronize>
|
<synchronize>true</synchronize>
|
||||||
<jdbc-driver>org.sqlite.JDBC</jdbc-driver>
|
<jdbc-driver>org.sqlite.JDBC</jdbc-driver>
|
||||||
<jdbc-url>jdbc:sqlite:__db.sqlite</jdbc-url>
|
<jdbc-url>jdbc:sqlite:identifier.sqlite</jdbc-url>
|
||||||
<working-dir>$ProjectFileDir$</working-dir>
|
<working-dir>$ProjectFileDir$</working-dir>
|
||||||
</data-source>
|
</data-source>
|
||||||
</component>
|
</component>
|
||||||
|
@ -14,6 +14,8 @@ import (
|
|||||||
"github.com/MrMelon54/violet/proxy"
|
"github.com/MrMelon54/violet/proxy"
|
||||||
"github.com/MrMelon54/violet/router"
|
"github.com/MrMelon54/violet/router"
|
||||||
"github.com/MrMelon54/violet/servers"
|
"github.com/MrMelon54/violet/servers"
|
||||||
|
"github.com/MrMelon54/violet/servers/api"
|
||||||
|
"github.com/MrMelon54/violet/servers/conf"
|
||||||
"github.com/MrMelon54/violet/utils"
|
"github.com/MrMelon54/violet/utils"
|
||||||
"github.com/google/subcommands"
|
"github.com/google/subcommands"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
@ -70,9 +72,9 @@ func (s *serveCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface{
|
|||||||
return subcommands.ExitSuccess
|
return subcommands.ExitSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalLoad(conf startUpConfig, wd string) {
|
func normalLoad(startUp startUpConfig, wd string) {
|
||||||
// the cert and key paths are useless in self-signed mode
|
// the cert and key paths are useless in self-signed mode
|
||||||
if !conf.SelfSigned {
|
if !startUp.SelfSigned {
|
||||||
// create path to cert dir
|
// create path to cert dir
|
||||||
err := os.MkdirAll(filepath.Join(wd, "certs"), os.ModePerm)
|
err := os.MkdirAll(filepath.Join(wd, "certs"), os.ModePerm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -87,11 +89,11 @@ func normalLoad(conf startUpConfig, wd string) {
|
|||||||
|
|
||||||
// errorPageDir stores an FS interface for accessing the error page directory
|
// errorPageDir stores an FS interface for accessing the error page directory
|
||||||
var errorPageDir fs.FS
|
var errorPageDir fs.FS
|
||||||
if conf.ErrorPagePath != "" {
|
if startUp.ErrorPagePath != "" {
|
||||||
errorPageDir = os.DirFS(conf.ErrorPagePath)
|
errorPageDir = os.DirFS(startUp.ErrorPagePath)
|
||||||
err := os.MkdirAll(conf.ErrorPagePath, os.ModePerm)
|
err := os.MkdirAll(startUp.ErrorPagePath, os.ModePerm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("[Violet] Failed to create error page path '%s'", conf.ErrorPagePath)
|
log.Fatalf("[Violet] Failed to create error page path '%s'", startUp.ErrorPagePath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,20 +112,20 @@ func normalLoad(conf startUpConfig, wd string) {
|
|||||||
certDir := os.DirFS(filepath.Join(wd, "certs"))
|
certDir := os.DirFS(filepath.Join(wd, "certs"))
|
||||||
keyDir := os.DirFS(filepath.Join(wd, "keys"))
|
keyDir := os.DirFS(filepath.Join(wd, "keys"))
|
||||||
|
|
||||||
allowedDomains := domains.New(db) // load allowed domains
|
allowedDomains := domains.New(db) // load allowed domains
|
||||||
acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store
|
acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store
|
||||||
allowedCerts := certs.New(certDir, keyDir, conf.SelfSigned) // load certificate manager
|
allowedCerts := certs.New(certDir, keyDir, startUp.SelfSigned) // load certificate manager
|
||||||
hybridTransport := proxy.NewHybridTransport() // load reverse proxy
|
hybridTransport := proxy.NewHybridTransport() // load reverse proxy
|
||||||
dynamicFavicons := favicons.New(db, conf.InkscapeCmd) // load dynamic favicon provider
|
dynamicFavicons := favicons.New(db, startUp.InkscapeCmd) // load dynamic favicon provider
|
||||||
dynamicErrorPages := errorPages.New(errorPageDir) // load dynamic error page provider
|
dynamicErrorPages := errorPages.New(errorPageDir) // load dynamic error page provider
|
||||||
dynamicRouter := router.NewManager(db, hybridTransport) // load dynamic router manager
|
dynamicRouter := router.NewManager(db, hybridTransport) // load dynamic router manager
|
||||||
|
|
||||||
// struct containing config for the http servers
|
// struct containing config for the http servers
|
||||||
srvConf := &servers.Conf{
|
srvConf := &conf.Conf{
|
||||||
ApiListen: conf.Listen.Api,
|
ApiListen: startUp.Listen.Api,
|
||||||
HttpListen: conf.Listen.Http,
|
HttpListen: startUp.Listen.Http,
|
||||||
HttpsListen: conf.Listen.Https,
|
HttpsListen: startUp.Listen.Https,
|
||||||
RateLimit: conf.RateLimit,
|
RateLimit: startUp.RateLimit,
|
||||||
DB: db,
|
DB: db,
|
||||||
Domains: allowedDomains,
|
Domains: allowedDomains,
|
||||||
Acme: acmeChallenges,
|
Acme: acmeChallenges,
|
||||||
@ -140,7 +142,7 @@ func normalLoad(conf startUpConfig, wd string) {
|
|||||||
|
|
||||||
var srvApi, srvHttp, srvHttps *http.Server
|
var srvApi, srvHttp, srvHttps *http.Server
|
||||||
if srvConf.ApiListen != "" {
|
if srvConf.ApiListen != "" {
|
||||||
srvApi = servers.NewApiServer(srvConf, allCompilables)
|
srvApi = api.NewApiServer(srvConf, allCompilables)
|
||||||
log.Printf("[API] Starting API server on: '%s'\n", srvApi.Addr)
|
log.Printf("[API] Starting API server on: '%s'\n", srvApi.Addr)
|
||||||
go utils.RunBackgroundHttp("API", srvApi)
|
go utils.RunBackgroundHttp("API", srvApi)
|
||||||
}
|
}
|
||||||
|
@ -181,13 +181,15 @@ func (s *setupCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
|
|||||||
// add with the route manager, no need to compile as this will run when opened
|
// add with the route manager, no need to compile as this will run when opened
|
||||||
// with the serve subcommand
|
// with the serve subcommand
|
||||||
routeManager := router.NewManager(db, proxy.NewHybridTransportWithCalls(&nilTransport{}, &nilTransport{}))
|
routeManager := router.NewManager(db, proxy.NewHybridTransportWithCalls(&nilTransport{}, &nilTransport{}))
|
||||||
routeManager.Add(path.Join(apiUrl.Host, apiUrl.Path), target.Route{
|
err = routeManager.InsertRoute(target.Route{
|
||||||
Pre: true,
|
Src: path.Join(apiUrl.Host, apiUrl.Path),
|
||||||
Host: answers.ApiListen,
|
Dst: answers.ApiListen,
|
||||||
Cors: true,
|
Flags: target.FlagPre | target.FlagCors | target.FlagForwardHost | target.FlagForwardAddr,
|
||||||
ForwardHost: true,
|
})
|
||||||
ForwardAddr: true,
|
if err != nil {
|
||||||
}, true)
|
fmt.Println("[Violet] Failed to insert api route into database: ", err)
|
||||||
|
return subcommands.ExitFailure
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("[Violet] Setup complete")
|
fmt.Println("[Violet] Setup complete")
|
||||||
|
7
favicons/create-table-favicons.sql
Normal file
7
favicons/create-table-favicons.sql
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS favicons (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
host VARCHAR,
|
||||||
|
svg VARCHAR,
|
||||||
|
png VARCHAR,
|
||||||
|
ico VARCHAR
|
||||||
|
);
|
@ -2,6 +2,7 @@ package favicons
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
_ "embed"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/MrMelon54/rescheduler"
|
"github.com/MrMelon54/rescheduler"
|
||||||
@ -12,6 +13,9 @@ import (
|
|||||||
|
|
||||||
var ErrFaviconNotFound = errors.New("favicon not found")
|
var ErrFaviconNotFound = errors.New("favicon not found")
|
||||||
|
|
||||||
|
//go:embed create-table-favicons.sql
|
||||||
|
var createTableFavicons string
|
||||||
|
|
||||||
// Favicons is a dynamic favicon generator which supports overwriting favicons
|
// Favicons is a dynamic favicon generator which supports overwriting favicons
|
||||||
type Favicons struct {
|
type Favicons struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
@ -32,7 +36,7 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons {
|
|||||||
f.r = rescheduler.NewRescheduler(f.threadCompile)
|
f.r = rescheduler.NewRescheduler(f.threadCompile)
|
||||||
|
|
||||||
// init favicons table
|
// init favicons table
|
||||||
_, err := f.db.Exec(`create table if not exists favicons (id integer primary key autoincrement, host varchar, svg varchar, png varchar, ico varchar)`)
|
_, err := f.db.Exec(createTableFavicons)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[WARN] Failed to generate 'favicons' table\n")
|
log.Printf("[WARN] Failed to generate 'favicons' table\n")
|
||||||
return nil
|
return nil
|
||||||
|
@ -1,10 +0,0 @@
|
|||||||
CREATE TABLE IF NOT EXISTS redirects
|
|
||||||
(
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
source TEXT,
|
|
||||||
pre INTEGER,
|
|
||||||
destination TEXT,
|
|
||||||
abs INTEGER,
|
|
||||||
code INTEGER,
|
|
||||||
active INTEGER DEFAULT 1
|
|
||||||
);
|
|
@ -1,14 +0,0 @@
|
|||||||
CREATE TABLE IF NOT EXISTS routes
|
|
||||||
(
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
source TEXT,
|
|
||||||
pre INTEGER,
|
|
||||||
destination TEXT,
|
|
||||||
abs INTEGER,
|
|
||||||
cors INTEGER,
|
|
||||||
secure_mode INTEGER,
|
|
||||||
forward_host INTEGER,
|
|
||||||
forward_addr INTEGER,
|
|
||||||
ignore_cert INTEGER,
|
|
||||||
active INTEGER DEFAULT 1
|
|
||||||
);
|
|
18
router/create-tables.sql
Normal file
18
router/create-tables.sql
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS routes
|
||||||
|
(
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
source TEXT UNIQUE,
|
||||||
|
destination TEXT,
|
||||||
|
flags INTEGER DEFAULT 0,
|
||||||
|
active INTEGER DEFAULT 1
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS redirects
|
||||||
|
(
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
source TEXT UNIQUE,
|
||||||
|
destination TEXT,
|
||||||
|
flags INTEGER DEFAULT 0,
|
||||||
|
code INTEGER DEFAULT 0,
|
||||||
|
active INTEGER DEFAULT 1
|
||||||
|
);
|
@ -3,15 +3,11 @@ package router
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"fmt"
|
|
||||||
"github.com/MrMelon54/rescheduler"
|
"github.com/MrMelon54/rescheduler"
|
||||||
"github.com/MrMelon54/violet/proxy"
|
"github.com/MrMelon54/violet/proxy"
|
||||||
"github.com/MrMelon54/violet/target"
|
"github.com/MrMelon54/violet/target"
|
||||||
"github.com/MrMelon54/violet/utils"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -26,14 +22,8 @@ type Manager struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
//go:embed create-table-routes.sql
|
//go:embed create-tables.sql
|
||||||
createTableRoutes string
|
createTables string
|
||||||
//go:embed create-table-redirects.sql
|
|
||||||
createTableRedirects string
|
|
||||||
//go:embed query-table-routes.sql
|
|
||||||
queryTableRoutes string
|
|
||||||
//go:embed query-table-redirects.sql
|
|
||||||
queryTableRedirects string
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewManager create a new manager, initialises the routes and redirects tables
|
// NewManager create a new manager, initialises the routes and redirects tables
|
||||||
@ -48,16 +38,9 @@ func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager {
|
|||||||
m.z = rescheduler.NewRescheduler(m.threadCompile)
|
m.z = rescheduler.NewRescheduler(m.threadCompile)
|
||||||
|
|
||||||
// init routes table
|
// init routes table
|
||||||
_, err := m.db.Exec(createTableRoutes)
|
_, err := m.db.Exec(createTables)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[WARN] Failed to generate 'routes' table\n")
|
log.Printf("[WARN] Failed to generate tables\n")
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// init redirects table
|
|
||||||
_, err = m.db.Exec(createTableRedirects)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[WARN] Failed to generate 'redirects' table\n")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return m
|
return m
|
||||||
@ -96,7 +79,7 @@ func (m *Manager) internalCompile(router *Router) error {
|
|||||||
log.Println("[Manager] Updating routes from database")
|
log.Println("[Manager] Updating routes from database")
|
||||||
|
|
||||||
// sql or something?
|
// sql or something?
|
||||||
rows, err := m.db.Query(queryTableRoutes)
|
rows, err := m.db.Query(`SELECT source, destination, flags FROM routes WHERE active = 1`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -105,26 +88,19 @@ func (m *Manager) internalCompile(router *Router) error {
|
|||||||
// loop through rows and scan the options
|
// loop through rows and scan the options
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
pre, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert bool
|
src, dst string
|
||||||
src, dst string
|
flags target.Flags
|
||||||
)
|
)
|
||||||
err := rows.Scan(&src, &pre, &dst, &abs, &cors, &secure_mode, &forward_host, &forward_addr, &ignore_cert)
|
err := rows.Scan(&src, &dst, &flags)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = addRoute(router, src, dst, target.Route{
|
router.AddRoute(target.Route{
|
||||||
Pre: pre,
|
Src: src,
|
||||||
Abs: abs,
|
Dst: dst,
|
||||||
Cors: cors,
|
Flags: flags.NormaliseRouteFlags(),
|
||||||
SecureMode: secure_mode,
|
|
||||||
ForwardHost: forward_host,
|
|
||||||
ForwardAddr: forward_addr,
|
|
||||||
IgnoreCert: ignore_cert,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// check for errors
|
// check for errors
|
||||||
@ -133,7 +109,7 @@ func (m *Manager) internalCompile(router *Router) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sql or something?
|
// sql or something?
|
||||||
rows, err = m.db.Query(queryTableRedirects)
|
rows, err = m.db.Query(`SELECT source,destination,flags,code FROM redirects WHERE active = 1`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -142,99 +118,51 @@ func (m *Manager) internalCompile(router *Router) error {
|
|||||||
// loop through rows and scan the options
|
// loop through rows and scan the options
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
pre, abs bool
|
|
||||||
code int
|
|
||||||
src, dst string
|
src, dst string
|
||||||
|
flags target.Flags
|
||||||
|
code int
|
||||||
)
|
)
|
||||||
err := rows.Scan(&src, &pre, &dst, &abs, &code)
|
err := rows.Scan(&src, &dst, &flags, &code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = addRedirect(router, src, dst, target.Redirect{
|
router.AddRedirect(target.Redirect{
|
||||||
Pre: pre,
|
Src: src,
|
||||||
Abs: abs,
|
Dst: dst,
|
||||||
Code: code,
|
Flags: flags.NormaliseRedirectFlags(),
|
||||||
|
Code: code,
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// check for errors
|
// check for errors
|
||||||
return rows.Err()
|
return rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) Add(source string, route target.Route, active bool) {
|
func (m *Manager) InsertRoute(route target.Route) error {
|
||||||
m.s.Lock()
|
m.s.Lock()
|
||||||
defer m.s.Unlock()
|
defer m.s.Unlock()
|
||||||
_, err := m.db.Exec(`INSERT INTO routes (source, pre, destination, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, source, route.Pre, path.Join(route.Host, route.Path), route.Abs, route.Cors, route.SecureMode, route.ForwardHost, route.ForwardAddr, route.IgnoreCert, active)
|
_, err := m.db.Exec(`INSERT INTO routes (source, destination, flags) VALUES (?, ?, ?) ON CONFLICT(source) DO UPDATE SET destination = excluded.destination, flags = excluded.flags, active = 1`, route.Src, route.Dst, route.Flags)
|
||||||
if err != nil {
|
return err
|
||||||
log.Printf("[Violet] Database error: %s\n", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// addRoute is an alias to parse the src and dst then add the route
|
func (m *Manager) DeleteRoute(source string) error {
|
||||||
func addRoute(router *Router, src string, dst string, t target.Route) error {
|
m.s.Lock()
|
||||||
srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst)
|
defer m.s.Unlock()
|
||||||
if err != nil {
|
_, err := m.db.Exec(`UPDATE routes SET active = 0 WHERE source = ?`, source)
|
||||||
return err
|
return err
|
||||||
}
|
|
||||||
|
|
||||||
// update target route values and add route
|
|
||||||
t.Host = dstHost
|
|
||||||
t.Port = dstPort
|
|
||||||
t.Path = dstPath
|
|
||||||
router.AddRoute(srcHost, srcPath, t)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// addRedirect is an alias to parse the src and dst then add the redirect
|
func (m *Manager) InsertRedirect(redirect target.Redirect) error {
|
||||||
func addRedirect(router *Router, src string, dst string, t target.Redirect) error {
|
m.s.Lock()
|
||||||
srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst)
|
defer m.s.Unlock()
|
||||||
if err != nil {
|
_, err := m.db.Exec(`INSERT INTO redirects (source, destination, flags, code) VALUES (?, ?, ?, ?) ON CONFLICT(source) DO UPDATE SET destination = excluded.destination, flags = excluded.flags, code = excluded.code, active = 1`, redirect.Src, redirect.Dst, redirect.Flags, redirect.Code)
|
||||||
return err
|
return err
|
||||||
}
|
|
||||||
|
|
||||||
t.Host = dstHost
|
|
||||||
t.Port = dstPort
|
|
||||||
t.Path = dstPath
|
|
||||||
router.AddRedirect(srcHost, srcPath, t)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseSrcDstHost extracts the host/path and host:port/path from the src and dst values
|
func (m *Manager) DeleteRedirect(source string) error {
|
||||||
func parseSrcDstHost(src string, dst string) (string, string, string, int, string, error) {
|
m.s.Lock()
|
||||||
// check if source has path
|
defer m.s.Unlock()
|
||||||
var srcHost, srcPath string
|
_, err := m.db.Exec(`UPDATE redirects SET active = 0 WHERE source = ?`, source)
|
||||||
nSrc := strings.IndexByte(src, '/')
|
return err
|
||||||
if nSrc == -1 {
|
|
||||||
// set host then path to /
|
|
||||||
srcHost = src
|
|
||||||
srcPath = "/"
|
|
||||||
} else {
|
|
||||||
// set host then custom path
|
|
||||||
srcHost = src[:nSrc]
|
|
||||||
srcPath = src[nSrc:]
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if destination has path
|
|
||||||
var dstPath string
|
|
||||||
nDst := strings.IndexByte(dst, '/')
|
|
||||||
if nDst == -1 {
|
|
||||||
// set path to /
|
|
||||||
dstPath = "/"
|
|
||||||
} else {
|
|
||||||
// set custom path then trim dst string to the host
|
|
||||||
dstPath = dst[nDst:]
|
|
||||||
dst = dst[:nDst]
|
|
||||||
}
|
|
||||||
|
|
||||||
// try to split the destination host into domain + port
|
|
||||||
dstHost, dstPort, ok := utils.SplitDomainPort(dst, 0)
|
|
||||||
if !ok {
|
|
||||||
return "", "", "", 0, "", fmt.Errorf("failed to split destination '%s' into host + port", dst)
|
|
||||||
}
|
|
||||||
|
|
||||||
return srcHost, srcPath, dstHost, dstPort, dstPath, nil
|
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package router
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"github.com/MrMelon54/violet/proxy"
|
"github.com/MrMelon54/violet/proxy"
|
||||||
|
"github.com/MrMelon54/violet/target"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -37,7 +38,7 @@ func TestNewManager(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusTeapot, res.StatusCode)
|
assert.Equal(t, http.StatusTeapot, res.StatusCode)
|
||||||
assert.Nil(t, ft.req)
|
assert.Nil(t, ft.req)
|
||||||
|
|
||||||
_, err = db.Exec(`INSERT INTO routes (source, pre, destination, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert, active) VALUES (?,?,?,?,?,?,?,?,?,?)`, "*.example.com", 0, "127.0.0.1:8080", 1, 0, 0, 1, 1, 0, 1)
|
_, err = db.Exec(`INSERT INTO routes (source, destination, flags, active) VALUES (?,?,?,1)`, "*.example.com", "127.0.0.1:8080", target.FlagAbs|target.FlagForwardHost|target.FlagForwardAddr)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.NoError(t, m.internalCompile(m.r))
|
assert.NoError(t, m.internalCompile(m.r))
|
||||||
|
@ -1,7 +0,0 @@
|
|||||||
select source,
|
|
||||||
pre,
|
|
||||||
destination,
|
|
||||||
abs,
|
|
||||||
code
|
|
||||||
from redirects
|
|
||||||
where active = true
|
|
@ -1,11 +0,0 @@
|
|||||||
select source,
|
|
||||||
pre,
|
|
||||||
destination,
|
|
||||||
abs,
|
|
||||||
cors,
|
|
||||||
secure_mode,
|
|
||||||
forward_host,
|
|
||||||
forward_addr,
|
|
||||||
ignore_cert
|
|
||||||
from routes
|
|
||||||
where active = true
|
|
@ -46,16 +46,14 @@ func (r *Router) hostRedirect(host string) *trie.Trie[target.Redirect] {
|
|||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) AddService(host string, t target.Route) {
|
func (r *Router) AddRoute(t target.Route) {
|
||||||
r.AddRoute(host, "/", t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Router) AddRoute(host string, path string, t target.Route) {
|
|
||||||
t.Proxy = r.proxy
|
t.Proxy = r.proxy
|
||||||
|
host, path := utils.SplitHostPath(t.Src)
|
||||||
r.hostRoute(host).PutString(path, t)
|
r.hostRoute(host).PutString(path, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) AddRedirect(host, path string, t target.Redirect) {
|
func (r *Router) AddRedirect(t target.Redirect) {
|
||||||
|
host, path := utils.SplitHostPath(t.Src)
|
||||||
r.hostRedirect(host).PutString(path, t)
|
r.hostRedirect(host).PutString(path, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,7 +93,7 @@ func (r *Router) serveRouteHTTP(rw http.ResponseWriter, req *http.Request, host
|
|||||||
if h != nil {
|
if h != nil {
|
||||||
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
|
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
|
||||||
for i := len(pairs) - 1; i >= 0; i-- {
|
for i := len(pairs) - 1; i >= 0; i-- {
|
||||||
if pairs[i].Value.Pre || pairs[i].Key == req.URL.Path {
|
if pairs[i].Value.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
|
||||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
|
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
|
||||||
pairs[i].Value.ServeHTTP(rw, req)
|
pairs[i].Value.ServeHTTP(rw, req)
|
||||||
return true
|
return true
|
||||||
@ -110,7 +108,7 @@ func (r *Router) serveRedirectHTTP(rw http.ResponseWriter, req *http.Request, ho
|
|||||||
if h != nil {
|
if h != nil {
|
||||||
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
|
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
|
||||||
for i := len(pairs) - 1; i >= 0; i-- {
|
for i := len(pairs) - 1; i >= 0; i-- {
|
||||||
if pairs[i].Value.Pre || pairs[i].Key == req.URL.Path {
|
if pairs[i].Value.Flags.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
|
||||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
|
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
|
||||||
pairs[i].Value.ServeHTTP(rw, req)
|
pairs[i].Value.ServeHTTP(rw, req)
|
||||||
return true
|
return true
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"path"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -29,31 +30,31 @@ var (
|
|||||||
"/": "/",
|
"/": "/",
|
||||||
"/hello": "",
|
"/hello": "",
|
||||||
}},
|
}},
|
||||||
{"/", target.Route{Path: "/world"}, mss{
|
{"/", target.Route{Dst: "world"}, mss{
|
||||||
"/": "/world",
|
"/": "/world",
|
||||||
"/hello": "",
|
"/hello": "",
|
||||||
}},
|
}},
|
||||||
{"/", target.Route{Abs: true}, mss{
|
{"/", target.Route{Flags: target.FlagAbs}, mss{
|
||||||
"/": "/",
|
"/": "/",
|
||||||
"/hello": "",
|
"/hello": "",
|
||||||
}},
|
}},
|
||||||
{"/", target.Route{Abs: true, Path: "world"}, mss{
|
{"/", target.Route{Flags: target.FlagAbs, Dst: "world"}, mss{
|
||||||
"/": "/world",
|
"/": "/world",
|
||||||
"/hello": "",
|
"/hello": "",
|
||||||
}},
|
}},
|
||||||
{"/", target.Route{Pre: true}, mss{
|
{"/", target.Route{Flags: target.FlagPre}, mss{
|
||||||
"/": "/",
|
"/": "/",
|
||||||
"/hello": "/hello",
|
"/hello": "/hello",
|
||||||
}},
|
}},
|
||||||
{"/", target.Route{Pre: true, Path: "world"}, mss{
|
{"/", target.Route{Flags: target.FlagPre, Dst: "world"}, mss{
|
||||||
"/": "/world",
|
"/": "/world",
|
||||||
"/hello": "/world/hello",
|
"/hello": "/world/hello",
|
||||||
}},
|
}},
|
||||||
{"/", target.Route{Pre: true, Abs: true}, mss{
|
{"/", target.Route{Flags: target.FlagPre | target.FlagAbs}, mss{
|
||||||
"/": "/",
|
"/": "/",
|
||||||
"/hello": "/",
|
"/hello": "/",
|
||||||
}},
|
}},
|
||||||
{"/", target.Route{Pre: true, Abs: true, Path: "world"}, mss{
|
{"/", target.Route{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{
|
||||||
"/": "/world",
|
"/": "/world",
|
||||||
"/hello": "/world",
|
"/hello": "/world",
|
||||||
}},
|
}},
|
||||||
@ -62,37 +63,37 @@ var (
|
|||||||
"/hello": "/",
|
"/hello": "/",
|
||||||
"/hello/hi": "",
|
"/hello/hi": "",
|
||||||
}},
|
}},
|
||||||
{"/hello", target.Route{Path: "world"}, mss{
|
{"/hello", target.Route{Dst: "world"}, mss{
|
||||||
"/": "",
|
"/": "",
|
||||||
"/hello": "/world",
|
"/hello": "/world",
|
||||||
"/hello/hi": "",
|
"/hello/hi": "",
|
||||||
}},
|
}},
|
||||||
{"/hello", target.Route{Abs: true}, mss{
|
{"/hello", target.Route{Flags: target.FlagAbs}, mss{
|
||||||
"/": "",
|
"/": "",
|
||||||
"/hello": "/",
|
"/hello": "/",
|
||||||
"/hello/hi": "",
|
"/hello/hi": "",
|
||||||
}},
|
}},
|
||||||
{"/hello", target.Route{Abs: true, Path: "world"}, mss{
|
{"/hello", target.Route{Flags: target.FlagAbs, Dst: "world"}, mss{
|
||||||
"/": "",
|
"/": "",
|
||||||
"/hello": "/world",
|
"/hello": "/world",
|
||||||
"/hello/hi": "",
|
"/hello/hi": "",
|
||||||
}},
|
}},
|
||||||
{"/hello", target.Route{Pre: true}, mss{
|
{"/hello", target.Route{Flags: target.FlagPre}, mss{
|
||||||
"/": "",
|
"/": "",
|
||||||
"/hello": "/",
|
"/hello": "/",
|
||||||
"/hello/hi": "/hi",
|
"/hello/hi": "/hi",
|
||||||
}},
|
}},
|
||||||
{"/hello", target.Route{Pre: true, Path: "world"}, mss{
|
{"/hello", target.Route{Flags: target.FlagPre, Dst: "world"}, mss{
|
||||||
"/": "",
|
"/": "",
|
||||||
"/hello": "/world",
|
"/hello": "/world",
|
||||||
"/hello/hi": "/world/hi",
|
"/hello/hi": "/world/hi",
|
||||||
}},
|
}},
|
||||||
{"/hello", target.Route{Pre: true, Abs: true}, mss{
|
{"/hello", target.Route{Flags: target.FlagPre | target.FlagAbs}, mss{
|
||||||
"/": "",
|
"/": "",
|
||||||
"/hello": "/",
|
"/hello": "/",
|
||||||
"/hello/hi": "/",
|
"/hello/hi": "/",
|
||||||
}},
|
}},
|
||||||
{"/hello", target.Route{Pre: true, Abs: true, Path: "world"}, mss{
|
{"/hello", target.Route{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{
|
||||||
"/": "",
|
"/": "",
|
||||||
"/hello": "/world",
|
"/hello": "/world",
|
||||||
"/hello/hi": "/world",
|
"/hello/hi": "/world",
|
||||||
@ -103,31 +104,31 @@ var (
|
|||||||
"/": "/",
|
"/": "/",
|
||||||
"/hello": "",
|
"/hello": "",
|
||||||
}},
|
}},
|
||||||
{"/", target.Redirect{Path: "world"}, mss{
|
{"/", target.Redirect{Dst: "world"}, mss{
|
||||||
"/": "/world",
|
"/": "/world",
|
||||||
"/hello": "",
|
"/hello": "",
|
||||||
}},
|
}},
|
||||||
{"/", target.Redirect{Abs: true}, mss{
|
{"/", target.Redirect{Flags: target.FlagAbs}, mss{
|
||||||
"/": "/",
|
"/": "/",
|
||||||
"/hello": "",
|
"/hello": "",
|
||||||
}},
|
}},
|
||||||
{"/", target.Redirect{Abs: true, Path: "world"}, mss{
|
{"/", target.Redirect{Flags: target.FlagAbs, Dst: "world"}, mss{
|
||||||
"/": "/world",
|
"/": "/world",
|
||||||
"/hello": "",
|
"/hello": "",
|
||||||
}},
|
}},
|
||||||
{"/", target.Redirect{Pre: true}, mss{
|
{"/", target.Redirect{Flags: target.FlagPre}, mss{
|
||||||
"/": "/",
|
"/": "/",
|
||||||
"/hello": "/hello",
|
"/hello": "/hello",
|
||||||
}},
|
}},
|
||||||
{"/", target.Redirect{Pre: true, Path: "world"}, mss{
|
{"/", target.Redirect{Flags: target.FlagPre, Dst: "world"}, mss{
|
||||||
"/": "/world",
|
"/": "/world",
|
||||||
"/hello": "/world/hello",
|
"/hello": "/world/hello",
|
||||||
}},
|
}},
|
||||||
{"/", target.Redirect{Pre: true, Abs: true}, mss{
|
{"/", target.Redirect{Flags: target.FlagPre | target.FlagAbs}, mss{
|
||||||
"/": "/",
|
"/": "/",
|
||||||
"/hello": "/",
|
"/hello": "/",
|
||||||
}},
|
}},
|
||||||
{"/", target.Redirect{Pre: true, Abs: true, Path: "world"}, mss{
|
{"/", target.Redirect{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{
|
||||||
"/": "/world",
|
"/": "/world",
|
||||||
"/hello": "/world",
|
"/hello": "/world",
|
||||||
}},
|
}},
|
||||||
@ -136,37 +137,37 @@ var (
|
|||||||
"/hello": "/",
|
"/hello": "/",
|
||||||
"/hello/hi": "",
|
"/hello/hi": "",
|
||||||
}},
|
}},
|
||||||
{"/hello", target.Redirect{Path: "world"}, mss{
|
{"/hello", target.Redirect{Dst: "world"}, mss{
|
||||||
"/": "",
|
"/": "",
|
||||||
"/hello": "/world",
|
"/hello": "/world",
|
||||||
"/hello/hi": "",
|
"/hello/hi": "",
|
||||||
}},
|
}},
|
||||||
{"/hello", target.Redirect{Abs: true}, mss{
|
{"/hello", target.Redirect{Flags: target.FlagAbs}, mss{
|
||||||
"/": "",
|
"/": "",
|
||||||
"/hello": "/",
|
"/hello": "/",
|
||||||
"/hello/hi": "",
|
"/hello/hi": "",
|
||||||
}},
|
}},
|
||||||
{"/hello", target.Redirect{Abs: true, Path: "world"}, mss{
|
{"/hello", target.Redirect{Flags: target.FlagAbs, Dst: "world"}, mss{
|
||||||
"/": "",
|
"/": "",
|
||||||
"/hello": "/world",
|
"/hello": "/world",
|
||||||
"/hello/hi": "",
|
"/hello/hi": "",
|
||||||
}},
|
}},
|
||||||
{"/hello", target.Redirect{Pre: true}, mss{
|
{"/hello", target.Redirect{Flags: target.FlagPre}, mss{
|
||||||
"/": "",
|
"/": "",
|
||||||
"/hello": "/",
|
"/hello": "/",
|
||||||
"/hello/hi": "/hi",
|
"/hello/hi": "/hi",
|
||||||
}},
|
}},
|
||||||
{"/hello", target.Redirect{Pre: true, Path: "world"}, mss{
|
{"/hello", target.Redirect{Flags: target.FlagPre, Dst: "world"}, mss{
|
||||||
"/": "",
|
"/": "",
|
||||||
"/hello": "/world",
|
"/hello": "/world",
|
||||||
"/hello/hi": "/world/hi",
|
"/hello/hi": "/world/hi",
|
||||||
}},
|
}},
|
||||||
{"/hello", target.Redirect{Pre: true, Abs: true}, mss{
|
{"/hello", target.Redirect{Flags: target.FlagPre | target.FlagAbs}, mss{
|
||||||
"/": "",
|
"/": "",
|
||||||
"/hello": "/",
|
"/hello": "/",
|
||||||
"/hello/hi": "/",
|
"/hello/hi": "/",
|
||||||
}},
|
}},
|
||||||
{"/hello", target.Redirect{Pre: true, Abs: true, Path: "world"}, mss{
|
{"/hello", target.Redirect{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{
|
||||||
"/": "",
|
"/": "",
|
||||||
"/hello": "/world",
|
"/hello": "/world",
|
||||||
"/hello/hi": "/world",
|
"/hello/hi": "/world",
|
||||||
@ -181,10 +182,10 @@ func TestRouter_AddRoute(t *testing.T) {
|
|||||||
for _, i := range routeTests {
|
for _, i := range routeTests {
|
||||||
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure))
|
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure))
|
||||||
dst := i.dst
|
dst := i.dst
|
||||||
dst.Host = "127.0.0.1"
|
dst.Dst = path.Join("127.0.0.1:8080", dst.Dst)
|
||||||
dst.Port = 8080
|
dst.Src = path.Join("example.com", i.path)
|
||||||
t.Logf("Running tests for %#v\n", dst)
|
t.Logf("Running tests for %#v\n", dst)
|
||||||
r.AddRoute("example.com", i.path, dst)
|
r.AddRoute(dst)
|
||||||
for k, v := range i.tests {
|
for k, v := range i.tests {
|
||||||
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: k}
|
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: k}
|
||||||
req, _ := http.NewRequest(http.MethodGet, u1.String(), nil)
|
req, _ := http.NewRequest(http.MethodGet, u1.String(), nil)
|
||||||
@ -217,10 +218,11 @@ func TestRouter_AddRedirect(t *testing.T) {
|
|||||||
for _, i := range redirectTests {
|
for _, i := range redirectTests {
|
||||||
r := New(nil)
|
r := New(nil)
|
||||||
dst := i.dst
|
dst := i.dst
|
||||||
dst.Host = "example.com"
|
dst.Dst = path.Join("example.com", dst.Dst)
|
||||||
dst.Code = http.StatusFound
|
dst.Code = http.StatusFound
|
||||||
|
dst.Src = path.Join("www.example.com", i.path)
|
||||||
t.Logf("Running tests for %#v\n", dst)
|
t.Logf("Running tests for %#v\n", dst)
|
||||||
r.AddRedirect("www.example.com", i.path, dst)
|
r.AddRedirect(dst)
|
||||||
for k, v := range i.tests {
|
for k, v := range i.tests {
|
||||||
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: v}
|
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: v}
|
||||||
if v == "" {
|
if v == "" {
|
||||||
@ -266,10 +268,10 @@ func TestRouter_AddWildcardRoute(t *testing.T) {
|
|||||||
for _, i := range routeTests {
|
for _, i := range routeTests {
|
||||||
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure))
|
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure))
|
||||||
dst := i.dst
|
dst := i.dst
|
||||||
dst.Host = "127.0.0.1"
|
dst.Dst = path.Join("127.0.0.1:8080", dst.Dst)
|
||||||
dst.Port = 8080
|
dst.Src = path.Join("example.com", i.path)
|
||||||
t.Logf("Running tests for %#v\n", dst)
|
t.Logf("Running tests for %#v\n", dst)
|
||||||
r.AddRoute("example.com", i.path, dst)
|
r.AddRoute(dst)
|
||||||
for k, v := range i.tests {
|
for k, v := range i.tests {
|
||||||
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: k}
|
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: k}
|
||||||
req, _ := http.NewRequest(http.MethodGet, u1.String(), nil)
|
req, _ := http.NewRequest(http.MethodGet, u1.String(), nil)
|
||||||
|
115
servers/api.go
115
servers/api.go
@ -1,115 +0,0 @@
|
|||||||
package servers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/MrMelon54/mjwt"
|
|
||||||
"github.com/MrMelon54/mjwt/auth"
|
|
||||||
"github.com/MrMelon54/violet/utils"
|
|
||||||
"github.com/julienschmidt/httprouter"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewApiServer creates and runs a http server containing all the API
|
|
||||||
// endpoints for the software
|
|
||||||
//
|
|
||||||
// `/compile` - reloads all domains, routes and redirects
|
|
||||||
func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server {
|
|
||||||
r := httprouter.New()
|
|
||||||
|
|
||||||
// Endpoint for compile action
|
|
||||||
r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
|
|
||||||
if !hasPerms(conf.Signer, req, "violet:compile") {
|
|
||||||
utils.RespondHttpStatus(rw, http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Trigger the compile action
|
|
||||||
compileTarget.Compile()
|
|
||||||
rw.WriteHeader(http.StatusAccepted)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Endpoint for domains
|
|
||||||
r.PUT("/domain/:domain", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
|
||||||
if !hasPerms(conf.Signer, req, "violet:domains") {
|
|
||||||
utils.RespondHttpStatus(rw, http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// add domain with active state
|
|
||||||
q := req.URL.Query()
|
|
||||||
conf.Domains.Put(params.ByName("domain"), q.Get("active") == "1")
|
|
||||||
conf.Domains.Compile()
|
|
||||||
})
|
|
||||||
r.DELETE("/domain/:domain", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
|
||||||
if !hasPerms(conf.Signer, req, "violet:domains") {
|
|
||||||
utils.RespondHttpStatus(rw, http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// add domain with active state
|
|
||||||
q := req.URL.Query()
|
|
||||||
conf.Domains.Put(params.ByName("domain"), q.Get("active") == "1")
|
|
||||||
conf.Domains.Compile()
|
|
||||||
})
|
|
||||||
|
|
||||||
// Endpoint for routes
|
|
||||||
r.POST("/route", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
// Endpoint for acme-challenge
|
|
||||||
r.PUT("/acme-challenge/:domain/:key/:value", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
|
||||||
if !hasPerms(conf.Signer, req, "violet:acme-challenge") {
|
|
||||||
utils.RespondHttpStatus(rw, http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
domain := params.ByName("domain")
|
|
||||||
if !conf.Domains.IsValid(domain) {
|
|
||||||
utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
conf.Acme.Put(domain, params.ByName("key"), params.ByName("value"))
|
|
||||||
rw.WriteHeader(http.StatusAccepted)
|
|
||||||
})
|
|
||||||
r.DELETE("/acme-challenge/:domain/:key", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
|
||||||
if !hasPerms(conf.Signer, req, "violet:acme-challenge") {
|
|
||||||
utils.RespondHttpStatus(rw, http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
domain := params.ByName("domain")
|
|
||||||
if !conf.Domains.IsValid(domain) {
|
|
||||||
utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
conf.Acme.Delete(domain, params.ByName("key"))
|
|
||||||
rw.WriteHeader(http.StatusAccepted)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create and run http server
|
|
||||||
return &http.Server{
|
|
||||||
Addr: conf.ApiListen,
|
|
||||||
Handler: r,
|
|
||||||
ReadTimeout: time.Minute,
|
|
||||||
ReadHeaderTimeout: time.Minute,
|
|
||||||
WriteTimeout: time.Minute,
|
|
||||||
IdleTimeout: time.Minute,
|
|
||||||
MaxHeaderBytes: 2500,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasPerms(verify mjwt.Verifier, req *http.Request, perm string) bool {
|
|
||||||
// Get bearer token
|
|
||||||
bearer := utils.GetBearer(req)
|
|
||||||
if bearer == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read claims from mjwt
|
|
||||||
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Token must have perm
|
|
||||||
return b.Claims.Perms.Has(perm)
|
|
||||||
}
|
|
102
servers/api/api.go
Normal file
102
servers/api/api.go
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/MrMelon54/mjwt"
|
||||||
|
"github.com/MrMelon54/mjwt/claims"
|
||||||
|
"github.com/MrMelon54/violet/servers/conf"
|
||||||
|
"github.com/MrMelon54/violet/utils"
|
||||||
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewApiServer creates and runs a http server containing all the API
|
||||||
|
// endpoints for the software
|
||||||
|
//
|
||||||
|
// `/compile` - reloads all domains, routes and redirects
|
||||||
|
func NewApiServer(conf *conf.Conf, compileTarget utils.MultiCompilable) *http.Server {
|
||||||
|
r := httprouter.New()
|
||||||
|
|
||||||
|
// Endpoint for compile action
|
||||||
|
r.POST("/compile", checkAuthWithPerm(conf.Signer, "violet:compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, b AuthClaims) {
|
||||||
|
// Trigger the compile action
|
||||||
|
compileTarget.Compile()
|
||||||
|
rw.WriteHeader(http.StatusAccepted)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Endpoint for domains
|
||||||
|
domainFunc := domainManage(conf.Signer, conf.Domains)
|
||||||
|
r.PUT("/domain/:domain", domainFunc)
|
||||||
|
r.DELETE("/domain/:domain", domainFunc)
|
||||||
|
|
||||||
|
// Endpoint code for target routes/redirects
|
||||||
|
targetApis := SetupTargetApis(conf.Signer, conf.Router)
|
||||||
|
|
||||||
|
// Endpoint for routes
|
||||||
|
r.POST("/route", targetApis.CreateRoute)
|
||||||
|
r.DELETE("/route", targetApis.DeleteRoute)
|
||||||
|
|
||||||
|
// Endpoint for redirects
|
||||||
|
r.POST("/redirect", targetApis.CreateRedirect)
|
||||||
|
r.DELETE("/redirect", targetApis.DeleteRedirect)
|
||||||
|
|
||||||
|
// Endpoint for acme-challenge
|
||||||
|
acmeChallengeFunc := acmeChallengeManage(conf.Signer, conf.Domains, conf.Acme)
|
||||||
|
r.PUT("/acme-challenge/:domain/:key/:value", acmeChallengeFunc)
|
||||||
|
r.DELETE("/acme-challenge/:domain/:key", acmeChallengeFunc)
|
||||||
|
|
||||||
|
// Create and run http server
|
||||||
|
return &http.Server{
|
||||||
|
Addr: conf.ApiListen,
|
||||||
|
Handler: r,
|
||||||
|
ReadTimeout: time.Minute,
|
||||||
|
ReadHeaderTimeout: time.Minute,
|
||||||
|
WriteTimeout: time.Minute,
|
||||||
|
IdleTimeout: time.Minute,
|
||||||
|
MaxHeaderBytes: 2500,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// apiError outputs a generic JSON error message
|
||||||
|
func apiError(rw http.ResponseWriter, code int, m string) {
|
||||||
|
rw.WriteHeader(code)
|
||||||
|
_ = json.NewEncoder(rw).Encode(map[string]string{
|
||||||
|
"error": m,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func domainManage(verify mjwt.Verifier, domains utils.DomainProvider) httprouter.Handle {
|
||||||
|
return checkAuthWithPerm(verify, "violet:domains", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
|
||||||
|
// add domain with active state
|
||||||
|
domains.Put(params.ByName("domain"), req.Method == http.MethodPut)
|
||||||
|
domains.Compile()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func acmeChallengeManage(verify mjwt.Verifier, domains utils.DomainProvider, acme utils.AcmeChallengeProvider) httprouter.Handle {
|
||||||
|
return checkAuthWithPerm(verify, "violet:acme-challenge", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
|
||||||
|
domain := params.ByName("domain")
|
||||||
|
if !domains.IsValid(domain) {
|
||||||
|
utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Method == http.MethodPut {
|
||||||
|
acme.Put(domain, params.ByName("key"), params.ByName("value"))
|
||||||
|
} else {
|
||||||
|
acme.Delete(domain, params.ByName("key"))
|
||||||
|
}
|
||||||
|
rw.WriteHeader(http.StatusAccepted)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateDomainOwnershipClaims validates if the claims contain the
|
||||||
|
// `owns=<fqdn>` field with the matching top level domain
|
||||||
|
func validateDomainOwnershipClaims(a string, perms *claims.PermStorage) bool {
|
||||||
|
if fqdn, ok := utils.GetTopFqdn(a); ok {
|
||||||
|
if perms.Has("owns=" + fqdn) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
@ -1,59 +1,22 @@
|
|||||||
package servers
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"github.com/MrMelon54/violet/servers/conf"
|
||||||
"crypto/rsa"
|
|
||||||
"github.com/MrMelon54/mjwt"
|
|
||||||
"github.com/MrMelon54/mjwt/auth"
|
|
||||||
"github.com/MrMelon54/mjwt/claims"
|
|
||||||
"github.com/MrMelon54/violet/utils"
|
"github.com/MrMelon54/violet/utils"
|
||||||
|
"github.com/MrMelon54/violet/utils/fake"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var snakeOilProv = genSnakeOilProv()
|
|
||||||
|
|
||||||
type fakeDomains struct{}
|
|
||||||
|
|
||||||
func (f *fakeDomains) IsValid(host string) bool { return host == "example.com" }
|
|
||||||
func (f *fakeDomains) Put(string, bool) {}
|
|
||||||
func (f *fakeDomains) Delete(string) {}
|
|
||||||
func (f *fakeDomains) Compile() {}
|
|
||||||
|
|
||||||
func genSnakeOilProv() mjwt.Signer {
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return mjwt.NewMJwtSigner("violet.test", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func genSnakeOilKey(perm string) string {
|
|
||||||
p := claims.NewPermStorage()
|
|
||||||
p.Set(perm)
|
|
||||||
val, err := snakeOilProv.GenerateJwt("abc", "abc", nil, 5*time.Minute, auth.AccessTokenClaims{Perms: p})
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeCompilable struct{ done bool }
|
|
||||||
|
|
||||||
func (f *fakeCompilable) Compile() { f.done = true }
|
|
||||||
|
|
||||||
var _ utils.Compilable = &fakeCompilable{}
|
|
||||||
|
|
||||||
func TestNewApiServer_Compile(t *testing.T) {
|
func TestNewApiServer_Compile(t *testing.T) {
|
||||||
apiConf := &Conf{
|
apiConf := &conf.Conf{
|
||||||
Domains: &fakeDomains{},
|
Domains: &fake.Domains{},
|
||||||
Acme: utils.NewAcmeChallenge(),
|
Acme: utils.NewAcmeChallenge(),
|
||||||
Signer: snakeOilProv,
|
Signer: fake.SnakeOilProv,
|
||||||
}
|
}
|
||||||
f := &fakeCompilable{}
|
f := &fake.Compilable{}
|
||||||
srv := NewApiServer(apiConf, utils.MultiCompilable{f})
|
srv := NewApiServer(apiConf, utils.MultiCompilable{f})
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodPost, "https://example.com/compile", nil)
|
req, err := http.NewRequest(http.MethodPost, "https://example.com/compile", nil)
|
||||||
@ -63,25 +26,25 @@ func TestNewApiServer_Compile(t *testing.T) {
|
|||||||
srv.Handler.ServeHTTP(rec, req)
|
srv.Handler.ServeHTTP(rec, req)
|
||||||
res := rec.Result()
|
res := rec.Result()
|
||||||
assert.Equal(t, http.StatusForbidden, res.StatusCode)
|
assert.Equal(t, http.StatusForbidden, res.StatusCode)
|
||||||
assert.False(t, f.done)
|
assert.False(t, f.Done)
|
||||||
|
|
||||||
req.Header.Set("Authorization", "Bearer "+genSnakeOilKey("violet:compile"))
|
req.Header.Set("Authorization", "Bearer "+fake.GenSnakeOilKey("violet:compile"))
|
||||||
|
|
||||||
rec = httptest.NewRecorder()
|
rec = httptest.NewRecorder()
|
||||||
srv.Handler.ServeHTTP(rec, req)
|
srv.Handler.ServeHTTP(rec, req)
|
||||||
res = rec.Result()
|
res = rec.Result()
|
||||||
assert.Equal(t, http.StatusAccepted, res.StatusCode)
|
assert.Equal(t, http.StatusAccepted, res.StatusCode)
|
||||||
assert.True(t, f.done)
|
assert.True(t, f.Done)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewApiServer_AcmeChallenge_Put(t *testing.T) {
|
func TestNewApiServer_AcmeChallenge_Put(t *testing.T) {
|
||||||
apiConf := &Conf{
|
apiConf := &conf.Conf{
|
||||||
Domains: &fakeDomains{},
|
Domains: &fake.Domains{},
|
||||||
Acme: utils.NewAcmeChallenge(),
|
Acme: utils.NewAcmeChallenge(),
|
||||||
Signer: snakeOilProv,
|
Signer: fake.SnakeOilProv,
|
||||||
}
|
}
|
||||||
srv := NewApiServer(apiConf, utils.MultiCompilable{})
|
srv := NewApiServer(apiConf, utils.MultiCompilable{})
|
||||||
acmeKey := genSnakeOilKey("violet:acme-challenge")
|
acmeKey := fake.GenSnakeOilKey("violet:acme-challenge")
|
||||||
|
|
||||||
// Valid domain
|
// Valid domain
|
||||||
req, err := http.NewRequest(http.MethodPut, "https://example.com/acme-challenge/example.com/123/123abc", nil)
|
req, err := http.NewRequest(http.MethodPut, "https://example.com/acme-challenge/example.com/123/123abc", nil)
|
||||||
@ -119,13 +82,13 @@ func TestNewApiServer_AcmeChallenge_Put(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNewApiServer_AcmeChallenge_Delete(t *testing.T) {
|
func TestNewApiServer_AcmeChallenge_Delete(t *testing.T) {
|
||||||
apiConf := &Conf{
|
apiConf := &conf.Conf{
|
||||||
Domains: &fakeDomains{},
|
Domains: &fake.Domains{},
|
||||||
Acme: utils.NewAcmeChallenge(),
|
Acme: utils.NewAcmeChallenge(),
|
||||||
Signer: snakeOilProv,
|
Signer: fake.SnakeOilProv,
|
||||||
}
|
}
|
||||||
srv := NewApiServer(apiConf, utils.MultiCompilable{})
|
srv := NewApiServer(apiConf, utils.MultiCompilable{})
|
||||||
acmeKey := genSnakeOilKey("violet:acme-challenge")
|
acmeKey := fake.GenSnakeOilKey("violet:acme-challenge")
|
||||||
|
|
||||||
// Valid domain
|
// Valid domain
|
||||||
req, err := http.NewRequest(http.MethodDelete, "https://example.com/acme-challenge/example.com/123", nil)
|
req, err := http.NewRequest(http.MethodDelete, "https://example.com/acme-challenge/example.com/123", nil)
|
49
servers/api/auth.go
Normal file
49
servers/api/auth.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/MrMelon54/mjwt"
|
||||||
|
"github.com/MrMelon54/mjwt/auth"
|
||||||
|
"github.com/MrMelon54/violet/utils"
|
||||||
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthClaims mjwt.BaseTypeClaims[auth.AccessTokenClaims]
|
||||||
|
|
||||||
|
type AuthCallback func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims)
|
||||||
|
|
||||||
|
// checkAuth validates the bearer token against a mjwt.Verifier and returns an
|
||||||
|
// error message or continues to the next handler
|
||||||
|
func checkAuth(verify mjwt.Verifier, cb AuthCallback) httprouter.Handle {
|
||||||
|
return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||||
|
// Get bearer token
|
||||||
|
bearer := utils.GetBearer(req)
|
||||||
|
if bearer == "" {
|
||||||
|
apiError(rw, http.StatusForbidden, "Missing bearer token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read claims from mjwt
|
||||||
|
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer)
|
||||||
|
if err != nil {
|
||||||
|
apiError(rw, http.StatusForbidden, "Invalid token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cb(rw, req, params, AuthClaims(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkAuthWithPerm validates the bearer token and checks if it contains a
|
||||||
|
// required permission and returns an error message or continues to the next
|
||||||
|
// handler
|
||||||
|
func checkAuthWithPerm(verify mjwt.Verifier, perm string, cb AuthCallback) httprouter.Handle {
|
||||||
|
return checkAuth(verify, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
|
||||||
|
// check perms
|
||||||
|
if !b.Claims.Perms.Has(perm) {
|
||||||
|
apiError(rw, http.StatusForbidden, "No permission")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cb(rw, req, params, b)
|
||||||
|
})
|
||||||
|
}
|
27
servers/api/target-types.go
Normal file
27
servers/api/target-types.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/MrMelon54/violet/target"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sourceJson struct {
|
||||||
|
Src string `json:"src"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s sourceJson) GetSource() string { return s.Src }
|
||||||
|
|
||||||
|
type routeSource target.Route
|
||||||
|
|
||||||
|
func (r routeSource) GetSource() string { return r.Src }
|
||||||
|
|
||||||
|
type redirectSource target.Redirect
|
||||||
|
|
||||||
|
func (r redirectSource) GetSource() string { return r.Src }
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ sourceGetter = sourceJson{}
|
||||||
|
_ sourceGetter = routeSource{}
|
||||||
|
_ sourceGetter = redirectSource{}
|
||||||
|
)
|
||||||
|
|
||||||
|
type sourceGetter interface{ GetSource() string }
|
88
servers/api/target.go
Normal file
88
servers/api/target.go
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/MrMelon54/mjwt"
|
||||||
|
"github.com/MrMelon54/violet/router"
|
||||||
|
"github.com/MrMelon54/violet/target"
|
||||||
|
"github.com/MrMelon54/violet/utils"
|
||||||
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TargetApis struct {
|
||||||
|
CreateRoute httprouter.Handle
|
||||||
|
DeleteRoute httprouter.Handle
|
||||||
|
CreateRedirect httprouter.Handle
|
||||||
|
DeleteRedirect httprouter.Handle
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetupTargetApis(verify mjwt.Verifier, manager *router.Manager) *TargetApis {
|
||||||
|
r := &TargetApis{
|
||||||
|
CreateRoute: parseJsonAndCheckOwnership[routeSource](verify, "route", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t routeSource) {
|
||||||
|
err := manager.InsertRoute(target.Route(t))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Violet] Failed to insert route into database: %s\n", err)
|
||||||
|
apiError(rw, http.StatusInternalServerError, "Failed to insert route into database")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
manager.Compile()
|
||||||
|
}),
|
||||||
|
DeleteRoute: parseJsonAndCheckOwnership[sourceJson](verify, "route", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t sourceJson) {
|
||||||
|
err := manager.DeleteRoute(t.Src)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Violet] Failed to delete route from database: %s\n", err)
|
||||||
|
apiError(rw, http.StatusInternalServerError, "Failed to delete route from database")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
manager.Compile()
|
||||||
|
}),
|
||||||
|
CreateRedirect: parseJsonAndCheckOwnership[redirectSource](verify, "redirect", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t redirectSource) {
|
||||||
|
err := manager.InsertRedirect(target.Redirect(t))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Violet] Failed to insert redirect into database: %s\n", err)
|
||||||
|
apiError(rw, http.StatusInternalServerError, "Failed to insert redirect into database")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
manager.Compile()
|
||||||
|
}),
|
||||||
|
DeleteRedirect: parseJsonAndCheckOwnership[sourceJson](verify, "redirect", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t sourceJson) {
|
||||||
|
err := manager.DeleteRedirect(t.Src)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Violet] Failed to delete redirect from database: %s\n", err)
|
||||||
|
apiError(rw, http.StatusInternalServerError, "Failed to delete redirect from database")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
manager.Compile()
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthWithJsonCallback[T any] func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t T)
|
||||||
|
|
||||||
|
func parseJsonAndCheckOwnership[T sourceGetter](verify mjwt.Verifier, t string, cb AuthWithJsonCallback[T]) httprouter.Handle {
|
||||||
|
return checkAuthWithPerm(verify, "violet:"+t, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
|
||||||
|
var j T
|
||||||
|
if json.NewDecoder(req.Body).Decode(&j) != nil {
|
||||||
|
apiError(rw, http.StatusBadRequest, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// check token owns this domain
|
||||||
|
host, _ := utils.SplitHostPath(j.GetSource())
|
||||||
|
if strings.IndexByte(host, ':') != -1 {
|
||||||
|
apiError(rw, http.StatusBadRequest, "Invalid route source")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !validateDomainOwnershipClaims(host, b.Claims.Perms) {
|
||||||
|
apiError(rw, http.StatusBadRequest, "Token cannot modify the specified domain")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cb(rw, req, params, b, j)
|
||||||
|
})
|
||||||
|
}
|
@ -1,12 +1,12 @@
|
|||||||
package servers
|
package conf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"github.com/MrMelon54/mjwt"
|
"github.com/MrMelon54/mjwt"
|
||||||
errorPages "github.com/MrMelon54/violet/error-pages"
|
errorPages "github.com/MrMelon54/violet/error-pages"
|
||||||
"github.com/MrMelon54/violet/favicons"
|
"github.com/MrMelon54/violet/favicons"
|
||||||
"github.com/MrMelon54/violet/router"
|
"github.com/MrMelon54/violet/router"
|
||||||
|
"github.com/MrMelon54/violet/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Conf stores the shared configuration for the API, HTTP and HTTPS servers.
|
// Conf stores the shared configuration for the API, HTTP and HTTPS servers.
|
||||||
@ -16,29 +16,11 @@ type Conf struct {
|
|||||||
HttpsListen string // https server listen address
|
HttpsListen string // https server listen address
|
||||||
RateLimit uint64 // rate limit per minute
|
RateLimit uint64 // rate limit per minute
|
||||||
DB *sql.DB
|
DB *sql.DB
|
||||||
Domains DomainProvider
|
Domains utils.DomainProvider
|
||||||
Acme AcmeChallengeProvider
|
Acme utils.AcmeChallengeProvider
|
||||||
Certs CertProvider
|
Certs utils.CertProvider
|
||||||
Favicons *favicons.Favicons
|
Favicons *favicons.Favicons
|
||||||
Signer mjwt.Verifier
|
Signer mjwt.Verifier
|
||||||
ErrorPages *errorPages.ErrorPages
|
ErrorPages *errorPages.ErrorPages
|
||||||
Router *router.Manager
|
Router *router.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
type DomainProvider interface {
|
|
||||||
IsValid(host string) bool
|
|
||||||
Put(domain string, active bool)
|
|
||||||
Delete(domain string)
|
|
||||||
Compile()
|
|
||||||
}
|
|
||||||
|
|
||||||
type AcmeChallengeProvider interface {
|
|
||||||
Get(domain, key string) string
|
|
||||||
Put(domain, key, value string)
|
|
||||||
Delete(domain, key string)
|
|
||||||
}
|
|
||||||
|
|
||||||
type CertProvider interface {
|
|
||||||
GetCertForDomain(domain string) *tls.Certificate
|
|
||||||
Compile()
|
|
||||||
}
|
|
@ -2,6 +2,7 @@ package servers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/MrMelon54/violet/servers/conf"
|
||||||
"github.com/MrMelon54/violet/utils"
|
"github.com/MrMelon54/violet/utils"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -14,7 +15,7 @@ import (
|
|||||||
//
|
//
|
||||||
// `/.well-known/acme-challenge/{token}` is used for outputting answers for
|
// `/.well-known/acme-challenge/{token}` is used for outputting answers for
|
||||||
// acme challenges, this is used for Let's Encrypt HTTP verification.
|
// acme challenges, this is used for Let's Encrypt HTTP verification.
|
||||||
func NewHttpServer(conf *Conf) *http.Server {
|
func NewHttpServer(conf *conf.Conf) *http.Server {
|
||||||
r := httprouter.New()
|
r := httprouter.New()
|
||||||
var secureExtend string
|
var secureExtend string
|
||||||
_, httpsPort, ok := utils.SplitDomainPort(conf.HttpsListen, 443)
|
_, httpsPort, ok := utils.SplitDomainPort(conf.HttpsListen, 443)
|
||||||
|
@ -2,7 +2,9 @@ package servers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"github.com/MrMelon54/violet/servers/conf"
|
||||||
"github.com/MrMelon54/violet/utils"
|
"github.com/MrMelon54/violet/utils"
|
||||||
|
"github.com/MrMelon54/violet/utils/fake"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -11,10 +13,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestNewHttpServer_AcmeChallenge(t *testing.T) {
|
func TestNewHttpServer_AcmeChallenge(t *testing.T) {
|
||||||
httpConf := &Conf{
|
httpConf := &conf.Conf{
|
||||||
Domains: &fakeDomains{},
|
Domains: &fake.Domains{},
|
||||||
Acme: utils.NewAcmeChallenge(),
|
Acme: utils.NewAcmeChallenge(),
|
||||||
Signer: snakeOilProv,
|
Signer: fake.SnakeOilProv,
|
||||||
}
|
}
|
||||||
srv := NewHttpServer(httpConf)
|
srv := NewHttpServer(httpConf)
|
||||||
httpConf.Acme.Put("example.com", "456", "456def")
|
httpConf.Acme.Put("example.com", "456", "456def")
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/MrMelon54/violet/favicons"
|
"github.com/MrMelon54/violet/favicons"
|
||||||
|
"github.com/MrMelon54/violet/servers/conf"
|
||||||
"github.com/MrMelon54/violet/utils"
|
"github.com/MrMelon54/violet/utils"
|
||||||
"github.com/sethvargo/go-limiter/httplimit"
|
"github.com/sethvargo/go-limiter/httplimit"
|
||||||
"github.com/sethvargo/go-limiter/memorystore"
|
"github.com/sethvargo/go-limiter/memorystore"
|
||||||
@ -16,7 +17,7 @@ import (
|
|||||||
|
|
||||||
// NewHttpsServer creates and runs a http server containing the public https
|
// NewHttpsServer creates and runs a http server containing the public https
|
||||||
// endpoints for the reverse proxy.
|
// endpoints for the reverse proxy.
|
||||||
func NewHttpsServer(conf *Conf) *http.Server {
|
func NewHttpsServer(conf *conf.Conf) *http.Server {
|
||||||
return &http.Server{
|
return &http.Server{
|
||||||
Addr: conf.HttpsListen,
|
Addr: conf.HttpsListen,
|
||||||
Handler: setupRateLimiter(conf.RateLimit, setupFaviconMiddleware(conf.Favicons, conf.Router)),
|
Handler: setupRateLimiter(conf.RateLimit, setupFaviconMiddleware(conf.Favicons, conf.Router)),
|
||||||
|
@ -5,6 +5,8 @@ import (
|
|||||||
"github.com/MrMelon54/violet/certs"
|
"github.com/MrMelon54/violet/certs"
|
||||||
"github.com/MrMelon54/violet/proxy"
|
"github.com/MrMelon54/violet/proxy"
|
||||||
"github.com/MrMelon54/violet/router"
|
"github.com/MrMelon54/violet/router"
|
||||||
|
"github.com/MrMelon54/violet/servers/conf"
|
||||||
|
"github.com/MrMelon54/violet/utils/fake"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -26,11 +28,11 @@ func TestNewHttpsServer_RateLimit(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
ft := &fakeTransport{}
|
ft := &fakeTransport{}
|
||||||
httpsConf := &Conf{
|
httpsConf := &conf.Conf{
|
||||||
RateLimit: 5,
|
RateLimit: 5,
|
||||||
Domains: &fakeDomains{},
|
Domains: &fake.Domains{},
|
||||||
Certs: certs.New(nil, nil, true),
|
Certs: certs.New(nil, nil, true),
|
||||||
Signer: snakeOilProv,
|
Signer: fake.SnakeOilProv,
|
||||||
Router: router.NewManager(db, proxy.NewHybridTransportWithCalls(ft, ft)),
|
Router: router.NewManager(db, proxy.NewHybridTransportWithCalls(ft, ft)),
|
||||||
}
|
}
|
||||||
srv := NewHttpsServer(httpsConf)
|
srv := NewHttpsServer(httpsConf)
|
||||||
|
41
target/flags.go
Normal file
41
target/flags.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package target
|
||||||
|
|
||||||
|
type Flags uint64
|
||||||
|
|
||||||
|
const (
|
||||||
|
FlagPre Flags = 1 << iota
|
||||||
|
FlagAbs
|
||||||
|
FlagCors
|
||||||
|
FlagSecureMode
|
||||||
|
FlagForwardHost
|
||||||
|
FlagForwardAddr
|
||||||
|
FlagIgnoreCert
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
routeFlagMask = FlagPre | FlagAbs | FlagCors | FlagSecureMode | FlagForwardHost | FlagForwardAddr | FlagIgnoreCert
|
||||||
|
redirectFlagMask = FlagPre | FlagAbs
|
||||||
|
)
|
||||||
|
|
||||||
|
// HasFlag returns true if the bits contain the requested flag
|
||||||
|
func (f Flags) HasFlag(flag Flags) bool {
|
||||||
|
// 0110 & 0100 == 0100 (value != 0 thus true)
|
||||||
|
// 0011 & 0100 == 0000 (value == 0 thus false)
|
||||||
|
return f&flag != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormaliseRouteFlags returns only the bits used for routes
|
||||||
|
func (f Flags) NormaliseRouteFlags() Flags {
|
||||||
|
// removes bits outside the mask
|
||||||
|
// 0110 & 0111 == 0110
|
||||||
|
// 1010 & 0111 == 0010 (values are different)
|
||||||
|
return f & routeFlagMask
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormaliseRedirectFlags returns only the bits used for redirects
|
||||||
|
func (f Flags) NormaliseRedirectFlags() Flags {
|
||||||
|
// removes bits outside the mask
|
||||||
|
// 0110 & 0111 == 0110
|
||||||
|
// 1010 & 0111 == 0010 (values are different)
|
||||||
|
return f & redirectFlagMask
|
||||||
|
}
|
@ -12,20 +12,14 @@ import (
|
|||||||
// Redirect is a target used by the router to manage redirecting the request
|
// Redirect is a target used by the router to manage redirecting the request
|
||||||
// using the specified configuration.
|
// using the specified configuration.
|
||||||
type Redirect struct {
|
type Redirect struct {
|
||||||
Pre bool // if the path has had a prefix removed
|
Src string `json:"src"` // request source
|
||||||
Host string // target host
|
Dst string `json:"dst"` // redirect destination
|
||||||
Port int // target port
|
Flags Flags `json:"flags"` // extra flags
|
||||||
Path string // target path (possibly a prefix or absolute)
|
Code int `json:"code"` // status code used to redirect
|
||||||
Abs bool // if the path is a prefix or absolute
|
|
||||||
Code int // status code used to redirect
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FullHost outputs a host:port combo or just the host if the port is 0.
|
func (r Route) HasFlag(flag Flags) bool {
|
||||||
func (r Redirect) FullHost() string {
|
return r.Flags&flag != 0
|
||||||
if r.Port == 0 {
|
|
||||||
return r.Host
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%s:%d", r.Host, r.Port)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeHTTP responds with the redirect to the response writer provided.
|
// ServeHTTP responds with the redirect to the response writer provided.
|
||||||
@ -36,10 +30,12 @@ func (r Redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
code = http.StatusFound
|
code = http.StatusFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// split the host and path
|
||||||
|
host, p := utils.SplitHostPath(r.Dst)
|
||||||
|
|
||||||
// if not Abs then join with the ending of the current path
|
// if not Abs then join with the ending of the current path
|
||||||
p := r.Path
|
if !r.Flags.HasFlag(FlagAbs) {
|
||||||
if !r.Abs {
|
p = path.Join(p, req.URL.Path)
|
||||||
p = path.Join(r.Path, req.URL.Path)
|
|
||||||
|
|
||||||
// replace the trailing slash that path.Join() strips off
|
// replace the trailing slash that path.Join() strips off
|
||||||
if strings.HasSuffix(req.URL.Path, "/") {
|
if strings.HasSuffix(req.URL.Path, "/") {
|
||||||
@ -55,7 +51,7 @@ func (r Redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
// create a new URL
|
// create a new URL
|
||||||
u := &url.URL{
|
u := &url.URL{
|
||||||
Scheme: req.URL.Scheme,
|
Scheme: req.URL.Scheme,
|
||||||
Host: r.FullHost(),
|
Host: host,
|
||||||
Path: p,
|
Path: p,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,18 +7,13 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRedirect_FullHost(t *testing.T) {
|
|
||||||
assert.Equal(t, "localhost", Redirect{Host: "localhost"}.FullHost())
|
|
||||||
assert.Equal(t, "localhost:22", Redirect{Host: "localhost", Port: 22}.FullHost())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRedirect_ServeHTTP(t *testing.T) {
|
func TestRedirect_ServeHTTP(t *testing.T) {
|
||||||
a := []struct {
|
a := []struct {
|
||||||
Redirect
|
Redirect
|
||||||
target string
|
target string
|
||||||
}{
|
}{
|
||||||
{Redirect{Host: "example.com", Path: "/bye", Abs: true, Code: http.StatusFound}, "https://example.com/bye"},
|
{Redirect{Dst: "example.com/bye", Flags: FlagAbs, Code: http.StatusFound}, "https://example.com/bye"},
|
||||||
{Redirect{Host: "example.com", Path: "/bye", Code: http.StatusFound}, "https://example.com/bye/hello/world"},
|
{Redirect{Dst: "example.com/bye", Code: http.StatusFound}, "https://example.com/bye/hello/world"},
|
||||||
}
|
}
|
||||||
for _, i := range a {
|
for _, i := range a {
|
||||||
res := httptest.NewRecorder()
|
res := httptest.NewRecorder()
|
||||||
|
@ -36,18 +36,11 @@ var serveApiCors = cors.New(cors.Options{
|
|||||||
// Route is a target used by the router to manage forwarding traffic to an
|
// Route is a target used by the router to manage forwarding traffic to an
|
||||||
// internal server using the specified configuration.
|
// internal server using the specified configuration.
|
||||||
type Route struct {
|
type Route struct {
|
||||||
Pre bool // if the path has had a prefix removed
|
Src string `json:"src"` // request source
|
||||||
Host string // target host
|
Dst string `json:"dst"` // proxy destination
|
||||||
Port int // target port
|
Flags Flags `json:"flags"` // extra flags
|
||||||
Path string // target path (possibly a prefix or absolute)
|
Headers http.Header `json:"-"` // extra headers
|
||||||
Abs bool // if the path is a prefix or absolute
|
Proxy *proxy.HybridTransport `json:"-"` // reverse proxy handler
|
||||||
Cors bool // add CORS headers
|
|
||||||
SecureMode bool // use HTTPS internally
|
|
||||||
ForwardHost bool // forward host header internally
|
|
||||||
ForwardAddr bool // forward remote address
|
|
||||||
IgnoreCert bool // ignore self-cert
|
|
||||||
Headers http.Header // extra headers
|
|
||||||
Proxy *proxy.HybridTransport // reverse proxy handler
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateHeaders takes an existing set of headers and overwrites them with the
|
// UpdateHeaders takes an existing set of headers and overwrites them with the
|
||||||
@ -58,18 +51,10 @@ func (r Route) UpdateHeaders(header http.Header) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FullHost outputs a host:port combo or just the host if the port is 0.
|
|
||||||
func (r Route) FullHost() string {
|
|
||||||
if r.Port == 0 {
|
|
||||||
return r.Host
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%s:%d", r.Host, r.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeHTTP responds with the data proxied from the internal server to the
|
// ServeHTTP responds with the data proxied from the internal server to the
|
||||||
// response writer provided.
|
// response writer provided.
|
||||||
func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
if r.Cors {
|
if r.HasFlag(FlagCors) {
|
||||||
// wraps with CORS handler
|
// wraps with CORS handler
|
||||||
serveApiCors.Handler(http.HandlerFunc(r.internalServeHTTP)).ServeHTTP(rw, req)
|
serveApiCors.Handler(http.HandlerFunc(r.internalServeHTTP)).ServeHTTP(rw, req)
|
||||||
} else {
|
} else {
|
||||||
@ -82,21 +67,16 @@ func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
// set the scheme and port using defaults if the port is 0
|
// set the scheme and port using defaults if the port is 0
|
||||||
scheme := "http"
|
scheme := "http"
|
||||||
if r.SecureMode {
|
if r.HasFlag(FlagSecureMode) {
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
if r.Port == 0 {
|
|
||||||
r.Port = 443
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if r.Port == 0 {
|
|
||||||
r.Port = 80
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// split the host and path
|
||||||
|
host, p := utils.SplitHostPath(r.Dst)
|
||||||
|
|
||||||
// if not Abs then join with the ending of the current path
|
// if not Abs then join with the ending of the current path
|
||||||
p := r.Path
|
if !r.HasFlag(FlagAbs) {
|
||||||
if !r.Abs {
|
p = path.Join(p, req.URL.Path)
|
||||||
p = path.Join(r.Path, req.URL.Path)
|
|
||||||
|
|
||||||
// replace the trailing slash that path.Join() strips off
|
// replace the trailing slash that path.Join() strips off
|
||||||
if strings.HasSuffix(req.URL.Path, "/") {
|
if strings.HasSuffix(req.URL.Path, "/") {
|
||||||
@ -112,7 +92,7 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
// create a new URL
|
// create a new URL
|
||||||
u := &url.URL{
|
u := &url.URL{
|
||||||
Scheme: scheme,
|
Scheme: scheme,
|
||||||
Host: r.FullHost(),
|
Host: host,
|
||||||
Path: p,
|
Path: p,
|
||||||
RawQuery: req.URL.RawQuery,
|
RawQuery: req.URL.RawQuery,
|
||||||
}
|
}
|
||||||
@ -150,10 +130,10 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// if forward host is enabled then send the host
|
// if forward host is enabled then send the host
|
||||||
if r.ForwardHost {
|
if r.HasFlag(FlagForwardHost) {
|
||||||
req2.Host = req.Host
|
req2.Host = req.Host
|
||||||
}
|
}
|
||||||
if r.ForwardAddr {
|
if r.HasFlag(FlagForwardAddr) {
|
||||||
req2.Header.Add("X-Forwarded-For", req.RemoteAddr)
|
req2.Header.Add("X-Forwarded-For", req.RemoteAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -162,7 +142,7 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
|
|
||||||
// serve request with reverse proxy
|
// serve request with reverse proxy
|
||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
if r.IgnoreCert {
|
if r.HasFlag(FlagIgnoreCert) {
|
||||||
resp, err = r.Proxy.InsecureRoundTrip(req2)
|
resp, err = r.Proxy.InsecureRoundTrip(req2)
|
||||||
} else {
|
} else {
|
||||||
resp, err = r.Proxy.SecureRoundTrip(req2)
|
resp, err = r.Proxy.SecureRoundTrip(req2)
|
||||||
|
@ -25,9 +25,9 @@ func (p *proxyTester) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||||||
return &http.Response{StatusCode: http.StatusOK}, nil
|
return &http.Response{StatusCode: http.StatusOK}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoute_FullHost(t *testing.T) {
|
func TestRoute_HasFlag(t *testing.T) {
|
||||||
assert.Equal(t, "localhost", Route{Host: "localhost"}.FullHost())
|
assert.True(t, Route{Flags: FlagPre | FlagAbs}.HasFlag(FlagPre))
|
||||||
assert.Equal(t, "localhost:22", Route{Host: "localhost", Port: 22}.FullHost())
|
assert.False(t, Route{Flags: FlagPre | FlagAbs}.HasFlag(FlagCors))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoute_ServeHTTP(t *testing.T) {
|
func TestRoute_ServeHTTP(t *testing.T) {
|
||||||
@ -35,12 +35,12 @@ func TestRoute_ServeHTTP(t *testing.T) {
|
|||||||
Route
|
Route
|
||||||
target string
|
target string
|
||||||
}{
|
}{
|
||||||
{Route{Host: "localhost", Port: 1234, Path: "/bye", Abs: true}, "http://localhost:1234/bye"},
|
{Route{Dst: "localhost:1234/bye", Flags: FlagAbs}, "http://localhost:1234/bye"},
|
||||||
{Route{Host: "1.2.3.4", Path: "/bye"}, "http://1.2.3.4:80/bye/hello/world"},
|
{Route{Dst: "1.2.3.4/bye"}, "http://1.2.3.4/bye/hello/world"},
|
||||||
{Route{Host: "2.2.2.2", Path: "/world", Abs: true, SecureMode: true}, "https://2.2.2.2:443/world"},
|
{Route{Dst: "2.2.2.2/world", Flags: FlagAbs | FlagSecureMode}, "https://2.2.2.2/world"},
|
||||||
{Route{Host: "api.example.com", Path: "/world", Abs: true, SecureMode: true, ForwardHost: true}, "https://api.example.com:443/world"},
|
{Route{Dst: "api.example.com/world", Flags: FlagAbs | FlagSecureMode | FlagForwardHost}, "https://api.example.com/world"},
|
||||||
{Route{Host: "api.example.org", Path: "/world", Abs: true, SecureMode: true, ForwardAddr: true}, "https://api.example.org:443/world"},
|
{Route{Dst: "api.example.org/world", Flags: FlagAbs | FlagSecureMode | FlagForwardAddr}, "https://api.example.org/world"},
|
||||||
{Route{Host: "3.3.3.3", Path: "/headers", Abs: true, Headers: http.Header{"X-Other": []string{"test value"}}}, "http://3.3.3.3:80/headers"},
|
{Route{Dst: "3.3.3.3/headers", Flags: FlagAbs, Headers: http.Header{"X-Other": []string{"test value"}}}, "http://3.3.3.3/headers"},
|
||||||
}
|
}
|
||||||
for _, i := range a {
|
for _, i := range a {
|
||||||
pt := &proxyTester{}
|
pt := &proxyTester{}
|
||||||
@ -51,10 +51,10 @@ func TestRoute_ServeHTTP(t *testing.T) {
|
|||||||
|
|
||||||
assert.True(t, pt.got)
|
assert.True(t, pt.got)
|
||||||
assert.Equal(t, i.target, pt.req.URL.String())
|
assert.Equal(t, i.target, pt.req.URL.String())
|
||||||
if i.ForwardAddr {
|
if i.HasFlag(FlagForwardAddr) {
|
||||||
assert.Equal(t, req.RemoteAddr, pt.req.Header.Get("X-Forwarded-For"))
|
assert.Equal(t, req.RemoteAddr, pt.req.Header.Get("X-Forwarded-For"))
|
||||||
}
|
}
|
||||||
if i.ForwardHost {
|
if i.HasFlag(FlagForwardHost) {
|
||||||
assert.Equal(t, req.Host, pt.req.Host)
|
assert.Equal(t, req.Host, pt.req.Host)
|
||||||
}
|
}
|
||||||
if i.Headers != nil {
|
if i.Headers != nil {
|
||||||
@ -68,7 +68,7 @@ func TestRoute_ServeHTTP_Cors(t *testing.T) {
|
|||||||
res := httptest.NewRecorder()
|
res := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(http.MethodOptions, "https://www.example.com/test", nil)
|
req := httptest.NewRequest(http.MethodOptions, "https://www.example.com/test", nil)
|
||||||
req.Header.Set("Origin", "https://test.example.com")
|
req.Header.Set("Origin", "https://test.example.com")
|
||||||
i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt.makeHybridTransport()}
|
i := &Route{Dst: "1.1.1.1:8080/hello", Flags: FlagCors, Proxy: pt.makeHybridTransport()}
|
||||||
i.ServeHTTP(res, req)
|
i.ServeHTTP(res, req)
|
||||||
|
|
||||||
assert.True(t, pt.got)
|
assert.True(t, pt.got)
|
||||||
@ -86,7 +86,7 @@ func TestRoute_ServeHTTP_Body(t *testing.T) {
|
|||||||
buf := bytes.NewBuffer([]byte{0x54})
|
buf := bytes.NewBuffer([]byte{0x54})
|
||||||
req := httptest.NewRequest(http.MethodPost, "https://www.example.com/test", buf)
|
req := httptest.NewRequest(http.MethodPost, "https://www.example.com/test", buf)
|
||||||
req.Header.Set("Origin", "https://test.example.com")
|
req.Header.Set("Origin", "https://test.example.com")
|
||||||
i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt.makeHybridTransport()}
|
i := &Route{Dst: "1.1.1.1:8080/hello", Flags: FlagCors, Proxy: pt.makeHybridTransport()}
|
||||||
i.ServeHTTP(res, req)
|
i.ServeHTTP(res, req)
|
||||||
|
|
||||||
assert.True(t, pt.got)
|
assert.True(t, pt.got)
|
||||||
|
@ -83,3 +83,38 @@ func GetTopFqdn(domain string) (string, bool) {
|
|||||||
}
|
}
|
||||||
return domain[n+1:], true
|
return domain[n+1:], true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SplitHostPath extracts the host/path from the input
|
||||||
|
func SplitHostPath(a string) (host, path string) {
|
||||||
|
// check if source has path
|
||||||
|
n := strings.IndexByte(a, '/')
|
||||||
|
if n == -1 {
|
||||||
|
// set host then path to /
|
||||||
|
host = a
|
||||||
|
path = "/"
|
||||||
|
} else {
|
||||||
|
// set host then custom path
|
||||||
|
host = a[:n]
|
||||||
|
path = a[n:] // this required to keep / at the start of the path
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// SplitHostPathQuery extracts the host/path?query from the input
|
||||||
|
func SplitHostPathQuery(a string) (host, path, query string) {
|
||||||
|
host, path = SplitHostPath(a)
|
||||||
|
if path == "/" {
|
||||||
|
n := strings.IndexByte(host, '?')
|
||||||
|
if n != -1 {
|
||||||
|
query = host[n+1:]
|
||||||
|
host = host[:n]
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n := strings.IndexByte(path, '?')
|
||||||
|
if n != -1 {
|
||||||
|
query = path[n+1:]
|
||||||
|
path = path[:n] // reassign happens after
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -60,3 +60,40 @@ func TestGetTopFqdn(t *testing.T) {
|
|||||||
assert.True(t, ok, "Output should be true")
|
assert.True(t, ok, "Output should be true")
|
||||||
assert.Equal(t, "example.com", domain)
|
assert.Equal(t, "example.com", domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSplitHostPath(t *testing.T) {
|
||||||
|
h, p := SplitHostPath("example.com/hello/world")
|
||||||
|
assert.Equal(t, "example.com", h)
|
||||||
|
assert.Equal(t, "/hello/world", p)
|
||||||
|
|
||||||
|
h, p = SplitHostPath("example.com")
|
||||||
|
assert.Equal(t, "example.com", h)
|
||||||
|
assert.Equal(t, "/", p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitHostPathQuery(t *testing.T) {
|
||||||
|
h, p, q := SplitHostPathQuery("example.com/hello/world")
|
||||||
|
assert.Equal(t, "example.com", h)
|
||||||
|
assert.Equal(t, "/hello/world", p)
|
||||||
|
assert.Equal(t, "", q)
|
||||||
|
|
||||||
|
h, p, q = SplitHostPathQuery("example.com")
|
||||||
|
assert.Equal(t, "example.com", h)
|
||||||
|
assert.Equal(t, "/", p)
|
||||||
|
assert.Equal(t, "", q)
|
||||||
|
|
||||||
|
h, p, q = SplitHostPathQuery("example.com/hello/world?a=b")
|
||||||
|
assert.Equal(t, "example.com", h)
|
||||||
|
assert.Equal(t, "/hello/world", p)
|
||||||
|
assert.Equal(t, "a=b", q)
|
||||||
|
|
||||||
|
h, p, q = SplitHostPathQuery("example.com?a=b")
|
||||||
|
assert.Equal(t, "example.com", h)
|
||||||
|
assert.Equal(t, "/", p)
|
||||||
|
assert.Equal(t, "a=b", q)
|
||||||
|
|
||||||
|
h, p, q = SplitHostPathQuery("example.com/?a=b")
|
||||||
|
assert.Equal(t, "example.com", h)
|
||||||
|
assert.Equal(t, "/", p)
|
||||||
|
assert.Equal(t, "a=b", q)
|
||||||
|
}
|
||||||
|
11
utils/fake/fake-compilable.go
Normal file
11
utils/fake/fake-compilable.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package fake
|
||||||
|
|
||||||
|
import "github.com/MrMelon54/violet/utils"
|
||||||
|
|
||||||
|
// Compilable implements utils.Compilable and stores if the Compile function
|
||||||
|
// is called.
|
||||||
|
type Compilable struct{ Done bool }
|
||||||
|
|
||||||
|
func (f *Compilable) Compile() { f.Done = true }
|
||||||
|
|
||||||
|
var _ utils.Compilable = &Compilable{}
|
13
utils/fake/fake-domains.go
Normal file
13
utils/fake/fake-domains.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package fake
|
||||||
|
|
||||||
|
import "github.com/MrMelon54/violet/utils"
|
||||||
|
|
||||||
|
// Domains implements DomainProvider and makes sure `example.com` is valid
|
||||||
|
type Domains struct{}
|
||||||
|
|
||||||
|
func (f *Domains) IsValid(host string) bool { return host == "example.com" }
|
||||||
|
func (f *Domains) Put(string, bool) {}
|
||||||
|
func (f *Domains) Delete(string) {}
|
||||||
|
func (f *Domains) Compile() {}
|
||||||
|
|
||||||
|
var _ utils.DomainProvider = &Domains{}
|
2
utils/fake/fake.go
Normal file
2
utils/fake/fake.go
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
// Package fake contains fake structs used during tests
|
||||||
|
package fake
|
30
utils/fake/mjwt.go
Normal file
30
utils/fake/mjwt.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package fake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"github.com/MrMelon54/mjwt"
|
||||||
|
"github.com/MrMelon54/mjwt/auth"
|
||||||
|
"github.com/MrMelon54/mjwt/claims"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var SnakeOilProv = GenSnakeOilProv()
|
||||||
|
|
||||||
|
func GenSnakeOilProv() mjwt.Signer {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return mjwt.NewMJwtSigner("violet.test", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenSnakeOilKey(perm string) string {
|
||||||
|
p := claims.NewPermStorage()
|
||||||
|
p.Set(perm)
|
||||||
|
val, err := SnakeOilProv.GenerateJwt("abc", "abc", nil, 5*time.Minute, auth.AccessTokenClaims{Perms: p})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
21
utils/interfaces.go
Normal file
21
utils/interfaces.go
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import "crypto/tls"
|
||||||
|
|
||||||
|
type DomainProvider interface {
|
||||||
|
IsValid(host string) bool
|
||||||
|
Put(domain string, active bool)
|
||||||
|
Delete(domain string)
|
||||||
|
Compile()
|
||||||
|
}
|
||||||
|
|
||||||
|
type AcmeChallengeProvider interface {
|
||||||
|
Get(domain, key string) string
|
||||||
|
Put(domain, key, value string)
|
||||||
|
Delete(domain, key string)
|
||||||
|
}
|
||||||
|
|
||||||
|
type CertProvider interface {
|
||||||
|
GetCertForDomain(domain string) *tls.Certificate
|
||||||
|
Compile()
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user