mirror of
https://github.com/1f349/violet.git
synced 2024-11-21 19:01:39 +00:00
Write route/redirect APIs and rearrage some other code to make it possible
This commit is contained in:
parent
c930ddff28
commit
949dcd298a
@ -1,11 +1,11 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
|
||||
<data-source source="LOCAL" name="__db.sqlite" uuid="5aeb4e88-8ec4-4227-a921-ba4eaed357bf">
|
||||
<data-source source="LOCAL" name="identifier.sqlite" uuid="5b42d21a-92a8-43d0-8651-c1555b91060c">
|
||||
<driver-ref>sqlite.xerial</driver-ref>
|
||||
<synchronize>true</synchronize>
|
||||
<jdbc-driver>org.sqlite.JDBC</jdbc-driver>
|
||||
<jdbc-url>jdbc:sqlite:__db.sqlite</jdbc-url>
|
||||
<jdbc-url>jdbc:sqlite:identifier.sqlite</jdbc-url>
|
||||
<working-dir>$ProjectFileDir$</working-dir>
|
||||
</data-source>
|
||||
</component>
|
||||
|
@ -14,6 +14,8 @@ import (
|
||||
"github.com/MrMelon54/violet/proxy"
|
||||
"github.com/MrMelon54/violet/router"
|
||||
"github.com/MrMelon54/violet/servers"
|
||||
"github.com/MrMelon54/violet/servers/api"
|
||||
"github.com/MrMelon54/violet/servers/conf"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/google/subcommands"
|
||||
"io/fs"
|
||||
@ -70,9 +72,9 @@ func (s *serveCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface{
|
||||
return subcommands.ExitSuccess
|
||||
}
|
||||
|
||||
func normalLoad(conf startUpConfig, wd string) {
|
||||
func normalLoad(startUp startUpConfig, wd string) {
|
||||
// the cert and key paths are useless in self-signed mode
|
||||
if !conf.SelfSigned {
|
||||
if !startUp.SelfSigned {
|
||||
// create path to cert dir
|
||||
err := os.MkdirAll(filepath.Join(wd, "certs"), os.ModePerm)
|
||||
if err != nil {
|
||||
@ -87,11 +89,11 @@ func normalLoad(conf startUpConfig, wd string) {
|
||||
|
||||
// errorPageDir stores an FS interface for accessing the error page directory
|
||||
var errorPageDir fs.FS
|
||||
if conf.ErrorPagePath != "" {
|
||||
errorPageDir = os.DirFS(conf.ErrorPagePath)
|
||||
err := os.MkdirAll(conf.ErrorPagePath, os.ModePerm)
|
||||
if startUp.ErrorPagePath != "" {
|
||||
errorPageDir = os.DirFS(startUp.ErrorPagePath)
|
||||
err := os.MkdirAll(startUp.ErrorPagePath, os.ModePerm)
|
||||
if err != nil {
|
||||
log.Fatalf("[Violet] Failed to create error page path '%s'", conf.ErrorPagePath)
|
||||
log.Fatalf("[Violet] Failed to create error page path '%s'", startUp.ErrorPagePath)
|
||||
}
|
||||
}
|
||||
|
||||
@ -112,18 +114,18 @@ func normalLoad(conf startUpConfig, wd string) {
|
||||
|
||||
allowedDomains := domains.New(db) // load allowed domains
|
||||
acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store
|
||||
allowedCerts := certs.New(certDir, keyDir, conf.SelfSigned) // load certificate manager
|
||||
allowedCerts := certs.New(certDir, keyDir, startUp.SelfSigned) // load certificate manager
|
||||
hybridTransport := proxy.NewHybridTransport() // load reverse proxy
|
||||
dynamicFavicons := favicons.New(db, conf.InkscapeCmd) // load dynamic favicon provider
|
||||
dynamicFavicons := favicons.New(db, startUp.InkscapeCmd) // load dynamic favicon provider
|
||||
dynamicErrorPages := errorPages.New(errorPageDir) // load dynamic error page provider
|
||||
dynamicRouter := router.NewManager(db, hybridTransport) // load dynamic router manager
|
||||
|
||||
// struct containing config for the http servers
|
||||
srvConf := &servers.Conf{
|
||||
ApiListen: conf.Listen.Api,
|
||||
HttpListen: conf.Listen.Http,
|
||||
HttpsListen: conf.Listen.Https,
|
||||
RateLimit: conf.RateLimit,
|
||||
srvConf := &conf.Conf{
|
||||
ApiListen: startUp.Listen.Api,
|
||||
HttpListen: startUp.Listen.Http,
|
||||
HttpsListen: startUp.Listen.Https,
|
||||
RateLimit: startUp.RateLimit,
|
||||
DB: db,
|
||||
Domains: allowedDomains,
|
||||
Acme: acmeChallenges,
|
||||
@ -140,7 +142,7 @@ func normalLoad(conf startUpConfig, wd string) {
|
||||
|
||||
var srvApi, srvHttp, srvHttps *http.Server
|
||||
if srvConf.ApiListen != "" {
|
||||
srvApi = servers.NewApiServer(srvConf, allCompilables)
|
||||
srvApi = api.NewApiServer(srvConf, allCompilables)
|
||||
log.Printf("[API] Starting API server on: '%s'\n", srvApi.Addr)
|
||||
go utils.RunBackgroundHttp("API", srvApi)
|
||||
}
|
||||
|
@ -181,13 +181,15 @@ func (s *setupCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
|
||||
// add with the route manager, no need to compile as this will run when opened
|
||||
// with the serve subcommand
|
||||
routeManager := router.NewManager(db, proxy.NewHybridTransportWithCalls(&nilTransport{}, &nilTransport{}))
|
||||
routeManager.Add(path.Join(apiUrl.Host, apiUrl.Path), target.Route{
|
||||
Pre: true,
|
||||
Host: answers.ApiListen,
|
||||
Cors: true,
|
||||
ForwardHost: true,
|
||||
ForwardAddr: true,
|
||||
}, true)
|
||||
err = routeManager.InsertRoute(target.Route{
|
||||
Src: path.Join(apiUrl.Host, apiUrl.Path),
|
||||
Dst: answers.ApiListen,
|
||||
Flags: target.FlagPre | target.FlagCors | target.FlagForwardHost | target.FlagForwardAddr,
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Println("[Violet] Failed to insert api route into database: ", err)
|
||||
return subcommands.ExitFailure
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("[Violet] Setup complete")
|
||||
|
7
favicons/create-table-favicons.sql
Normal file
7
favicons/create-table-favicons.sql
Normal file
@ -0,0 +1,7 @@
|
||||
CREATE TABLE IF NOT EXISTS favicons (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
host VARCHAR,
|
||||
svg VARCHAR,
|
||||
png VARCHAR,
|
||||
ico VARCHAR
|
||||
);
|
@ -2,6 +2,7 @@ package favicons
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/MrMelon54/rescheduler"
|
||||
@ -12,6 +13,9 @@ import (
|
||||
|
||||
var ErrFaviconNotFound = errors.New("favicon not found")
|
||||
|
||||
//go:embed create-table-favicons.sql
|
||||
var createTableFavicons string
|
||||
|
||||
// Favicons is a dynamic favicon generator which supports overwriting favicons
|
||||
type Favicons struct {
|
||||
db *sql.DB
|
||||
@ -32,7 +36,7 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons {
|
||||
f.r = rescheduler.NewRescheduler(f.threadCompile)
|
||||
|
||||
// init favicons table
|
||||
_, err := f.db.Exec(`create table if not exists favicons (id integer primary key autoincrement, host varchar, svg varchar, png varchar, ico varchar)`)
|
||||
_, err := f.db.Exec(createTableFavicons)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Failed to generate 'favicons' table\n")
|
||||
return nil
|
||||
|
@ -1,10 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS redirects
|
||||
(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source TEXT,
|
||||
pre INTEGER,
|
||||
destination TEXT,
|
||||
abs INTEGER,
|
||||
code INTEGER,
|
||||
active INTEGER DEFAULT 1
|
||||
);
|
@ -1,14 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS routes
|
||||
(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source TEXT,
|
||||
pre INTEGER,
|
||||
destination TEXT,
|
||||
abs INTEGER,
|
||||
cors INTEGER,
|
||||
secure_mode INTEGER,
|
||||
forward_host INTEGER,
|
||||
forward_addr INTEGER,
|
||||
ignore_cert INTEGER,
|
||||
active INTEGER DEFAULT 1
|
||||
);
|
18
router/create-tables.sql
Normal file
18
router/create-tables.sql
Normal file
@ -0,0 +1,18 @@
|
||||
CREATE TABLE IF NOT EXISTS routes
|
||||
(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source TEXT UNIQUE,
|
||||
destination TEXT,
|
||||
flags INTEGER DEFAULT 0,
|
||||
active INTEGER DEFAULT 1
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS redirects
|
||||
(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source TEXT UNIQUE,
|
||||
destination TEXT,
|
||||
flags INTEGER DEFAULT 0,
|
||||
code INTEGER DEFAULT 0,
|
||||
active INTEGER DEFAULT 1
|
||||
);
|
@ -3,15 +3,11 @@ package router
|
||||
import (
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"github.com/MrMelon54/rescheduler"
|
||||
"github.com/MrMelon54/violet/proxy"
|
||||
"github.com/MrMelon54/violet/target"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"log"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@ -26,14 +22,8 @@ type Manager struct {
|
||||
}
|
||||
|
||||
var (
|
||||
//go:embed create-table-routes.sql
|
||||
createTableRoutes string
|
||||
//go:embed create-table-redirects.sql
|
||||
createTableRedirects string
|
||||
//go:embed query-table-routes.sql
|
||||
queryTableRoutes string
|
||||
//go:embed query-table-redirects.sql
|
||||
queryTableRedirects string
|
||||
//go:embed create-tables.sql
|
||||
createTables string
|
||||
)
|
||||
|
||||
// NewManager create a new manager, initialises the routes and redirects tables
|
||||
@ -48,16 +38,9 @@ func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager {
|
||||
m.z = rescheduler.NewRescheduler(m.threadCompile)
|
||||
|
||||
// init routes table
|
||||
_, err := m.db.Exec(createTableRoutes)
|
||||
_, err := m.db.Exec(createTables)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Failed to generate 'routes' table\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
// init redirects table
|
||||
_, err = m.db.Exec(createTableRedirects)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Failed to generate 'redirects' table\n")
|
||||
log.Printf("[WARN] Failed to generate tables\n")
|
||||
return nil
|
||||
}
|
||||
return m
|
||||
@ -96,7 +79,7 @@ func (m *Manager) internalCompile(router *Router) error {
|
||||
log.Println("[Manager] Updating routes from database")
|
||||
|
||||
// sql or something?
|
||||
rows, err := m.db.Query(queryTableRoutes)
|
||||
rows, err := m.db.Query(`SELECT source, destination, flags FROM routes WHERE active = 1`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -105,26 +88,19 @@ func (m *Manager) internalCompile(router *Router) error {
|
||||
// loop through rows and scan the options
|
||||
for rows.Next() {
|
||||
var (
|
||||
pre, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert bool
|
||||
src, dst string
|
||||
flags target.Flags
|
||||
)
|
||||
err := rows.Scan(&src, &pre, &dst, &abs, &cors, &secure_mode, &forward_host, &forward_addr, &ignore_cert)
|
||||
err := rows.Scan(&src, &dst, &flags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = addRoute(router, src, dst, target.Route{
|
||||
Pre: pre,
|
||||
Abs: abs,
|
||||
Cors: cors,
|
||||
SecureMode: secure_mode,
|
||||
ForwardHost: forward_host,
|
||||
ForwardAddr: forward_addr,
|
||||
IgnoreCert: ignore_cert,
|
||||
router.AddRoute(target.Route{
|
||||
Src: src,
|
||||
Dst: dst,
|
||||
Flags: flags.NormaliseRouteFlags(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// check for errors
|
||||
@ -133,7 +109,7 @@ func (m *Manager) internalCompile(router *Router) error {
|
||||
}
|
||||
|
||||
// sql or something?
|
||||
rows, err = m.db.Query(queryTableRedirects)
|
||||
rows, err = m.db.Query(`SELECT source,destination,flags,code FROM redirects WHERE active = 1`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -142,99 +118,51 @@ func (m *Manager) internalCompile(router *Router) error {
|
||||
// loop through rows and scan the options
|
||||
for rows.Next() {
|
||||
var (
|
||||
pre, abs bool
|
||||
code int
|
||||
src, dst string
|
||||
flags target.Flags
|
||||
code int
|
||||
)
|
||||
err := rows.Scan(&src, &pre, &dst, &abs, &code)
|
||||
err := rows.Scan(&src, &dst, &flags, &code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = addRedirect(router, src, dst, target.Redirect{
|
||||
Pre: pre,
|
||||
Abs: abs,
|
||||
router.AddRedirect(target.Redirect{
|
||||
Src: src,
|
||||
Dst: dst,
|
||||
Flags: flags.NormaliseRedirectFlags(),
|
||||
Code: code,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// check for errors
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (m *Manager) Add(source string, route target.Route, active bool) {
|
||||
func (m *Manager) InsertRoute(route target.Route) error {
|
||||
m.s.Lock()
|
||||
defer m.s.Unlock()
|
||||
_, err := m.db.Exec(`INSERT INTO routes (source, pre, destination, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, source, route.Pre, path.Join(route.Host, route.Path), route.Abs, route.Cors, route.SecureMode, route.ForwardHost, route.ForwardAddr, route.IgnoreCert, active)
|
||||
if err != nil {
|
||||
log.Printf("[Violet] Database error: %s\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// addRoute is an alias to parse the src and dst then add the route
|
||||
func addRoute(router *Router, src string, dst string, t target.Route) error {
|
||||
srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst)
|
||||
if err != nil {
|
||||
_, err := m.db.Exec(`INSERT INTO routes (source, destination, flags) VALUES (?, ?, ?) ON CONFLICT(source) DO UPDATE SET destination = excluded.destination, flags = excluded.flags, active = 1`, route.Src, route.Dst, route.Flags)
|
||||
return err
|
||||
}
|
||||
|
||||
// update target route values and add route
|
||||
t.Host = dstHost
|
||||
t.Port = dstPort
|
||||
t.Path = dstPath
|
||||
router.AddRoute(srcHost, srcPath, t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRedirect is an alias to parse the src and dst then add the redirect
|
||||
func addRedirect(router *Router, src string, dst string, t target.Redirect) error {
|
||||
srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst)
|
||||
if err != nil {
|
||||
func (m *Manager) DeleteRoute(source string) error {
|
||||
m.s.Lock()
|
||||
defer m.s.Unlock()
|
||||
_, err := m.db.Exec(`UPDATE routes SET active = 0 WHERE source = ?`, source)
|
||||
return err
|
||||
}
|
||||
|
||||
t.Host = dstHost
|
||||
t.Port = dstPort
|
||||
t.Path = dstPath
|
||||
router.AddRedirect(srcHost, srcPath, t)
|
||||
return nil
|
||||
func (m *Manager) InsertRedirect(redirect target.Redirect) error {
|
||||
m.s.Lock()
|
||||
defer m.s.Unlock()
|
||||
_, err := m.db.Exec(`INSERT INTO redirects (source, destination, flags, code) VALUES (?, ?, ?, ?) ON CONFLICT(source) DO UPDATE SET destination = excluded.destination, flags = excluded.flags, code = excluded.code, active = 1`, redirect.Src, redirect.Dst, redirect.Flags, redirect.Code)
|
||||
return err
|
||||
}
|
||||
|
||||
// parseSrcDstHost extracts the host/path and host:port/path from the src and dst values
|
||||
func parseSrcDstHost(src string, dst string) (string, string, string, int, string, error) {
|
||||
// check if source has path
|
||||
var srcHost, srcPath string
|
||||
nSrc := strings.IndexByte(src, '/')
|
||||
if nSrc == -1 {
|
||||
// set host then path to /
|
||||
srcHost = src
|
||||
srcPath = "/"
|
||||
} else {
|
||||
// set host then custom path
|
||||
srcHost = src[:nSrc]
|
||||
srcPath = src[nSrc:]
|
||||
}
|
||||
|
||||
// check if destination has path
|
||||
var dstPath string
|
||||
nDst := strings.IndexByte(dst, '/')
|
||||
if nDst == -1 {
|
||||
// set path to /
|
||||
dstPath = "/"
|
||||
} else {
|
||||
// set custom path then trim dst string to the host
|
||||
dstPath = dst[nDst:]
|
||||
dst = dst[:nDst]
|
||||
}
|
||||
|
||||
// try to split the destination host into domain + port
|
||||
dstHost, dstPort, ok := utils.SplitDomainPort(dst, 0)
|
||||
if !ok {
|
||||
return "", "", "", 0, "", fmt.Errorf("failed to split destination '%s' into host + port", dst)
|
||||
}
|
||||
|
||||
return srcHost, srcPath, dstHost, dstPort, dstPath, nil
|
||||
func (m *Manager) DeleteRedirect(source string) error {
|
||||
m.s.Lock()
|
||||
defer m.s.Unlock()
|
||||
_, err := m.db.Exec(`UPDATE redirects SET active = 0 WHERE source = ?`, source)
|
||||
return err
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package router
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/MrMelon54/violet/proxy"
|
||||
"github.com/MrMelon54/violet/target"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
@ -37,7 +38,7 @@ func TestNewManager(t *testing.T) {
|
||||
assert.Equal(t, http.StatusTeapot, res.StatusCode)
|
||||
assert.Nil(t, ft.req)
|
||||
|
||||
_, err = db.Exec(`INSERT INTO routes (source, pre, destination, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert, active) VALUES (?,?,?,?,?,?,?,?,?,?)`, "*.example.com", 0, "127.0.0.1:8080", 1, 0, 0, 1, 1, 0, 1)
|
||||
_, err = db.Exec(`INSERT INTO routes (source, destination, flags, active) VALUES (?,?,?,1)`, "*.example.com", "127.0.0.1:8080", target.FlagAbs|target.FlagForwardHost|target.FlagForwardAddr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NoError(t, m.internalCompile(m.r))
|
||||
|
@ -1,7 +0,0 @@
|
||||
select source,
|
||||
pre,
|
||||
destination,
|
||||
abs,
|
||||
code
|
||||
from redirects
|
||||
where active = true
|
@ -1,11 +0,0 @@
|
||||
select source,
|
||||
pre,
|
||||
destination,
|
||||
abs,
|
||||
cors,
|
||||
secure_mode,
|
||||
forward_host,
|
||||
forward_addr,
|
||||
ignore_cert
|
||||
from routes
|
||||
where active = true
|
@ -46,16 +46,14 @@ func (r *Router) hostRedirect(host string) *trie.Trie[target.Redirect] {
|
||||
return h
|
||||
}
|
||||
|
||||
func (r *Router) AddService(host string, t target.Route) {
|
||||
r.AddRoute(host, "/", t)
|
||||
}
|
||||
|
||||
func (r *Router) AddRoute(host string, path string, t target.Route) {
|
||||
func (r *Router) AddRoute(t target.Route) {
|
||||
t.Proxy = r.proxy
|
||||
host, path := utils.SplitHostPath(t.Src)
|
||||
r.hostRoute(host).PutString(path, t)
|
||||
}
|
||||
|
||||
func (r *Router) AddRedirect(host, path string, t target.Redirect) {
|
||||
func (r *Router) AddRedirect(t target.Redirect) {
|
||||
host, path := utils.SplitHostPath(t.Src)
|
||||
r.hostRedirect(host).PutString(path, t)
|
||||
}
|
||||
|
||||
@ -95,7 +93,7 @@ func (r *Router) serveRouteHTTP(rw http.ResponseWriter, req *http.Request, host
|
||||
if h != nil {
|
||||
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
|
||||
for i := len(pairs) - 1; i >= 0; i-- {
|
||||
if pairs[i].Value.Pre || pairs[i].Key == req.URL.Path {
|
||||
if pairs[i].Value.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
|
||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
|
||||
pairs[i].Value.ServeHTTP(rw, req)
|
||||
return true
|
||||
@ -110,7 +108,7 @@ func (r *Router) serveRedirectHTTP(rw http.ResponseWriter, req *http.Request, ho
|
||||
if h != nil {
|
||||
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
|
||||
for i := len(pairs) - 1; i >= 0; i-- {
|
||||
if pairs[i].Value.Pre || pairs[i].Key == req.URL.Path {
|
||||
if pairs[i].Value.Flags.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
|
||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
|
||||
pairs[i].Value.ServeHTTP(rw, req)
|
||||
return true
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -29,31 +30,31 @@ var (
|
||||
"/": "/",
|
||||
"/hello": "",
|
||||
}},
|
||||
{"/", target.Route{Path: "/world"}, mss{
|
||||
{"/", target.Route{Dst: "world"}, mss{
|
||||
"/": "/world",
|
||||
"/hello": "",
|
||||
}},
|
||||
{"/", target.Route{Abs: true}, mss{
|
||||
{"/", target.Route{Flags: target.FlagAbs}, mss{
|
||||
"/": "/",
|
||||
"/hello": "",
|
||||
}},
|
||||
{"/", target.Route{Abs: true, Path: "world"}, mss{
|
||||
{"/", target.Route{Flags: target.FlagAbs, Dst: "world"}, mss{
|
||||
"/": "/world",
|
||||
"/hello": "",
|
||||
}},
|
||||
{"/", target.Route{Pre: true}, mss{
|
||||
{"/", target.Route{Flags: target.FlagPre}, mss{
|
||||
"/": "/",
|
||||
"/hello": "/hello",
|
||||
}},
|
||||
{"/", target.Route{Pre: true, Path: "world"}, mss{
|
||||
{"/", target.Route{Flags: target.FlagPre, Dst: "world"}, mss{
|
||||
"/": "/world",
|
||||
"/hello": "/world/hello",
|
||||
}},
|
||||
{"/", target.Route{Pre: true, Abs: true}, mss{
|
||||
{"/", target.Route{Flags: target.FlagPre | target.FlagAbs}, mss{
|
||||
"/": "/",
|
||||
"/hello": "/",
|
||||
}},
|
||||
{"/", target.Route{Pre: true, Abs: true, Path: "world"}, mss{
|
||||
{"/", target.Route{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{
|
||||
"/": "/world",
|
||||
"/hello": "/world",
|
||||
}},
|
||||
@ -62,37 +63,37 @@ var (
|
||||
"/hello": "/",
|
||||
"/hello/hi": "",
|
||||
}},
|
||||
{"/hello", target.Route{Path: "world"}, mss{
|
||||
{"/hello", target.Route{Dst: "world"}, mss{
|
||||
"/": "",
|
||||
"/hello": "/world",
|
||||
"/hello/hi": "",
|
||||
}},
|
||||
{"/hello", target.Route{Abs: true}, mss{
|
||||
{"/hello", target.Route{Flags: target.FlagAbs}, mss{
|
||||
"/": "",
|
||||
"/hello": "/",
|
||||
"/hello/hi": "",
|
||||
}},
|
||||
{"/hello", target.Route{Abs: true, Path: "world"}, mss{
|
||||
{"/hello", target.Route{Flags: target.FlagAbs, Dst: "world"}, mss{
|
||||
"/": "",
|
||||
"/hello": "/world",
|
||||
"/hello/hi": "",
|
||||
}},
|
||||
{"/hello", target.Route{Pre: true}, mss{
|
||||
{"/hello", target.Route{Flags: target.FlagPre}, mss{
|
||||
"/": "",
|
||||
"/hello": "/",
|
||||
"/hello/hi": "/hi",
|
||||
}},
|
||||
{"/hello", target.Route{Pre: true, Path: "world"}, mss{
|
||||
{"/hello", target.Route{Flags: target.FlagPre, Dst: "world"}, mss{
|
||||
"/": "",
|
||||
"/hello": "/world",
|
||||
"/hello/hi": "/world/hi",
|
||||
}},
|
||||
{"/hello", target.Route{Pre: true, Abs: true}, mss{
|
||||
{"/hello", target.Route{Flags: target.FlagPre | target.FlagAbs}, mss{
|
||||
"/": "",
|
||||
"/hello": "/",
|
||||
"/hello/hi": "/",
|
||||
}},
|
||||
{"/hello", target.Route{Pre: true, Abs: true, Path: "world"}, mss{
|
||||
{"/hello", target.Route{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{
|
||||
"/": "",
|
||||
"/hello": "/world",
|
||||
"/hello/hi": "/world",
|
||||
@ -103,31 +104,31 @@ var (
|
||||
"/": "/",
|
||||
"/hello": "",
|
||||
}},
|
||||
{"/", target.Redirect{Path: "world"}, mss{
|
||||
{"/", target.Redirect{Dst: "world"}, mss{
|
||||
"/": "/world",
|
||||
"/hello": "",
|
||||
}},
|
||||
{"/", target.Redirect{Abs: true}, mss{
|
||||
{"/", target.Redirect{Flags: target.FlagAbs}, mss{
|
||||
"/": "/",
|
||||
"/hello": "",
|
||||
}},
|
||||
{"/", target.Redirect{Abs: true, Path: "world"}, mss{
|
||||
{"/", target.Redirect{Flags: target.FlagAbs, Dst: "world"}, mss{
|
||||
"/": "/world",
|
||||
"/hello": "",
|
||||
}},
|
||||
{"/", target.Redirect{Pre: true}, mss{
|
||||
{"/", target.Redirect{Flags: target.FlagPre}, mss{
|
||||
"/": "/",
|
||||
"/hello": "/hello",
|
||||
}},
|
||||
{"/", target.Redirect{Pre: true, Path: "world"}, mss{
|
||||
{"/", target.Redirect{Flags: target.FlagPre, Dst: "world"}, mss{
|
||||
"/": "/world",
|
||||
"/hello": "/world/hello",
|
||||
}},
|
||||
{"/", target.Redirect{Pre: true, Abs: true}, mss{
|
||||
{"/", target.Redirect{Flags: target.FlagPre | target.FlagAbs}, mss{
|
||||
"/": "/",
|
||||
"/hello": "/",
|
||||
}},
|
||||
{"/", target.Redirect{Pre: true, Abs: true, Path: "world"}, mss{
|
||||
{"/", target.Redirect{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{
|
||||
"/": "/world",
|
||||
"/hello": "/world",
|
||||
}},
|
||||
@ -136,37 +137,37 @@ var (
|
||||
"/hello": "/",
|
||||
"/hello/hi": "",
|
||||
}},
|
||||
{"/hello", target.Redirect{Path: "world"}, mss{
|
||||
{"/hello", target.Redirect{Dst: "world"}, mss{
|
||||
"/": "",
|
||||
"/hello": "/world",
|
||||
"/hello/hi": "",
|
||||
}},
|
||||
{"/hello", target.Redirect{Abs: true}, mss{
|
||||
{"/hello", target.Redirect{Flags: target.FlagAbs}, mss{
|
||||
"/": "",
|
||||
"/hello": "/",
|
||||
"/hello/hi": "",
|
||||
}},
|
||||
{"/hello", target.Redirect{Abs: true, Path: "world"}, mss{
|
||||
{"/hello", target.Redirect{Flags: target.FlagAbs, Dst: "world"}, mss{
|
||||
"/": "",
|
||||
"/hello": "/world",
|
||||
"/hello/hi": "",
|
||||
}},
|
||||
{"/hello", target.Redirect{Pre: true}, mss{
|
||||
{"/hello", target.Redirect{Flags: target.FlagPre}, mss{
|
||||
"/": "",
|
||||
"/hello": "/",
|
||||
"/hello/hi": "/hi",
|
||||
}},
|
||||
{"/hello", target.Redirect{Pre: true, Path: "world"}, mss{
|
||||
{"/hello", target.Redirect{Flags: target.FlagPre, Dst: "world"}, mss{
|
||||
"/": "",
|
||||
"/hello": "/world",
|
||||
"/hello/hi": "/world/hi",
|
||||
}},
|
||||
{"/hello", target.Redirect{Pre: true, Abs: true}, mss{
|
||||
{"/hello", target.Redirect{Flags: target.FlagPre | target.FlagAbs}, mss{
|
||||
"/": "",
|
||||
"/hello": "/",
|
||||
"/hello/hi": "/",
|
||||
}},
|
||||
{"/hello", target.Redirect{Pre: true, Abs: true, Path: "world"}, mss{
|
||||
{"/hello", target.Redirect{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{
|
||||
"/": "",
|
||||
"/hello": "/world",
|
||||
"/hello/hi": "/world",
|
||||
@ -181,10 +182,10 @@ func TestRouter_AddRoute(t *testing.T) {
|
||||
for _, i := range routeTests {
|
||||
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure))
|
||||
dst := i.dst
|
||||
dst.Host = "127.0.0.1"
|
||||
dst.Port = 8080
|
||||
dst.Dst = path.Join("127.0.0.1:8080", dst.Dst)
|
||||
dst.Src = path.Join("example.com", i.path)
|
||||
t.Logf("Running tests for %#v\n", dst)
|
||||
r.AddRoute("example.com", i.path, dst)
|
||||
r.AddRoute(dst)
|
||||
for k, v := range i.tests {
|
||||
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: k}
|
||||
req, _ := http.NewRequest(http.MethodGet, u1.String(), nil)
|
||||
@ -217,10 +218,11 @@ func TestRouter_AddRedirect(t *testing.T) {
|
||||
for _, i := range redirectTests {
|
||||
r := New(nil)
|
||||
dst := i.dst
|
||||
dst.Host = "example.com"
|
||||
dst.Dst = path.Join("example.com", dst.Dst)
|
||||
dst.Code = http.StatusFound
|
||||
dst.Src = path.Join("www.example.com", i.path)
|
||||
t.Logf("Running tests for %#v\n", dst)
|
||||
r.AddRedirect("www.example.com", i.path, dst)
|
||||
r.AddRedirect(dst)
|
||||
for k, v := range i.tests {
|
||||
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: v}
|
||||
if v == "" {
|
||||
@ -266,10 +268,10 @@ func TestRouter_AddWildcardRoute(t *testing.T) {
|
||||
for _, i := range routeTests {
|
||||
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure))
|
||||
dst := i.dst
|
||||
dst.Host = "127.0.0.1"
|
||||
dst.Port = 8080
|
||||
dst.Dst = path.Join("127.0.0.1:8080", dst.Dst)
|
||||
dst.Src = path.Join("example.com", i.path)
|
||||
t.Logf("Running tests for %#v\n", dst)
|
||||
r.AddRoute("example.com", i.path, dst)
|
||||
r.AddRoute(dst)
|
||||
for k, v := range i.tests {
|
||||
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: k}
|
||||
req, _ := http.NewRequest(http.MethodGet, u1.String(), nil)
|
||||
|
115
servers/api.go
115
servers/api.go
@ -1,115 +0,0 @@
|
||||
package servers
|
||||
|
||||
import (
|
||||
"github.com/MrMelon54/mjwt"
|
||||
"github.com/MrMelon54/mjwt/auth"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewApiServer creates and runs a http server containing all the API
|
||||
// endpoints for the software
|
||||
//
|
||||
// `/compile` - reloads all domains, routes and redirects
|
||||
func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server {
|
||||
r := httprouter.New()
|
||||
|
||||
// Endpoint for compile action
|
||||
r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
|
||||
if !hasPerms(conf.Signer, req, "violet:compile") {
|
||||
utils.RespondHttpStatus(rw, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Trigger the compile action
|
||||
compileTarget.Compile()
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
})
|
||||
|
||||
// Endpoint for domains
|
||||
r.PUT("/domain/:domain", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
if !hasPerms(conf.Signer, req, "violet:domains") {
|
||||
utils.RespondHttpStatus(rw, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// add domain with active state
|
||||
q := req.URL.Query()
|
||||
conf.Domains.Put(params.ByName("domain"), q.Get("active") == "1")
|
||||
conf.Domains.Compile()
|
||||
})
|
||||
r.DELETE("/domain/:domain", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
if !hasPerms(conf.Signer, req, "violet:domains") {
|
||||
utils.RespondHttpStatus(rw, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// add domain with active state
|
||||
q := req.URL.Query()
|
||||
conf.Domains.Put(params.ByName("domain"), q.Get("active") == "1")
|
||||
conf.Domains.Compile()
|
||||
})
|
||||
|
||||
// Endpoint for routes
|
||||
r.POST("/route", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
|
||||
|
||||
})
|
||||
|
||||
// Endpoint for acme-challenge
|
||||
r.PUT("/acme-challenge/:domain/:key/:value", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
if !hasPerms(conf.Signer, req, "violet:acme-challenge") {
|
||||
utils.RespondHttpStatus(rw, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
domain := params.ByName("domain")
|
||||
if !conf.Domains.IsValid(domain) {
|
||||
utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain")
|
||||
return
|
||||
}
|
||||
conf.Acme.Put(domain, params.ByName("key"), params.ByName("value"))
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
})
|
||||
r.DELETE("/acme-challenge/:domain/:key", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
if !hasPerms(conf.Signer, req, "violet:acme-challenge") {
|
||||
utils.RespondHttpStatus(rw, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
domain := params.ByName("domain")
|
||||
if !conf.Domains.IsValid(domain) {
|
||||
utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain")
|
||||
return
|
||||
}
|
||||
conf.Acme.Delete(domain, params.ByName("key"))
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
})
|
||||
|
||||
// Create and run http server
|
||||
return &http.Server{
|
||||
Addr: conf.ApiListen,
|
||||
Handler: r,
|
||||
ReadTimeout: time.Minute,
|
||||
ReadHeaderTimeout: time.Minute,
|
||||
WriteTimeout: time.Minute,
|
||||
IdleTimeout: time.Minute,
|
||||
MaxHeaderBytes: 2500,
|
||||
}
|
||||
}
|
||||
|
||||
func hasPerms(verify mjwt.Verifier, req *http.Request, perm string) bool {
|
||||
// Get bearer token
|
||||
bearer := utils.GetBearer(req)
|
||||
if bearer == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Read claims from mjwt
|
||||
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Token must have perm
|
||||
return b.Claims.Perms.Has(perm)
|
||||
}
|
102
servers/api/api.go
Normal file
102
servers/api/api.go
Normal file
@ -0,0 +1,102 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/MrMelon54/mjwt"
|
||||
"github.com/MrMelon54/mjwt/claims"
|
||||
"github.com/MrMelon54/violet/servers/conf"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewApiServer creates and runs a http server containing all the API
|
||||
// endpoints for the software
|
||||
//
|
||||
// `/compile` - reloads all domains, routes and redirects
|
||||
func NewApiServer(conf *conf.Conf, compileTarget utils.MultiCompilable) *http.Server {
|
||||
r := httprouter.New()
|
||||
|
||||
// Endpoint for compile action
|
||||
r.POST("/compile", checkAuthWithPerm(conf.Signer, "violet:compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, b AuthClaims) {
|
||||
// Trigger the compile action
|
||||
compileTarget.Compile()
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
}))
|
||||
|
||||
// Endpoint for domains
|
||||
domainFunc := domainManage(conf.Signer, conf.Domains)
|
||||
r.PUT("/domain/:domain", domainFunc)
|
||||
r.DELETE("/domain/:domain", domainFunc)
|
||||
|
||||
// Endpoint code for target routes/redirects
|
||||
targetApis := SetupTargetApis(conf.Signer, conf.Router)
|
||||
|
||||
// Endpoint for routes
|
||||
r.POST("/route", targetApis.CreateRoute)
|
||||
r.DELETE("/route", targetApis.DeleteRoute)
|
||||
|
||||
// Endpoint for redirects
|
||||
r.POST("/redirect", targetApis.CreateRedirect)
|
||||
r.DELETE("/redirect", targetApis.DeleteRedirect)
|
||||
|
||||
// Endpoint for acme-challenge
|
||||
acmeChallengeFunc := acmeChallengeManage(conf.Signer, conf.Domains, conf.Acme)
|
||||
r.PUT("/acme-challenge/:domain/:key/:value", acmeChallengeFunc)
|
||||
r.DELETE("/acme-challenge/:domain/:key", acmeChallengeFunc)
|
||||
|
||||
// Create and run http server
|
||||
return &http.Server{
|
||||
Addr: conf.ApiListen,
|
||||
Handler: r,
|
||||
ReadTimeout: time.Minute,
|
||||
ReadHeaderTimeout: time.Minute,
|
||||
WriteTimeout: time.Minute,
|
||||
IdleTimeout: time.Minute,
|
||||
MaxHeaderBytes: 2500,
|
||||
}
|
||||
}
|
||||
|
||||
// apiError outputs a generic JSON error message
|
||||
func apiError(rw http.ResponseWriter, code int, m string) {
|
||||
rw.WriteHeader(code)
|
||||
_ = json.NewEncoder(rw).Encode(map[string]string{
|
||||
"error": m,
|
||||
})
|
||||
}
|
||||
|
||||
func domainManage(verify mjwt.Verifier, domains utils.DomainProvider) httprouter.Handle {
|
||||
return checkAuthWithPerm(verify, "violet:domains", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
|
||||
// add domain with active state
|
||||
domains.Put(params.ByName("domain"), req.Method == http.MethodPut)
|
||||
domains.Compile()
|
||||
})
|
||||
}
|
||||
|
||||
func acmeChallengeManage(verify mjwt.Verifier, domains utils.DomainProvider, acme utils.AcmeChallengeProvider) httprouter.Handle {
|
||||
return checkAuthWithPerm(verify, "violet:acme-challenge", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
|
||||
domain := params.ByName("domain")
|
||||
if !domains.IsValid(domain) {
|
||||
utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain")
|
||||
return
|
||||
}
|
||||
if req.Method == http.MethodPut {
|
||||
acme.Put(domain, params.ByName("key"), params.ByName("value"))
|
||||
} else {
|
||||
acme.Delete(domain, params.ByName("key"))
|
||||
}
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
})
|
||||
}
|
||||
|
||||
// validateDomainOwnershipClaims validates if the claims contain the
|
||||
// `owns=<fqdn>` field with the matching top level domain
|
||||
func validateDomainOwnershipClaims(a string, perms *claims.PermStorage) bool {
|
||||
if fqdn, ok := utils.GetTopFqdn(a); ok {
|
||||
if perms.Has("owns=" + fqdn) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
@ -1,59 +1,22 @@
|
||||
package servers
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/MrMelon54/mjwt"
|
||||
"github.com/MrMelon54/mjwt/auth"
|
||||
"github.com/MrMelon54/mjwt/claims"
|
||||
"github.com/MrMelon54/violet/servers/conf"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/MrMelon54/violet/utils/fake"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var snakeOilProv = genSnakeOilProv()
|
||||
|
||||
type fakeDomains struct{}
|
||||
|
||||
func (f *fakeDomains) IsValid(host string) bool { return host == "example.com" }
|
||||
func (f *fakeDomains) Put(string, bool) {}
|
||||
func (f *fakeDomains) Delete(string) {}
|
||||
func (f *fakeDomains) Compile() {}
|
||||
|
||||
func genSnakeOilProv() mjwt.Signer {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return mjwt.NewMJwtSigner("violet.test", key)
|
||||
}
|
||||
|
||||
func genSnakeOilKey(perm string) string {
|
||||
p := claims.NewPermStorage()
|
||||
p.Set(perm)
|
||||
val, err := snakeOilProv.GenerateJwt("abc", "abc", nil, 5*time.Minute, auth.AccessTokenClaims{Perms: p})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
type fakeCompilable struct{ done bool }
|
||||
|
||||
func (f *fakeCompilable) Compile() { f.done = true }
|
||||
|
||||
var _ utils.Compilable = &fakeCompilable{}
|
||||
|
||||
func TestNewApiServer_Compile(t *testing.T) {
|
||||
apiConf := &Conf{
|
||||
Domains: &fakeDomains{},
|
||||
apiConf := &conf.Conf{
|
||||
Domains: &fake.Domains{},
|
||||
Acme: utils.NewAcmeChallenge(),
|
||||
Signer: snakeOilProv,
|
||||
Signer: fake.SnakeOilProv,
|
||||
}
|
||||
f := &fakeCompilable{}
|
||||
f := &fake.Compilable{}
|
||||
srv := NewApiServer(apiConf, utils.MultiCompilable{f})
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com/compile", nil)
|
||||
@ -63,25 +26,25 @@ func TestNewApiServer_Compile(t *testing.T) {
|
||||
srv.Handler.ServeHTTP(rec, req)
|
||||
res := rec.Result()
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode)
|
||||
assert.False(t, f.done)
|
||||
assert.False(t, f.Done)
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+genSnakeOilKey("violet:compile"))
|
||||
req.Header.Set("Authorization", "Bearer "+fake.GenSnakeOilKey("violet:compile"))
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
srv.Handler.ServeHTTP(rec, req)
|
||||
res = rec.Result()
|
||||
assert.Equal(t, http.StatusAccepted, res.StatusCode)
|
||||
assert.True(t, f.done)
|
||||
assert.True(t, f.Done)
|
||||
}
|
||||
|
||||
func TestNewApiServer_AcmeChallenge_Put(t *testing.T) {
|
||||
apiConf := &Conf{
|
||||
Domains: &fakeDomains{},
|
||||
apiConf := &conf.Conf{
|
||||
Domains: &fake.Domains{},
|
||||
Acme: utils.NewAcmeChallenge(),
|
||||
Signer: snakeOilProv,
|
||||
Signer: fake.SnakeOilProv,
|
||||
}
|
||||
srv := NewApiServer(apiConf, utils.MultiCompilable{})
|
||||
acmeKey := genSnakeOilKey("violet:acme-challenge")
|
||||
acmeKey := fake.GenSnakeOilKey("violet:acme-challenge")
|
||||
|
||||
// Valid domain
|
||||
req, err := http.NewRequest(http.MethodPut, "https://example.com/acme-challenge/example.com/123/123abc", nil)
|
||||
@ -119,13 +82,13 @@ func TestNewApiServer_AcmeChallenge_Put(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewApiServer_AcmeChallenge_Delete(t *testing.T) {
|
||||
apiConf := &Conf{
|
||||
Domains: &fakeDomains{},
|
||||
apiConf := &conf.Conf{
|
||||
Domains: &fake.Domains{},
|
||||
Acme: utils.NewAcmeChallenge(),
|
||||
Signer: snakeOilProv,
|
||||
Signer: fake.SnakeOilProv,
|
||||
}
|
||||
srv := NewApiServer(apiConf, utils.MultiCompilable{})
|
||||
acmeKey := genSnakeOilKey("violet:acme-challenge")
|
||||
acmeKey := fake.GenSnakeOilKey("violet:acme-challenge")
|
||||
|
||||
// Valid domain
|
||||
req, err := http.NewRequest(http.MethodDelete, "https://example.com/acme-challenge/example.com/123", nil)
|
49
servers/api/auth.go
Normal file
49
servers/api/auth.go
Normal file
@ -0,0 +1,49 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/MrMelon54/mjwt"
|
||||
"github.com/MrMelon54/mjwt/auth"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type AuthClaims mjwt.BaseTypeClaims[auth.AccessTokenClaims]
|
||||
|
||||
type AuthCallback func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims)
|
||||
|
||||
// checkAuth validates the bearer token against a mjwt.Verifier and returns an
|
||||
// error message or continues to the next handler
|
||||
func checkAuth(verify mjwt.Verifier, cb AuthCallback) httprouter.Handle {
|
||||
return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
// Get bearer token
|
||||
bearer := utils.GetBearer(req)
|
||||
if bearer == "" {
|
||||
apiError(rw, http.StatusForbidden, "Missing bearer token")
|
||||
return
|
||||
}
|
||||
|
||||
// Read claims from mjwt
|
||||
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer)
|
||||
if err != nil {
|
||||
apiError(rw, http.StatusForbidden, "Invalid token")
|
||||
return
|
||||
}
|
||||
|
||||
cb(rw, req, params, AuthClaims(b))
|
||||
}
|
||||
}
|
||||
|
||||
// checkAuthWithPerm validates the bearer token and checks if it contains a
|
||||
// required permission and returns an error message or continues to the next
|
||||
// handler
|
||||
func checkAuthWithPerm(verify mjwt.Verifier, perm string, cb AuthCallback) httprouter.Handle {
|
||||
return checkAuth(verify, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
|
||||
// check perms
|
||||
if !b.Claims.Perms.Has(perm) {
|
||||
apiError(rw, http.StatusForbidden, "No permission")
|
||||
return
|
||||
}
|
||||
cb(rw, req, params, b)
|
||||
})
|
||||
}
|
27
servers/api/target-types.go
Normal file
27
servers/api/target-types.go
Normal file
@ -0,0 +1,27 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/MrMelon54/violet/target"
|
||||
)
|
||||
|
||||
type sourceJson struct {
|
||||
Src string `json:"src"`
|
||||
}
|
||||
|
||||
func (s sourceJson) GetSource() string { return s.Src }
|
||||
|
||||
type routeSource target.Route
|
||||
|
||||
func (r routeSource) GetSource() string { return r.Src }
|
||||
|
||||
type redirectSource target.Redirect
|
||||
|
||||
func (r redirectSource) GetSource() string { return r.Src }
|
||||
|
||||
var (
|
||||
_ sourceGetter = sourceJson{}
|
||||
_ sourceGetter = routeSource{}
|
||||
_ sourceGetter = redirectSource{}
|
||||
)
|
||||
|
||||
type sourceGetter interface{ GetSource() string }
|
88
servers/api/target.go
Normal file
88
servers/api/target.go
Normal file
@ -0,0 +1,88 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/MrMelon54/mjwt"
|
||||
"github.com/MrMelon54/violet/router"
|
||||
"github.com/MrMelon54/violet/target"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type TargetApis struct {
|
||||
CreateRoute httprouter.Handle
|
||||
DeleteRoute httprouter.Handle
|
||||
CreateRedirect httprouter.Handle
|
||||
DeleteRedirect httprouter.Handle
|
||||
}
|
||||
|
||||
func SetupTargetApis(verify mjwt.Verifier, manager *router.Manager) *TargetApis {
|
||||
r := &TargetApis{
|
||||
CreateRoute: parseJsonAndCheckOwnership[routeSource](verify, "route", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t routeSource) {
|
||||
err := manager.InsertRoute(target.Route(t))
|
||||
if err != nil {
|
||||
log.Printf("[Violet] Failed to insert route into database: %s\n", err)
|
||||
apiError(rw, http.StatusInternalServerError, "Failed to insert route into database")
|
||||
return
|
||||
}
|
||||
manager.Compile()
|
||||
}),
|
||||
DeleteRoute: parseJsonAndCheckOwnership[sourceJson](verify, "route", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t sourceJson) {
|
||||
err := manager.DeleteRoute(t.Src)
|
||||
if err != nil {
|
||||
log.Printf("[Violet] Failed to delete route from database: %s\n", err)
|
||||
apiError(rw, http.StatusInternalServerError, "Failed to delete route from database")
|
||||
return
|
||||
}
|
||||
manager.Compile()
|
||||
}),
|
||||
CreateRedirect: parseJsonAndCheckOwnership[redirectSource](verify, "redirect", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t redirectSource) {
|
||||
err := manager.InsertRedirect(target.Redirect(t))
|
||||
if err != nil {
|
||||
log.Printf("[Violet] Failed to insert redirect into database: %s\n", err)
|
||||
apiError(rw, http.StatusInternalServerError, "Failed to insert redirect into database")
|
||||
return
|
||||
}
|
||||
manager.Compile()
|
||||
}),
|
||||
DeleteRedirect: parseJsonAndCheckOwnership[sourceJson](verify, "redirect", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t sourceJson) {
|
||||
err := manager.DeleteRedirect(t.Src)
|
||||
if err != nil {
|
||||
log.Printf("[Violet] Failed to delete redirect from database: %s\n", err)
|
||||
apiError(rw, http.StatusInternalServerError, "Failed to delete redirect from database")
|
||||
return
|
||||
}
|
||||
manager.Compile()
|
||||
}),
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
type AuthWithJsonCallback[T any] func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t T)
|
||||
|
||||
func parseJsonAndCheckOwnership[T sourceGetter](verify mjwt.Verifier, t string, cb AuthWithJsonCallback[T]) httprouter.Handle {
|
||||
return checkAuthWithPerm(verify, "violet:"+t, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
|
||||
var j T
|
||||
if json.NewDecoder(req.Body).Decode(&j) != nil {
|
||||
apiError(rw, http.StatusBadRequest, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
// check token owns this domain
|
||||
host, _ := utils.SplitHostPath(j.GetSource())
|
||||
if strings.IndexByte(host, ':') != -1 {
|
||||
apiError(rw, http.StatusBadRequest, "Invalid route source")
|
||||
return
|
||||
}
|
||||
|
||||
if !validateDomainOwnershipClaims(host, b.Claims.Perms) {
|
||||
apiError(rw, http.StatusBadRequest, "Token cannot modify the specified domain")
|
||||
return
|
||||
}
|
||||
|
||||
cb(rw, req, params, b, j)
|
||||
})
|
||||
}
|
@ -1,12 +1,12 @@
|
||||
package servers
|
||||
package conf
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"github.com/MrMelon54/mjwt"
|
||||
errorPages "github.com/MrMelon54/violet/error-pages"
|
||||
"github.com/MrMelon54/violet/favicons"
|
||||
"github.com/MrMelon54/violet/router"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
)
|
||||
|
||||
// Conf stores the shared configuration for the API, HTTP and HTTPS servers.
|
||||
@ -16,29 +16,11 @@ type Conf struct {
|
||||
HttpsListen string // https server listen address
|
||||
RateLimit uint64 // rate limit per minute
|
||||
DB *sql.DB
|
||||
Domains DomainProvider
|
||||
Acme AcmeChallengeProvider
|
||||
Certs CertProvider
|
||||
Domains utils.DomainProvider
|
||||
Acme utils.AcmeChallengeProvider
|
||||
Certs utils.CertProvider
|
||||
Favicons *favicons.Favicons
|
||||
Signer mjwt.Verifier
|
||||
ErrorPages *errorPages.ErrorPages
|
||||
Router *router.Manager
|
||||
}
|
||||
|
||||
type DomainProvider interface {
|
||||
IsValid(host string) bool
|
||||
Put(domain string, active bool)
|
||||
Delete(domain string)
|
||||
Compile()
|
||||
}
|
||||
|
||||
type AcmeChallengeProvider interface {
|
||||
Get(domain, key string) string
|
||||
Put(domain, key, value string)
|
||||
Delete(domain, key string)
|
||||
}
|
||||
|
||||
type CertProvider interface {
|
||||
GetCertForDomain(domain string) *tls.Certificate
|
||||
Compile()
|
||||
}
|
@ -2,6 +2,7 @@ package servers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/MrMelon54/violet/servers/conf"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"net/http"
|
||||
@ -14,7 +15,7 @@ import (
|
||||
//
|
||||
// `/.well-known/acme-challenge/{token}` is used for outputting answers for
|
||||
// acme challenges, this is used for Let's Encrypt HTTP verification.
|
||||
func NewHttpServer(conf *Conf) *http.Server {
|
||||
func NewHttpServer(conf *conf.Conf) *http.Server {
|
||||
r := httprouter.New()
|
||||
var secureExtend string
|
||||
_, httpsPort, ok := utils.SplitDomainPort(conf.HttpsListen, 443)
|
||||
|
@ -2,7 +2,9 @@ package servers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/MrMelon54/violet/servers/conf"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/MrMelon54/violet/utils/fake"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -11,10 +13,10 @@ import (
|
||||
)
|
||||
|
||||
func TestNewHttpServer_AcmeChallenge(t *testing.T) {
|
||||
httpConf := &Conf{
|
||||
Domains: &fakeDomains{},
|
||||
httpConf := &conf.Conf{
|
||||
Domains: &fake.Domains{},
|
||||
Acme: utils.NewAcmeChallenge(),
|
||||
Signer: snakeOilProv,
|
||||
Signer: fake.SnakeOilProv,
|
||||
}
|
||||
srv := NewHttpServer(httpConf)
|
||||
httpConf.Acme.Put("example.com", "456", "456def")
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/MrMelon54/violet/favicons"
|
||||
"github.com/MrMelon54/violet/servers/conf"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/sethvargo/go-limiter/httplimit"
|
||||
"github.com/sethvargo/go-limiter/memorystore"
|
||||
@ -16,7 +17,7 @@ import (
|
||||
|
||||
// NewHttpsServer creates and runs a http server containing the public https
|
||||
// endpoints for the reverse proxy.
|
||||
func NewHttpsServer(conf *Conf) *http.Server {
|
||||
func NewHttpsServer(conf *conf.Conf) *http.Server {
|
||||
return &http.Server{
|
||||
Addr: conf.HttpsListen,
|
||||
Handler: setupRateLimiter(conf.RateLimit, setupFaviconMiddleware(conf.Favicons, conf.Router)),
|
||||
|
@ -5,6 +5,8 @@ import (
|
||||
"github.com/MrMelon54/violet/certs"
|
||||
"github.com/MrMelon54/violet/proxy"
|
||||
"github.com/MrMelon54/violet/router"
|
||||
"github.com/MrMelon54/violet/servers/conf"
|
||||
"github.com/MrMelon54/violet/utils/fake"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
@ -26,11 +28,11 @@ func TestNewHttpsServer_RateLimit(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
ft := &fakeTransport{}
|
||||
httpsConf := &Conf{
|
||||
httpsConf := &conf.Conf{
|
||||
RateLimit: 5,
|
||||
Domains: &fakeDomains{},
|
||||
Domains: &fake.Domains{},
|
||||
Certs: certs.New(nil, nil, true),
|
||||
Signer: snakeOilProv,
|
||||
Signer: fake.SnakeOilProv,
|
||||
Router: router.NewManager(db, proxy.NewHybridTransportWithCalls(ft, ft)),
|
||||
}
|
||||
srv := NewHttpsServer(httpsConf)
|
||||
|
41
target/flags.go
Normal file
41
target/flags.go
Normal file
@ -0,0 +1,41 @@
|
||||
package target
|
||||
|
||||
type Flags uint64
|
||||
|
||||
const (
|
||||
FlagPre Flags = 1 << iota
|
||||
FlagAbs
|
||||
FlagCors
|
||||
FlagSecureMode
|
||||
FlagForwardHost
|
||||
FlagForwardAddr
|
||||
FlagIgnoreCert
|
||||
)
|
||||
|
||||
var (
|
||||
routeFlagMask = FlagPre | FlagAbs | FlagCors | FlagSecureMode | FlagForwardHost | FlagForwardAddr | FlagIgnoreCert
|
||||
redirectFlagMask = FlagPre | FlagAbs
|
||||
)
|
||||
|
||||
// HasFlag returns true if the bits contain the requested flag
|
||||
func (f Flags) HasFlag(flag Flags) bool {
|
||||
// 0110 & 0100 == 0100 (value != 0 thus true)
|
||||
// 0011 & 0100 == 0000 (value == 0 thus false)
|
||||
return f&flag != 0
|
||||
}
|
||||
|
||||
// NormaliseRouteFlags returns only the bits used for routes
|
||||
func (f Flags) NormaliseRouteFlags() Flags {
|
||||
// removes bits outside the mask
|
||||
// 0110 & 0111 == 0110
|
||||
// 1010 & 0111 == 0010 (values are different)
|
||||
return f & routeFlagMask
|
||||
}
|
||||
|
||||
// NormaliseRedirectFlags returns only the bits used for redirects
|
||||
func (f Flags) NormaliseRedirectFlags() Flags {
|
||||
// removes bits outside the mask
|
||||
// 0110 & 0111 == 0110
|
||||
// 1010 & 0111 == 0010 (values are different)
|
||||
return f & redirectFlagMask
|
||||
}
|
@ -12,20 +12,14 @@ import (
|
||||
// Redirect is a target used by the router to manage redirecting the request
|
||||
// using the specified configuration.
|
||||
type Redirect struct {
|
||||
Pre bool // if the path has had a prefix removed
|
||||
Host string // target host
|
||||
Port int // target port
|
||||
Path string // target path (possibly a prefix or absolute)
|
||||
Abs bool // if the path is a prefix or absolute
|
||||
Code int // status code used to redirect
|
||||
Src string `json:"src"` // request source
|
||||
Dst string `json:"dst"` // redirect destination
|
||||
Flags Flags `json:"flags"` // extra flags
|
||||
Code int `json:"code"` // status code used to redirect
|
||||
}
|
||||
|
||||
// FullHost outputs a host:port combo or just the host if the port is 0.
|
||||
func (r Redirect) FullHost() string {
|
||||
if r.Port == 0 {
|
||||
return r.Host
|
||||
}
|
||||
return fmt.Sprintf("%s:%d", r.Host, r.Port)
|
||||
func (r Route) HasFlag(flag Flags) bool {
|
||||
return r.Flags&flag != 0
|
||||
}
|
||||
|
||||
// ServeHTTP responds with the redirect to the response writer provided.
|
||||
@ -36,10 +30,12 @@ func (r Redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
code = http.StatusFound
|
||||
}
|
||||
|
||||
// split the host and path
|
||||
host, p := utils.SplitHostPath(r.Dst)
|
||||
|
||||
// if not Abs then join with the ending of the current path
|
||||
p := r.Path
|
||||
if !r.Abs {
|
||||
p = path.Join(r.Path, req.URL.Path)
|
||||
if !r.Flags.HasFlag(FlagAbs) {
|
||||
p = path.Join(p, req.URL.Path)
|
||||
|
||||
// replace the trailing slash that path.Join() strips off
|
||||
if strings.HasSuffix(req.URL.Path, "/") {
|
||||
@ -55,7 +51,7 @@ func (r Redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// create a new URL
|
||||
u := &url.URL{
|
||||
Scheme: req.URL.Scheme,
|
||||
Host: r.FullHost(),
|
||||
Host: host,
|
||||
Path: p,
|
||||
}
|
||||
|
||||
|
@ -7,18 +7,13 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRedirect_FullHost(t *testing.T) {
|
||||
assert.Equal(t, "localhost", Redirect{Host: "localhost"}.FullHost())
|
||||
assert.Equal(t, "localhost:22", Redirect{Host: "localhost", Port: 22}.FullHost())
|
||||
}
|
||||
|
||||
func TestRedirect_ServeHTTP(t *testing.T) {
|
||||
a := []struct {
|
||||
Redirect
|
||||
target string
|
||||
}{
|
||||
{Redirect{Host: "example.com", Path: "/bye", Abs: true, Code: http.StatusFound}, "https://example.com/bye"},
|
||||
{Redirect{Host: "example.com", Path: "/bye", Code: http.StatusFound}, "https://example.com/bye/hello/world"},
|
||||
{Redirect{Dst: "example.com/bye", Flags: FlagAbs, Code: http.StatusFound}, "https://example.com/bye"},
|
||||
{Redirect{Dst: "example.com/bye", Code: http.StatusFound}, "https://example.com/bye/hello/world"},
|
||||
}
|
||||
for _, i := range a {
|
||||
res := httptest.NewRecorder()
|
||||
|
@ -36,18 +36,11 @@ var serveApiCors = cors.New(cors.Options{
|
||||
// Route is a target used by the router to manage forwarding traffic to an
|
||||
// internal server using the specified configuration.
|
||||
type Route struct {
|
||||
Pre bool // if the path has had a prefix removed
|
||||
Host string // target host
|
||||
Port int // target port
|
||||
Path string // target path (possibly a prefix or absolute)
|
||||
Abs bool // if the path is a prefix or absolute
|
||||
Cors bool // add CORS headers
|
||||
SecureMode bool // use HTTPS internally
|
||||
ForwardHost bool // forward host header internally
|
||||
ForwardAddr bool // forward remote address
|
||||
IgnoreCert bool // ignore self-cert
|
||||
Headers http.Header // extra headers
|
||||
Proxy *proxy.HybridTransport // reverse proxy handler
|
||||
Src string `json:"src"` // request source
|
||||
Dst string `json:"dst"` // proxy destination
|
||||
Flags Flags `json:"flags"` // extra flags
|
||||
Headers http.Header `json:"-"` // extra headers
|
||||
Proxy *proxy.HybridTransport `json:"-"` // reverse proxy handler
|
||||
}
|
||||
|
||||
// UpdateHeaders takes an existing set of headers and overwrites them with the
|
||||
@ -58,18 +51,10 @@ func (r Route) UpdateHeaders(header http.Header) {
|
||||
}
|
||||
}
|
||||
|
||||
// FullHost outputs a host:port combo or just the host if the port is 0.
|
||||
func (r Route) FullHost() string {
|
||||
if r.Port == 0 {
|
||||
return r.Host
|
||||
}
|
||||
return fmt.Sprintf("%s:%d", r.Host, r.Port)
|
||||
}
|
||||
|
||||
// ServeHTTP responds with the data proxied from the internal server to the
|
||||
// response writer provided.
|
||||
func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if r.Cors {
|
||||
if r.HasFlag(FlagCors) {
|
||||
// wraps with CORS handler
|
||||
serveApiCors.Handler(http.HandlerFunc(r.internalServeHTTP)).ServeHTTP(rw, req)
|
||||
} else {
|
||||
@ -82,21 +67,16 @@ func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// set the scheme and port using defaults if the port is 0
|
||||
scheme := "http"
|
||||
if r.SecureMode {
|
||||
if r.HasFlag(FlagSecureMode) {
|
||||
scheme = "https"
|
||||
if r.Port == 0 {
|
||||
r.Port = 443
|
||||
}
|
||||
} else {
|
||||
if r.Port == 0 {
|
||||
r.Port = 80
|
||||
}
|
||||
}
|
||||
|
||||
// split the host and path
|
||||
host, p := utils.SplitHostPath(r.Dst)
|
||||
|
||||
// if not Abs then join with the ending of the current path
|
||||
p := r.Path
|
||||
if !r.Abs {
|
||||
p = path.Join(r.Path, req.URL.Path)
|
||||
if !r.HasFlag(FlagAbs) {
|
||||
p = path.Join(p, req.URL.Path)
|
||||
|
||||
// replace the trailing slash that path.Join() strips off
|
||||
if strings.HasSuffix(req.URL.Path, "/") {
|
||||
@ -112,7 +92,7 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// create a new URL
|
||||
u := &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: r.FullHost(),
|
||||
Host: host,
|
||||
Path: p,
|
||||
RawQuery: req.URL.RawQuery,
|
||||
}
|
||||
@ -150,10 +130,10 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
|
||||
// if forward host is enabled then send the host
|
||||
if r.ForwardHost {
|
||||
if r.HasFlag(FlagForwardHost) {
|
||||
req2.Host = req.Host
|
||||
}
|
||||
if r.ForwardAddr {
|
||||
if r.HasFlag(FlagForwardAddr) {
|
||||
req2.Header.Add("X-Forwarded-For", req.RemoteAddr)
|
||||
}
|
||||
|
||||
@ -162,7 +142,7 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
// serve request with reverse proxy
|
||||
var resp *http.Response
|
||||
if r.IgnoreCert {
|
||||
if r.HasFlag(FlagIgnoreCert) {
|
||||
resp, err = r.Proxy.InsecureRoundTrip(req2)
|
||||
} else {
|
||||
resp, err = r.Proxy.SecureRoundTrip(req2)
|
||||
|
@ -25,9 +25,9 @@ func (p *proxyTester) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{StatusCode: http.StatusOK}, nil
|
||||
}
|
||||
|
||||
func TestRoute_FullHost(t *testing.T) {
|
||||
assert.Equal(t, "localhost", Route{Host: "localhost"}.FullHost())
|
||||
assert.Equal(t, "localhost:22", Route{Host: "localhost", Port: 22}.FullHost())
|
||||
func TestRoute_HasFlag(t *testing.T) {
|
||||
assert.True(t, Route{Flags: FlagPre | FlagAbs}.HasFlag(FlagPre))
|
||||
assert.False(t, Route{Flags: FlagPre | FlagAbs}.HasFlag(FlagCors))
|
||||
}
|
||||
|
||||
func TestRoute_ServeHTTP(t *testing.T) {
|
||||
@ -35,12 +35,12 @@ func TestRoute_ServeHTTP(t *testing.T) {
|
||||
Route
|
||||
target string
|
||||
}{
|
||||
{Route{Host: "localhost", Port: 1234, Path: "/bye", Abs: true}, "http://localhost:1234/bye"},
|
||||
{Route{Host: "1.2.3.4", Path: "/bye"}, "http://1.2.3.4:80/bye/hello/world"},
|
||||
{Route{Host: "2.2.2.2", Path: "/world", Abs: true, SecureMode: true}, "https://2.2.2.2:443/world"},
|
||||
{Route{Host: "api.example.com", Path: "/world", Abs: true, SecureMode: true, ForwardHost: true}, "https://api.example.com:443/world"},
|
||||
{Route{Host: "api.example.org", Path: "/world", Abs: true, SecureMode: true, ForwardAddr: true}, "https://api.example.org:443/world"},
|
||||
{Route{Host: "3.3.3.3", Path: "/headers", Abs: true, Headers: http.Header{"X-Other": []string{"test value"}}}, "http://3.3.3.3:80/headers"},
|
||||
{Route{Dst: "localhost:1234/bye", Flags: FlagAbs}, "http://localhost:1234/bye"},
|
||||
{Route{Dst: "1.2.3.4/bye"}, "http://1.2.3.4/bye/hello/world"},
|
||||
{Route{Dst: "2.2.2.2/world", Flags: FlagAbs | FlagSecureMode}, "https://2.2.2.2/world"},
|
||||
{Route{Dst: "api.example.com/world", Flags: FlagAbs | FlagSecureMode | FlagForwardHost}, "https://api.example.com/world"},
|
||||
{Route{Dst: "api.example.org/world", Flags: FlagAbs | FlagSecureMode | FlagForwardAddr}, "https://api.example.org/world"},
|
||||
{Route{Dst: "3.3.3.3/headers", Flags: FlagAbs, Headers: http.Header{"X-Other": []string{"test value"}}}, "http://3.3.3.3/headers"},
|
||||
}
|
||||
for _, i := range a {
|
||||
pt := &proxyTester{}
|
||||
@ -51,10 +51,10 @@ func TestRoute_ServeHTTP(t *testing.T) {
|
||||
|
||||
assert.True(t, pt.got)
|
||||
assert.Equal(t, i.target, pt.req.URL.String())
|
||||
if i.ForwardAddr {
|
||||
if i.HasFlag(FlagForwardAddr) {
|
||||
assert.Equal(t, req.RemoteAddr, pt.req.Header.Get("X-Forwarded-For"))
|
||||
}
|
||||
if i.ForwardHost {
|
||||
if i.HasFlag(FlagForwardHost) {
|
||||
assert.Equal(t, req.Host, pt.req.Host)
|
||||
}
|
||||
if i.Headers != nil {
|
||||
@ -68,7 +68,7 @@ func TestRoute_ServeHTTP_Cors(t *testing.T) {
|
||||
res := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodOptions, "https://www.example.com/test", nil)
|
||||
req.Header.Set("Origin", "https://test.example.com")
|
||||
i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt.makeHybridTransport()}
|
||||
i := &Route{Dst: "1.1.1.1:8080/hello", Flags: FlagCors, Proxy: pt.makeHybridTransport()}
|
||||
i.ServeHTTP(res, req)
|
||||
|
||||
assert.True(t, pt.got)
|
||||
@ -86,7 +86,7 @@ func TestRoute_ServeHTTP_Body(t *testing.T) {
|
||||
buf := bytes.NewBuffer([]byte{0x54})
|
||||
req := httptest.NewRequest(http.MethodPost, "https://www.example.com/test", buf)
|
||||
req.Header.Set("Origin", "https://test.example.com")
|
||||
i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt.makeHybridTransport()}
|
||||
i := &Route{Dst: "1.1.1.1:8080/hello", Flags: FlagCors, Proxy: pt.makeHybridTransport()}
|
||||
i.ServeHTTP(res, req)
|
||||
|
||||
assert.True(t, pt.got)
|
||||
|
@ -83,3 +83,38 @@ func GetTopFqdn(domain string) (string, bool) {
|
||||
}
|
||||
return domain[n+1:], true
|
||||
}
|
||||
|
||||
// SplitHostPath extracts the host/path from the input
|
||||
func SplitHostPath(a string) (host, path string) {
|
||||
// check if source has path
|
||||
n := strings.IndexByte(a, '/')
|
||||
if n == -1 {
|
||||
// set host then path to /
|
||||
host = a
|
||||
path = "/"
|
||||
} else {
|
||||
// set host then custom path
|
||||
host = a[:n]
|
||||
path = a[n:] // this required to keep / at the start of the path
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SplitHostPathQuery extracts the host/path?query from the input
|
||||
func SplitHostPathQuery(a string) (host, path, query string) {
|
||||
host, path = SplitHostPath(a)
|
||||
if path == "/" {
|
||||
n := strings.IndexByte(host, '?')
|
||||
if n != -1 {
|
||||
query = host[n+1:]
|
||||
host = host[:n]
|
||||
}
|
||||
return
|
||||
}
|
||||
n := strings.IndexByte(path, '?')
|
||||
if n != -1 {
|
||||
query = path[n+1:]
|
||||
path = path[:n] // reassign happens after
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -60,3 +60,40 @@ func TestGetTopFqdn(t *testing.T) {
|
||||
assert.True(t, ok, "Output should be true")
|
||||
assert.Equal(t, "example.com", domain)
|
||||
}
|
||||
|
||||
func TestSplitHostPath(t *testing.T) {
|
||||
h, p := SplitHostPath("example.com/hello/world")
|
||||
assert.Equal(t, "example.com", h)
|
||||
assert.Equal(t, "/hello/world", p)
|
||||
|
||||
h, p = SplitHostPath("example.com")
|
||||
assert.Equal(t, "example.com", h)
|
||||
assert.Equal(t, "/", p)
|
||||
}
|
||||
|
||||
func TestSplitHostPathQuery(t *testing.T) {
|
||||
h, p, q := SplitHostPathQuery("example.com/hello/world")
|
||||
assert.Equal(t, "example.com", h)
|
||||
assert.Equal(t, "/hello/world", p)
|
||||
assert.Equal(t, "", q)
|
||||
|
||||
h, p, q = SplitHostPathQuery("example.com")
|
||||
assert.Equal(t, "example.com", h)
|
||||
assert.Equal(t, "/", p)
|
||||
assert.Equal(t, "", q)
|
||||
|
||||
h, p, q = SplitHostPathQuery("example.com/hello/world?a=b")
|
||||
assert.Equal(t, "example.com", h)
|
||||
assert.Equal(t, "/hello/world", p)
|
||||
assert.Equal(t, "a=b", q)
|
||||
|
||||
h, p, q = SplitHostPathQuery("example.com?a=b")
|
||||
assert.Equal(t, "example.com", h)
|
||||
assert.Equal(t, "/", p)
|
||||
assert.Equal(t, "a=b", q)
|
||||
|
||||
h, p, q = SplitHostPathQuery("example.com/?a=b")
|
||||
assert.Equal(t, "example.com", h)
|
||||
assert.Equal(t, "/", p)
|
||||
assert.Equal(t, "a=b", q)
|
||||
}
|
||||
|
11
utils/fake/fake-compilable.go
Normal file
11
utils/fake/fake-compilable.go
Normal file
@ -0,0 +1,11 @@
|
||||
package fake
|
||||
|
||||
import "github.com/MrMelon54/violet/utils"
|
||||
|
||||
// Compilable implements utils.Compilable and stores if the Compile function
|
||||
// is called.
|
||||
type Compilable struct{ Done bool }
|
||||
|
||||
func (f *Compilable) Compile() { f.Done = true }
|
||||
|
||||
var _ utils.Compilable = &Compilable{}
|
13
utils/fake/fake-domains.go
Normal file
13
utils/fake/fake-domains.go
Normal file
@ -0,0 +1,13 @@
|
||||
package fake
|
||||
|
||||
import "github.com/MrMelon54/violet/utils"
|
||||
|
||||
// Domains implements DomainProvider and makes sure `example.com` is valid
|
||||
type Domains struct{}
|
||||
|
||||
func (f *Domains) IsValid(host string) bool { return host == "example.com" }
|
||||
func (f *Domains) Put(string, bool) {}
|
||||
func (f *Domains) Delete(string) {}
|
||||
func (f *Domains) Compile() {}
|
||||
|
||||
var _ utils.DomainProvider = &Domains{}
|
2
utils/fake/fake.go
Normal file
2
utils/fake/fake.go
Normal file
@ -0,0 +1,2 @@
|
||||
// Package fake contains fake structs used during tests
|
||||
package fake
|
30
utils/fake/mjwt.go
Normal file
30
utils/fake/mjwt.go
Normal file
@ -0,0 +1,30 @@
|
||||
package fake
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/MrMelon54/mjwt"
|
||||
"github.com/MrMelon54/mjwt/auth"
|
||||
"github.com/MrMelon54/mjwt/claims"
|
||||
"time"
|
||||
)
|
||||
|
||||
var SnakeOilProv = GenSnakeOilProv()
|
||||
|
||||
func GenSnakeOilProv() mjwt.Signer {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return mjwt.NewMJwtSigner("violet.test", key)
|
||||
}
|
||||
|
||||
func GenSnakeOilKey(perm string) string {
|
||||
p := claims.NewPermStorage()
|
||||
p.Set(perm)
|
||||
val, err := SnakeOilProv.GenerateJwt("abc", "abc", nil, 5*time.Minute, auth.AccessTokenClaims{Perms: p})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return val
|
||||
}
|
21
utils/interfaces.go
Normal file
21
utils/interfaces.go
Normal file
@ -0,0 +1,21 @@
|
||||
package utils
|
||||
|
||||
import "crypto/tls"
|
||||
|
||||
type DomainProvider interface {
|
||||
IsValid(host string) bool
|
||||
Put(domain string, active bool)
|
||||
Delete(domain string)
|
||||
Compile()
|
||||
}
|
||||
|
||||
type AcmeChallengeProvider interface {
|
||||
Get(domain, key string) string
|
||||
Put(domain, key, value string)
|
||||
Delete(domain, key string)
|
||||
}
|
||||
|
||||
type CertProvider interface {
|
||||
GetCertForDomain(domain string) *tls.Certificate
|
||||
Compile()
|
||||
}
|
Loading…
Reference in New Issue
Block a user