Write route/redirect APIs and rearrage some other code to make it possible

This commit is contained in:
Melon 2023-07-12 16:55:09 +01:00
parent c930ddff28
commit 949dcd298a
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
37 changed files with 683 additions and 500 deletions

4
.idea/dataSources.xml generated
View File

@ -1,11 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="__db.sqlite" uuid="5aeb4e88-8ec4-4227-a921-ba4eaed357bf">
<data-source source="LOCAL" name="identifier.sqlite" uuid="5b42d21a-92a8-43d0-8651-c1555b91060c">
<driver-ref>sqlite.xerial</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>org.sqlite.JDBC</jdbc-driver>
<jdbc-url>jdbc:sqlite:__db.sqlite</jdbc-url>
<jdbc-url>jdbc:sqlite:identifier.sqlite</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
</component>

View File

@ -14,6 +14,8 @@ import (
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/router"
"github.com/MrMelon54/violet/servers"
"github.com/MrMelon54/violet/servers/api"
"github.com/MrMelon54/violet/servers/conf"
"github.com/MrMelon54/violet/utils"
"github.com/google/subcommands"
"io/fs"
@ -70,9 +72,9 @@ func (s *serveCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface{
return subcommands.ExitSuccess
}
func normalLoad(conf startUpConfig, wd string) {
func normalLoad(startUp startUpConfig, wd string) {
// the cert and key paths are useless in self-signed mode
if !conf.SelfSigned {
if !startUp.SelfSigned {
// create path to cert dir
err := os.MkdirAll(filepath.Join(wd, "certs"), os.ModePerm)
if err != nil {
@ -87,11 +89,11 @@ func normalLoad(conf startUpConfig, wd string) {
// errorPageDir stores an FS interface for accessing the error page directory
var errorPageDir fs.FS
if conf.ErrorPagePath != "" {
errorPageDir = os.DirFS(conf.ErrorPagePath)
err := os.MkdirAll(conf.ErrorPagePath, os.ModePerm)
if startUp.ErrorPagePath != "" {
errorPageDir = os.DirFS(startUp.ErrorPagePath)
err := os.MkdirAll(startUp.ErrorPagePath, os.ModePerm)
if err != nil {
log.Fatalf("[Violet] Failed to create error page path '%s'", conf.ErrorPagePath)
log.Fatalf("[Violet] Failed to create error page path '%s'", startUp.ErrorPagePath)
}
}
@ -110,20 +112,20 @@ func normalLoad(conf startUpConfig, wd string) {
certDir := os.DirFS(filepath.Join(wd, "certs"))
keyDir := os.DirFS(filepath.Join(wd, "keys"))
allowedDomains := domains.New(db) // load allowed domains
acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store
allowedCerts := certs.New(certDir, keyDir, conf.SelfSigned) // load certificate manager
hybridTransport := proxy.NewHybridTransport() // load reverse proxy
dynamicFavicons := favicons.New(db, conf.InkscapeCmd) // load dynamic favicon provider
dynamicErrorPages := errorPages.New(errorPageDir) // load dynamic error page provider
dynamicRouter := router.NewManager(db, hybridTransport) // load dynamic router manager
allowedDomains := domains.New(db) // load allowed domains
acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store
allowedCerts := certs.New(certDir, keyDir, startUp.SelfSigned) // load certificate manager
hybridTransport := proxy.NewHybridTransport() // load reverse proxy
dynamicFavicons := favicons.New(db, startUp.InkscapeCmd) // load dynamic favicon provider
dynamicErrorPages := errorPages.New(errorPageDir) // load dynamic error page provider
dynamicRouter := router.NewManager(db, hybridTransport) // load dynamic router manager
// struct containing config for the http servers
srvConf := &servers.Conf{
ApiListen: conf.Listen.Api,
HttpListen: conf.Listen.Http,
HttpsListen: conf.Listen.Https,
RateLimit: conf.RateLimit,
srvConf := &conf.Conf{
ApiListen: startUp.Listen.Api,
HttpListen: startUp.Listen.Http,
HttpsListen: startUp.Listen.Https,
RateLimit: startUp.RateLimit,
DB: db,
Domains: allowedDomains,
Acme: acmeChallenges,
@ -140,7 +142,7 @@ func normalLoad(conf startUpConfig, wd string) {
var srvApi, srvHttp, srvHttps *http.Server
if srvConf.ApiListen != "" {
srvApi = servers.NewApiServer(srvConf, allCompilables)
srvApi = api.NewApiServer(srvConf, allCompilables)
log.Printf("[API] Starting API server on: '%s'\n", srvApi.Addr)
go utils.RunBackgroundHttp("API", srvApi)
}

View File

@ -181,13 +181,15 @@ func (s *setupCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
// add with the route manager, no need to compile as this will run when opened
// with the serve subcommand
routeManager := router.NewManager(db, proxy.NewHybridTransportWithCalls(&nilTransport{}, &nilTransport{}))
routeManager.Add(path.Join(apiUrl.Host, apiUrl.Path), target.Route{
Pre: true,
Host: answers.ApiListen,
Cors: true,
ForwardHost: true,
ForwardAddr: true,
}, true)
err = routeManager.InsertRoute(target.Route{
Src: path.Join(apiUrl.Host, apiUrl.Path),
Dst: answers.ApiListen,
Flags: target.FlagPre | target.FlagCors | target.FlagForwardHost | target.FlagForwardAddr,
})
if err != nil {
fmt.Println("[Violet] Failed to insert api route into database: ", err)
return subcommands.ExitFailure
}
}
fmt.Println("[Violet] Setup complete")

View File

@ -0,0 +1,7 @@
CREATE TABLE IF NOT EXISTS favicons (
id INTEGER PRIMARY KEY AUTOINCREMENT,
host VARCHAR,
svg VARCHAR,
png VARCHAR,
ico VARCHAR
);

View File

@ -2,6 +2,7 @@ package favicons
import (
"database/sql"
_ "embed"
"errors"
"fmt"
"github.com/MrMelon54/rescheduler"
@ -12,6 +13,9 @@ import (
var ErrFaviconNotFound = errors.New("favicon not found")
//go:embed create-table-favicons.sql
var createTableFavicons string
// Favicons is a dynamic favicon generator which supports overwriting favicons
type Favicons struct {
db *sql.DB
@ -32,7 +36,7 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons {
f.r = rescheduler.NewRescheduler(f.threadCompile)
// init favicons table
_, err := f.db.Exec(`create table if not exists favicons (id integer primary key autoincrement, host varchar, svg varchar, png varchar, ico varchar)`)
_, err := f.db.Exec(createTableFavicons)
if err != nil {
log.Printf("[WARN] Failed to generate 'favicons' table\n")
return nil

View File

@ -1,10 +0,0 @@
CREATE TABLE IF NOT EXISTS redirects
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
source TEXT,
pre INTEGER,
destination TEXT,
abs INTEGER,
code INTEGER,
active INTEGER DEFAULT 1
);

View File

@ -1,14 +0,0 @@
CREATE TABLE IF NOT EXISTS routes
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
source TEXT,
pre INTEGER,
destination TEXT,
abs INTEGER,
cors INTEGER,
secure_mode INTEGER,
forward_host INTEGER,
forward_addr INTEGER,
ignore_cert INTEGER,
active INTEGER DEFAULT 1
);

18
router/create-tables.sql Normal file
View File

@ -0,0 +1,18 @@
CREATE TABLE IF NOT EXISTS routes
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
source TEXT UNIQUE,
destination TEXT,
flags INTEGER DEFAULT 0,
active INTEGER DEFAULT 1
);
CREATE TABLE IF NOT EXISTS redirects
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
source TEXT UNIQUE,
destination TEXT,
flags INTEGER DEFAULT 0,
code INTEGER DEFAULT 0,
active INTEGER DEFAULT 1
);

View File

@ -3,15 +3,11 @@ package router
import (
"database/sql"
_ "embed"
"fmt"
"github.com/MrMelon54/rescheduler"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/target"
"github.com/MrMelon54/violet/utils"
"log"
"net/http"
"path"
"strings"
"sync"
)
@ -26,14 +22,8 @@ type Manager struct {
}
var (
//go:embed create-table-routes.sql
createTableRoutes string
//go:embed create-table-redirects.sql
createTableRedirects string
//go:embed query-table-routes.sql
queryTableRoutes string
//go:embed query-table-redirects.sql
queryTableRedirects string
//go:embed create-tables.sql
createTables string
)
// NewManager create a new manager, initialises the routes and redirects tables
@ -48,16 +38,9 @@ func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager {
m.z = rescheduler.NewRescheduler(m.threadCompile)
// init routes table
_, err := m.db.Exec(createTableRoutes)
_, err := m.db.Exec(createTables)
if err != nil {
log.Printf("[WARN] Failed to generate 'routes' table\n")
return nil
}
// init redirects table
_, err = m.db.Exec(createTableRedirects)
if err != nil {
log.Printf("[WARN] Failed to generate 'redirects' table\n")
log.Printf("[WARN] Failed to generate tables\n")
return nil
}
return m
@ -96,7 +79,7 @@ func (m *Manager) internalCompile(router *Router) error {
log.Println("[Manager] Updating routes from database")
// sql or something?
rows, err := m.db.Query(queryTableRoutes)
rows, err := m.db.Query(`SELECT source, destination, flags FROM routes WHERE active = 1`)
if err != nil {
return err
}
@ -105,26 +88,19 @@ func (m *Manager) internalCompile(router *Router) error {
// loop through rows and scan the options
for rows.Next() {
var (
pre, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert bool
src, dst string
src, dst string
flags target.Flags
)
err := rows.Scan(&src, &pre, &dst, &abs, &cors, &secure_mode, &forward_host, &forward_addr, &ignore_cert)
err := rows.Scan(&src, &dst, &flags)
if err != nil {
return err
}
err = addRoute(router, src, dst, target.Route{
Pre: pre,
Abs: abs,
Cors: cors,
SecureMode: secure_mode,
ForwardHost: forward_host,
ForwardAddr: forward_addr,
IgnoreCert: ignore_cert,
router.AddRoute(target.Route{
Src: src,
Dst: dst,
Flags: flags.NormaliseRouteFlags(),
})
if err != nil {
return err
}
}
// check for errors
@ -133,7 +109,7 @@ func (m *Manager) internalCompile(router *Router) error {
}
// sql or something?
rows, err = m.db.Query(queryTableRedirects)
rows, err = m.db.Query(`SELECT source,destination,flags,code FROM redirects WHERE active = 1`)
if err != nil {
return err
}
@ -142,99 +118,51 @@ func (m *Manager) internalCompile(router *Router) error {
// loop through rows and scan the options
for rows.Next() {
var (
pre, abs bool
code int
src, dst string
flags target.Flags
code int
)
err := rows.Scan(&src, &pre, &dst, &abs, &code)
err := rows.Scan(&src, &dst, &flags, &code)
if err != nil {
return err
}
err = addRedirect(router, src, dst, target.Redirect{
Pre: pre,
Abs: abs,
Code: code,
router.AddRedirect(target.Redirect{
Src: src,
Dst: dst,
Flags: flags.NormaliseRedirectFlags(),
Code: code,
})
if err != nil {
return err
}
}
// check for errors
return rows.Err()
}
func (m *Manager) Add(source string, route target.Route, active bool) {
func (m *Manager) InsertRoute(route target.Route) error {
m.s.Lock()
defer m.s.Unlock()
_, err := m.db.Exec(`INSERT INTO routes (source, pre, destination, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, source, route.Pre, path.Join(route.Host, route.Path), route.Abs, route.Cors, route.SecureMode, route.ForwardHost, route.ForwardAddr, route.IgnoreCert, active)
if err != nil {
log.Printf("[Violet] Database error: %s\n", err)
}
_, err := m.db.Exec(`INSERT INTO routes (source, destination, flags) VALUES (?, ?, ?) ON CONFLICT(source) DO UPDATE SET destination = excluded.destination, flags = excluded.flags, active = 1`, route.Src, route.Dst, route.Flags)
return err
}
// addRoute is an alias to parse the src and dst then add the route
func addRoute(router *Router, src string, dst string, t target.Route) error {
srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst)
if err != nil {
return err
}
// update target route values and add route
t.Host = dstHost
t.Port = dstPort
t.Path = dstPath
router.AddRoute(srcHost, srcPath, t)
return nil
func (m *Manager) DeleteRoute(source string) error {
m.s.Lock()
defer m.s.Unlock()
_, err := m.db.Exec(`UPDATE routes SET active = 0 WHERE source = ?`, source)
return err
}
// addRedirect is an alias to parse the src and dst then add the redirect
func addRedirect(router *Router, src string, dst string, t target.Redirect) error {
srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst)
if err != nil {
return err
}
t.Host = dstHost
t.Port = dstPort
t.Path = dstPath
router.AddRedirect(srcHost, srcPath, t)
return nil
func (m *Manager) InsertRedirect(redirect target.Redirect) error {
m.s.Lock()
defer m.s.Unlock()
_, err := m.db.Exec(`INSERT INTO redirects (source, destination, flags, code) VALUES (?, ?, ?, ?) ON CONFLICT(source) DO UPDATE SET destination = excluded.destination, flags = excluded.flags, code = excluded.code, active = 1`, redirect.Src, redirect.Dst, redirect.Flags, redirect.Code)
return err
}
// parseSrcDstHost extracts the host/path and host:port/path from the src and dst values
func parseSrcDstHost(src string, dst string) (string, string, string, int, string, error) {
// check if source has path
var srcHost, srcPath string
nSrc := strings.IndexByte(src, '/')
if nSrc == -1 {
// set host then path to /
srcHost = src
srcPath = "/"
} else {
// set host then custom path
srcHost = src[:nSrc]
srcPath = src[nSrc:]
}
// check if destination has path
var dstPath string
nDst := strings.IndexByte(dst, '/')
if nDst == -1 {
// set path to /
dstPath = "/"
} else {
// set custom path then trim dst string to the host
dstPath = dst[nDst:]
dst = dst[:nDst]
}
// try to split the destination host into domain + port
dstHost, dstPort, ok := utils.SplitDomainPort(dst, 0)
if !ok {
return "", "", "", 0, "", fmt.Errorf("failed to split destination '%s' into host + port", dst)
}
return srcHost, srcPath, dstHost, dstPort, dstPath, nil
func (m *Manager) DeleteRedirect(source string) error {
m.s.Lock()
defer m.s.Unlock()
_, err := m.db.Exec(`UPDATE redirects SET active = 0 WHERE source = ?`, source)
return err
}

View File

@ -3,6 +3,7 @@ package router
import (
"database/sql"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/target"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"net/http"
@ -37,7 +38,7 @@ func TestNewManager(t *testing.T) {
assert.Equal(t, http.StatusTeapot, res.StatusCode)
assert.Nil(t, ft.req)
_, err = db.Exec(`INSERT INTO routes (source, pre, destination, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert, active) VALUES (?,?,?,?,?,?,?,?,?,?)`, "*.example.com", 0, "127.0.0.1:8080", 1, 0, 0, 1, 1, 0, 1)
_, err = db.Exec(`INSERT INTO routes (source, destination, flags, active) VALUES (?,?,?,1)`, "*.example.com", "127.0.0.1:8080", target.FlagAbs|target.FlagForwardHost|target.FlagForwardAddr)
assert.NoError(t, err)
assert.NoError(t, m.internalCompile(m.r))

View File

@ -1,7 +0,0 @@
select source,
pre,
destination,
abs,
code
from redirects
where active = true

View File

@ -1,11 +0,0 @@
select source,
pre,
destination,
abs,
cors,
secure_mode,
forward_host,
forward_addr,
ignore_cert
from routes
where active = true

View File

@ -46,16 +46,14 @@ func (r *Router) hostRedirect(host string) *trie.Trie[target.Redirect] {
return h
}
func (r *Router) AddService(host string, t target.Route) {
r.AddRoute(host, "/", t)
}
func (r *Router) AddRoute(host string, path string, t target.Route) {
func (r *Router) AddRoute(t target.Route) {
t.Proxy = r.proxy
host, path := utils.SplitHostPath(t.Src)
r.hostRoute(host).PutString(path, t)
}
func (r *Router) AddRedirect(host, path string, t target.Redirect) {
func (r *Router) AddRedirect(t target.Redirect) {
host, path := utils.SplitHostPath(t.Src)
r.hostRedirect(host).PutString(path, t)
}
@ -95,7 +93,7 @@ func (r *Router) serveRouteHTTP(rw http.ResponseWriter, req *http.Request, host
if h != nil {
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
for i := len(pairs) - 1; i >= 0; i-- {
if pairs[i].Value.Pre || pairs[i].Key == req.URL.Path {
if pairs[i].Value.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
pairs[i].Value.ServeHTTP(rw, req)
return true
@ -110,7 +108,7 @@ func (r *Router) serveRedirectHTTP(rw http.ResponseWriter, req *http.Request, ho
if h != nil {
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
for i := len(pairs) - 1; i >= 0; i-- {
if pairs[i].Value.Pre || pairs[i].Key == req.URL.Path {
if pairs[i].Value.Flags.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
pairs[i].Value.ServeHTTP(rw, req)
return true

View File

@ -6,6 +6,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"path"
"testing"
)
@ -29,31 +30,31 @@ var (
"/": "/",
"/hello": "",
}},
{"/", target.Route{Path: "/world"}, mss{
{"/", target.Route{Dst: "world"}, mss{
"/": "/world",
"/hello": "",
}},
{"/", target.Route{Abs: true}, mss{
{"/", target.Route{Flags: target.FlagAbs}, mss{
"/": "/",
"/hello": "",
}},
{"/", target.Route{Abs: true, Path: "world"}, mss{
{"/", target.Route{Flags: target.FlagAbs, Dst: "world"}, mss{
"/": "/world",
"/hello": "",
}},
{"/", target.Route{Pre: true}, mss{
{"/", target.Route{Flags: target.FlagPre}, mss{
"/": "/",
"/hello": "/hello",
}},
{"/", target.Route{Pre: true, Path: "world"}, mss{
{"/", target.Route{Flags: target.FlagPre, Dst: "world"}, mss{
"/": "/world",
"/hello": "/world/hello",
}},
{"/", target.Route{Pre: true, Abs: true}, mss{
{"/", target.Route{Flags: target.FlagPre | target.FlagAbs}, mss{
"/": "/",
"/hello": "/",
}},
{"/", target.Route{Pre: true, Abs: true, Path: "world"}, mss{
{"/", target.Route{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{
"/": "/world",
"/hello": "/world",
}},
@ -62,37 +63,37 @@ var (
"/hello": "/",
"/hello/hi": "",
}},
{"/hello", target.Route{Path: "world"}, mss{
{"/hello", target.Route{Dst: "world"}, mss{
"/": "",
"/hello": "/world",
"/hello/hi": "",
}},
{"/hello", target.Route{Abs: true}, mss{
{"/hello", target.Route{Flags: target.FlagAbs}, mss{
"/": "",
"/hello": "/",
"/hello/hi": "",
}},
{"/hello", target.Route{Abs: true, Path: "world"}, mss{
{"/hello", target.Route{Flags: target.FlagAbs, Dst: "world"}, mss{
"/": "",
"/hello": "/world",
"/hello/hi": "",
}},
{"/hello", target.Route{Pre: true}, mss{
{"/hello", target.Route{Flags: target.FlagPre}, mss{
"/": "",
"/hello": "/",
"/hello/hi": "/hi",
}},
{"/hello", target.Route{Pre: true, Path: "world"}, mss{
{"/hello", target.Route{Flags: target.FlagPre, Dst: "world"}, mss{
"/": "",
"/hello": "/world",
"/hello/hi": "/world/hi",
}},
{"/hello", target.Route{Pre: true, Abs: true}, mss{
{"/hello", target.Route{Flags: target.FlagPre | target.FlagAbs}, mss{
"/": "",
"/hello": "/",
"/hello/hi": "/",
}},
{"/hello", target.Route{Pre: true, Abs: true, Path: "world"}, mss{
{"/hello", target.Route{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{
"/": "",
"/hello": "/world",
"/hello/hi": "/world",
@ -103,31 +104,31 @@ var (
"/": "/",
"/hello": "",
}},
{"/", target.Redirect{Path: "world"}, mss{
{"/", target.Redirect{Dst: "world"}, mss{
"/": "/world",
"/hello": "",
}},
{"/", target.Redirect{Abs: true}, mss{
{"/", target.Redirect{Flags: target.FlagAbs}, mss{
"/": "/",
"/hello": "",
}},
{"/", target.Redirect{Abs: true, Path: "world"}, mss{
{"/", target.Redirect{Flags: target.FlagAbs, Dst: "world"}, mss{
"/": "/world",
"/hello": "",
}},
{"/", target.Redirect{Pre: true}, mss{
{"/", target.Redirect{Flags: target.FlagPre}, mss{
"/": "/",
"/hello": "/hello",
}},
{"/", target.Redirect{Pre: true, Path: "world"}, mss{
{"/", target.Redirect{Flags: target.FlagPre, Dst: "world"}, mss{
"/": "/world",
"/hello": "/world/hello",
}},
{"/", target.Redirect{Pre: true, Abs: true}, mss{
{"/", target.Redirect{Flags: target.FlagPre | target.FlagAbs}, mss{
"/": "/",
"/hello": "/",
}},
{"/", target.Redirect{Pre: true, Abs: true, Path: "world"}, mss{
{"/", target.Redirect{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{
"/": "/world",
"/hello": "/world",
}},
@ -136,37 +137,37 @@ var (
"/hello": "/",
"/hello/hi": "",
}},
{"/hello", target.Redirect{Path: "world"}, mss{
{"/hello", target.Redirect{Dst: "world"}, mss{
"/": "",
"/hello": "/world",
"/hello/hi": "",
}},
{"/hello", target.Redirect{Abs: true}, mss{
{"/hello", target.Redirect{Flags: target.FlagAbs}, mss{
"/": "",
"/hello": "/",
"/hello/hi": "",
}},
{"/hello", target.Redirect{Abs: true, Path: "world"}, mss{
{"/hello", target.Redirect{Flags: target.FlagAbs, Dst: "world"}, mss{
"/": "",
"/hello": "/world",
"/hello/hi": "",
}},
{"/hello", target.Redirect{Pre: true}, mss{
{"/hello", target.Redirect{Flags: target.FlagPre}, mss{
"/": "",
"/hello": "/",
"/hello/hi": "/hi",
}},
{"/hello", target.Redirect{Pre: true, Path: "world"}, mss{
{"/hello", target.Redirect{Flags: target.FlagPre, Dst: "world"}, mss{
"/": "",
"/hello": "/world",
"/hello/hi": "/world/hi",
}},
{"/hello", target.Redirect{Pre: true, Abs: true}, mss{
{"/hello", target.Redirect{Flags: target.FlagPre | target.FlagAbs}, mss{
"/": "",
"/hello": "/",
"/hello/hi": "/",
}},
{"/hello", target.Redirect{Pre: true, Abs: true, Path: "world"}, mss{
{"/hello", target.Redirect{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{
"/": "",
"/hello": "/world",
"/hello/hi": "/world",
@ -181,10 +182,10 @@ func TestRouter_AddRoute(t *testing.T) {
for _, i := range routeTests {
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure))
dst := i.dst
dst.Host = "127.0.0.1"
dst.Port = 8080
dst.Dst = path.Join("127.0.0.1:8080", dst.Dst)
dst.Src = path.Join("example.com", i.path)
t.Logf("Running tests for %#v\n", dst)
r.AddRoute("example.com", i.path, dst)
r.AddRoute(dst)
for k, v := range i.tests {
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: k}
req, _ := http.NewRequest(http.MethodGet, u1.String(), nil)
@ -217,10 +218,11 @@ func TestRouter_AddRedirect(t *testing.T) {
for _, i := range redirectTests {
r := New(nil)
dst := i.dst
dst.Host = "example.com"
dst.Dst = path.Join("example.com", dst.Dst)
dst.Code = http.StatusFound
dst.Src = path.Join("www.example.com", i.path)
t.Logf("Running tests for %#v\n", dst)
r.AddRedirect("www.example.com", i.path, dst)
r.AddRedirect(dst)
for k, v := range i.tests {
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: v}
if v == "" {
@ -266,10 +268,10 @@ func TestRouter_AddWildcardRoute(t *testing.T) {
for _, i := range routeTests {
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure))
dst := i.dst
dst.Host = "127.0.0.1"
dst.Port = 8080
dst.Dst = path.Join("127.0.0.1:8080", dst.Dst)
dst.Src = path.Join("example.com", i.path)
t.Logf("Running tests for %#v\n", dst)
r.AddRoute("example.com", i.path, dst)
r.AddRoute(dst)
for k, v := range i.tests {
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: k}
req, _ := http.NewRequest(http.MethodGet, u1.String(), nil)

View File

@ -1,115 +0,0 @@
package servers
import (
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/mjwt/auth"
"github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter"
"net/http"
"time"
)
// NewApiServer creates and runs a http server containing all the API
// endpoints for the software
//
// `/compile` - reloads all domains, routes and redirects
func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server {
r := httprouter.New()
// Endpoint for compile action
r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
if !hasPerms(conf.Signer, req, "violet:compile") {
utils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
// Trigger the compile action
compileTarget.Compile()
rw.WriteHeader(http.StatusAccepted)
})
// Endpoint for domains
r.PUT("/domain/:domain", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
if !hasPerms(conf.Signer, req, "violet:domains") {
utils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
// add domain with active state
q := req.URL.Query()
conf.Domains.Put(params.ByName("domain"), q.Get("active") == "1")
conf.Domains.Compile()
})
r.DELETE("/domain/:domain", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
if !hasPerms(conf.Signer, req, "violet:domains") {
utils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
// add domain with active state
q := req.URL.Query()
conf.Domains.Put(params.ByName("domain"), q.Get("active") == "1")
conf.Domains.Compile()
})
// Endpoint for routes
r.POST("/route", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
})
// Endpoint for acme-challenge
r.PUT("/acme-challenge/:domain/:key/:value", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
if !hasPerms(conf.Signer, req, "violet:acme-challenge") {
utils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
domain := params.ByName("domain")
if !conf.Domains.IsValid(domain) {
utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain")
return
}
conf.Acme.Put(domain, params.ByName("key"), params.ByName("value"))
rw.WriteHeader(http.StatusAccepted)
})
r.DELETE("/acme-challenge/:domain/:key", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
if !hasPerms(conf.Signer, req, "violet:acme-challenge") {
utils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
domain := params.ByName("domain")
if !conf.Domains.IsValid(domain) {
utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain")
return
}
conf.Acme.Delete(domain, params.ByName("key"))
rw.WriteHeader(http.StatusAccepted)
})
// Create and run http server
return &http.Server{
Addr: conf.ApiListen,
Handler: r,
ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute,
WriteTimeout: time.Minute,
IdleTimeout: time.Minute,
MaxHeaderBytes: 2500,
}
}
func hasPerms(verify mjwt.Verifier, req *http.Request, perm string) bool {
// Get bearer token
bearer := utils.GetBearer(req)
if bearer == "" {
return false
}
// Read claims from mjwt
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer)
if err != nil {
return false
}
// Token must have perm
return b.Claims.Perms.Has(perm)
}

102
servers/api/api.go Normal file
View File

@ -0,0 +1,102 @@
package api
import (
"encoding/json"
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/mjwt/claims"
"github.com/MrMelon54/violet/servers/conf"
"github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter"
"net/http"
"time"
)
// NewApiServer creates and runs a http server containing all the API
// endpoints for the software
//
// `/compile` - reloads all domains, routes and redirects
func NewApiServer(conf *conf.Conf, compileTarget utils.MultiCompilable) *http.Server {
r := httprouter.New()
// Endpoint for compile action
r.POST("/compile", checkAuthWithPerm(conf.Signer, "violet:compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, b AuthClaims) {
// Trigger the compile action
compileTarget.Compile()
rw.WriteHeader(http.StatusAccepted)
}))
// Endpoint for domains
domainFunc := domainManage(conf.Signer, conf.Domains)
r.PUT("/domain/:domain", domainFunc)
r.DELETE("/domain/:domain", domainFunc)
// Endpoint code for target routes/redirects
targetApis := SetupTargetApis(conf.Signer, conf.Router)
// Endpoint for routes
r.POST("/route", targetApis.CreateRoute)
r.DELETE("/route", targetApis.DeleteRoute)
// Endpoint for redirects
r.POST("/redirect", targetApis.CreateRedirect)
r.DELETE("/redirect", targetApis.DeleteRedirect)
// Endpoint for acme-challenge
acmeChallengeFunc := acmeChallengeManage(conf.Signer, conf.Domains, conf.Acme)
r.PUT("/acme-challenge/:domain/:key/:value", acmeChallengeFunc)
r.DELETE("/acme-challenge/:domain/:key", acmeChallengeFunc)
// Create and run http server
return &http.Server{
Addr: conf.ApiListen,
Handler: r,
ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute,
WriteTimeout: time.Minute,
IdleTimeout: time.Minute,
MaxHeaderBytes: 2500,
}
}
// apiError outputs a generic JSON error message
func apiError(rw http.ResponseWriter, code int, m string) {
rw.WriteHeader(code)
_ = json.NewEncoder(rw).Encode(map[string]string{
"error": m,
})
}
func domainManage(verify mjwt.Verifier, domains utils.DomainProvider) httprouter.Handle {
return checkAuthWithPerm(verify, "violet:domains", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
// add domain with active state
domains.Put(params.ByName("domain"), req.Method == http.MethodPut)
domains.Compile()
})
}
func acmeChallengeManage(verify mjwt.Verifier, domains utils.DomainProvider, acme utils.AcmeChallengeProvider) httprouter.Handle {
return checkAuthWithPerm(verify, "violet:acme-challenge", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
domain := params.ByName("domain")
if !domains.IsValid(domain) {
utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain")
return
}
if req.Method == http.MethodPut {
acme.Put(domain, params.ByName("key"), params.ByName("value"))
} else {
acme.Delete(domain, params.ByName("key"))
}
rw.WriteHeader(http.StatusAccepted)
})
}
// validateDomainOwnershipClaims validates if the claims contain the
// `owns=<fqdn>` field with the matching top level domain
func validateDomainOwnershipClaims(a string, perms *claims.PermStorage) bool {
if fqdn, ok := utils.GetTopFqdn(a); ok {
if perms.Has("owns=" + fqdn) {
return true
}
}
return false
}

View File

@ -1,59 +1,22 @@
package servers
package api
import (
"crypto/rand"
"crypto/rsa"
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/mjwt/auth"
"github.com/MrMelon54/mjwt/claims"
"github.com/MrMelon54/violet/servers/conf"
"github.com/MrMelon54/violet/utils"
"github.com/MrMelon54/violet/utils/fake"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
)
var snakeOilProv = genSnakeOilProv()
type fakeDomains struct{}
func (f *fakeDomains) IsValid(host string) bool { return host == "example.com" }
func (f *fakeDomains) Put(string, bool) {}
func (f *fakeDomains) Delete(string) {}
func (f *fakeDomains) Compile() {}
func genSnakeOilProv() mjwt.Signer {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
}
return mjwt.NewMJwtSigner("violet.test", key)
}
func genSnakeOilKey(perm string) string {
p := claims.NewPermStorage()
p.Set(perm)
val, err := snakeOilProv.GenerateJwt("abc", "abc", nil, 5*time.Minute, auth.AccessTokenClaims{Perms: p})
if err != nil {
panic(err)
}
return val
}
type fakeCompilable struct{ done bool }
func (f *fakeCompilable) Compile() { f.done = true }
var _ utils.Compilable = &fakeCompilable{}
func TestNewApiServer_Compile(t *testing.T) {
apiConf := &Conf{
Domains: &fakeDomains{},
apiConf := &conf.Conf{
Domains: &fake.Domains{},
Acme: utils.NewAcmeChallenge(),
Signer: snakeOilProv,
Signer: fake.SnakeOilProv,
}
f := &fakeCompilable{}
f := &fake.Compilable{}
srv := NewApiServer(apiConf, utils.MultiCompilable{f})
req, err := http.NewRequest(http.MethodPost, "https://example.com/compile", nil)
@ -63,25 +26,25 @@ func TestNewApiServer_Compile(t *testing.T) {
srv.Handler.ServeHTTP(rec, req)
res := rec.Result()
assert.Equal(t, http.StatusForbidden, res.StatusCode)
assert.False(t, f.done)
assert.False(t, f.Done)
req.Header.Set("Authorization", "Bearer "+genSnakeOilKey("violet:compile"))
req.Header.Set("Authorization", "Bearer "+fake.GenSnakeOilKey("violet:compile"))
rec = httptest.NewRecorder()
srv.Handler.ServeHTTP(rec, req)
res = rec.Result()
assert.Equal(t, http.StatusAccepted, res.StatusCode)
assert.True(t, f.done)
assert.True(t, f.Done)
}
func TestNewApiServer_AcmeChallenge_Put(t *testing.T) {
apiConf := &Conf{
Domains: &fakeDomains{},
apiConf := &conf.Conf{
Domains: &fake.Domains{},
Acme: utils.NewAcmeChallenge(),
Signer: snakeOilProv,
Signer: fake.SnakeOilProv,
}
srv := NewApiServer(apiConf, utils.MultiCompilable{})
acmeKey := genSnakeOilKey("violet:acme-challenge")
acmeKey := fake.GenSnakeOilKey("violet:acme-challenge")
// Valid domain
req, err := http.NewRequest(http.MethodPut, "https://example.com/acme-challenge/example.com/123/123abc", nil)
@ -119,13 +82,13 @@ func TestNewApiServer_AcmeChallenge_Put(t *testing.T) {
}
func TestNewApiServer_AcmeChallenge_Delete(t *testing.T) {
apiConf := &Conf{
Domains: &fakeDomains{},
apiConf := &conf.Conf{
Domains: &fake.Domains{},
Acme: utils.NewAcmeChallenge(),
Signer: snakeOilProv,
Signer: fake.SnakeOilProv,
}
srv := NewApiServer(apiConf, utils.MultiCompilable{})
acmeKey := genSnakeOilKey("violet:acme-challenge")
acmeKey := fake.GenSnakeOilKey("violet:acme-challenge")
// Valid domain
req, err := http.NewRequest(http.MethodDelete, "https://example.com/acme-challenge/example.com/123", nil)

49
servers/api/auth.go Normal file
View File

@ -0,0 +1,49 @@
package api
import (
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/mjwt/auth"
"github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter"
"net/http"
)
type AuthClaims mjwt.BaseTypeClaims[auth.AccessTokenClaims]
type AuthCallback func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims)
// checkAuth validates the bearer token against a mjwt.Verifier and returns an
// error message or continues to the next handler
func checkAuth(verify mjwt.Verifier, cb AuthCallback) httprouter.Handle {
return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
// Get bearer token
bearer := utils.GetBearer(req)
if bearer == "" {
apiError(rw, http.StatusForbidden, "Missing bearer token")
return
}
// Read claims from mjwt
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer)
if err != nil {
apiError(rw, http.StatusForbidden, "Invalid token")
return
}
cb(rw, req, params, AuthClaims(b))
}
}
// checkAuthWithPerm validates the bearer token and checks if it contains a
// required permission and returns an error message or continues to the next
// handler
func checkAuthWithPerm(verify mjwt.Verifier, perm string, cb AuthCallback) httprouter.Handle {
return checkAuth(verify, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
// check perms
if !b.Claims.Perms.Has(perm) {
apiError(rw, http.StatusForbidden, "No permission")
return
}
cb(rw, req, params, b)
})
}

View File

@ -0,0 +1,27 @@
package api
import (
"github.com/MrMelon54/violet/target"
)
type sourceJson struct {
Src string `json:"src"`
}
func (s sourceJson) GetSource() string { return s.Src }
type routeSource target.Route
func (r routeSource) GetSource() string { return r.Src }
type redirectSource target.Redirect
func (r redirectSource) GetSource() string { return r.Src }
var (
_ sourceGetter = sourceJson{}
_ sourceGetter = routeSource{}
_ sourceGetter = redirectSource{}
)
type sourceGetter interface{ GetSource() string }

88
servers/api/target.go Normal file
View File

@ -0,0 +1,88 @@
package api
import (
"encoding/json"
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/violet/router"
"github.com/MrMelon54/violet/target"
"github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter"
"log"
"net/http"
"strings"
)
type TargetApis struct {
CreateRoute httprouter.Handle
DeleteRoute httprouter.Handle
CreateRedirect httprouter.Handle
DeleteRedirect httprouter.Handle
}
func SetupTargetApis(verify mjwt.Verifier, manager *router.Manager) *TargetApis {
r := &TargetApis{
CreateRoute: parseJsonAndCheckOwnership[routeSource](verify, "route", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t routeSource) {
err := manager.InsertRoute(target.Route(t))
if err != nil {
log.Printf("[Violet] Failed to insert route into database: %s\n", err)
apiError(rw, http.StatusInternalServerError, "Failed to insert route into database")
return
}
manager.Compile()
}),
DeleteRoute: parseJsonAndCheckOwnership[sourceJson](verify, "route", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t sourceJson) {
err := manager.DeleteRoute(t.Src)
if err != nil {
log.Printf("[Violet] Failed to delete route from database: %s\n", err)
apiError(rw, http.StatusInternalServerError, "Failed to delete route from database")
return
}
manager.Compile()
}),
CreateRedirect: parseJsonAndCheckOwnership[redirectSource](verify, "redirect", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t redirectSource) {
err := manager.InsertRedirect(target.Redirect(t))
if err != nil {
log.Printf("[Violet] Failed to insert redirect into database: %s\n", err)
apiError(rw, http.StatusInternalServerError, "Failed to insert redirect into database")
return
}
manager.Compile()
}),
DeleteRedirect: parseJsonAndCheckOwnership[sourceJson](verify, "redirect", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t sourceJson) {
err := manager.DeleteRedirect(t.Src)
if err != nil {
log.Printf("[Violet] Failed to delete redirect from database: %s\n", err)
apiError(rw, http.StatusInternalServerError, "Failed to delete redirect from database")
return
}
manager.Compile()
}),
}
return r
}
type AuthWithJsonCallback[T any] func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t T)
func parseJsonAndCheckOwnership[T sourceGetter](verify mjwt.Verifier, t string, cb AuthWithJsonCallback[T]) httprouter.Handle {
return checkAuthWithPerm(verify, "violet:"+t, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
var j T
if json.NewDecoder(req.Body).Decode(&j) != nil {
apiError(rw, http.StatusBadRequest, "Invalid request body")
return
}
// check token owns this domain
host, _ := utils.SplitHostPath(j.GetSource())
if strings.IndexByte(host, ':') != -1 {
apiError(rw, http.StatusBadRequest, "Invalid route source")
return
}
if !validateDomainOwnershipClaims(host, b.Claims.Perms) {
apiError(rw, http.StatusBadRequest, "Token cannot modify the specified domain")
return
}
cb(rw, req, params, b, j)
})
}

View File

@ -1,12 +1,12 @@
package servers
package conf
import (
"crypto/tls"
"database/sql"
"github.com/MrMelon54/mjwt"
errorPages "github.com/MrMelon54/violet/error-pages"
"github.com/MrMelon54/violet/favicons"
"github.com/MrMelon54/violet/router"
"github.com/MrMelon54/violet/utils"
)
// Conf stores the shared configuration for the API, HTTP and HTTPS servers.
@ -16,29 +16,11 @@ type Conf struct {
HttpsListen string // https server listen address
RateLimit uint64 // rate limit per minute
DB *sql.DB
Domains DomainProvider
Acme AcmeChallengeProvider
Certs CertProvider
Domains utils.DomainProvider
Acme utils.AcmeChallengeProvider
Certs utils.CertProvider
Favicons *favicons.Favicons
Signer mjwt.Verifier
ErrorPages *errorPages.ErrorPages
Router *router.Manager
}
type DomainProvider interface {
IsValid(host string) bool
Put(domain string, active bool)
Delete(domain string)
Compile()
}
type AcmeChallengeProvider interface {
Get(domain, key string) string
Put(domain, key, value string)
Delete(domain, key string)
}
type CertProvider interface {
GetCertForDomain(domain string) *tls.Certificate
Compile()
}

View File

@ -2,6 +2,7 @@ package servers
import (
"fmt"
"github.com/MrMelon54/violet/servers/conf"
"github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter"
"net/http"
@ -14,7 +15,7 @@ import (
//
// `/.well-known/acme-challenge/{token}` is used for outputting answers for
// acme challenges, this is used for Let's Encrypt HTTP verification.
func NewHttpServer(conf *Conf) *http.Server {
func NewHttpServer(conf *conf.Conf) *http.Server {
r := httprouter.New()
var secureExtend string
_, httpsPort, ok := utils.SplitDomainPort(conf.HttpsListen, 443)

View File

@ -2,7 +2,9 @@ package servers
import (
"bytes"
"github.com/MrMelon54/violet/servers/conf"
"github.com/MrMelon54/violet/utils"
"github.com/MrMelon54/violet/utils/fake"
"github.com/stretchr/testify/assert"
"io"
"net/http"
@ -11,10 +13,10 @@ import (
)
func TestNewHttpServer_AcmeChallenge(t *testing.T) {
httpConf := &Conf{
Domains: &fakeDomains{},
httpConf := &conf.Conf{
Domains: &fake.Domains{},
Acme: utils.NewAcmeChallenge(),
Signer: snakeOilProv,
Signer: fake.SnakeOilProv,
}
srv := NewHttpServer(httpConf)
httpConf.Acme.Put("example.com", "456", "456def")

View File

@ -4,6 +4,7 @@ import (
"crypto/tls"
"fmt"
"github.com/MrMelon54/violet/favicons"
"github.com/MrMelon54/violet/servers/conf"
"github.com/MrMelon54/violet/utils"
"github.com/sethvargo/go-limiter/httplimit"
"github.com/sethvargo/go-limiter/memorystore"
@ -16,7 +17,7 @@ import (
// NewHttpsServer creates and runs a http server containing the public https
// endpoints for the reverse proxy.
func NewHttpsServer(conf *Conf) *http.Server {
func NewHttpsServer(conf *conf.Conf) *http.Server {
return &http.Server{
Addr: conf.HttpsListen,
Handler: setupRateLimiter(conf.RateLimit, setupFaviconMiddleware(conf.Favicons, conf.Router)),

View File

@ -5,6 +5,8 @@ import (
"github.com/MrMelon54/violet/certs"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/router"
"github.com/MrMelon54/violet/servers/conf"
"github.com/MrMelon54/violet/utils/fake"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"net/http"
@ -26,11 +28,11 @@ func TestNewHttpsServer_RateLimit(t *testing.T) {
assert.NoError(t, err)
ft := &fakeTransport{}
httpsConf := &Conf{
httpsConf := &conf.Conf{
RateLimit: 5,
Domains: &fakeDomains{},
Domains: &fake.Domains{},
Certs: certs.New(nil, nil, true),
Signer: snakeOilProv,
Signer: fake.SnakeOilProv,
Router: router.NewManager(db, proxy.NewHybridTransportWithCalls(ft, ft)),
}
srv := NewHttpsServer(httpsConf)

41
target/flags.go Normal file
View File

@ -0,0 +1,41 @@
package target
type Flags uint64
const (
FlagPre Flags = 1 << iota
FlagAbs
FlagCors
FlagSecureMode
FlagForwardHost
FlagForwardAddr
FlagIgnoreCert
)
var (
routeFlagMask = FlagPre | FlagAbs | FlagCors | FlagSecureMode | FlagForwardHost | FlagForwardAddr | FlagIgnoreCert
redirectFlagMask = FlagPre | FlagAbs
)
// HasFlag returns true if the bits contain the requested flag
func (f Flags) HasFlag(flag Flags) bool {
// 0110 & 0100 == 0100 (value != 0 thus true)
// 0011 & 0100 == 0000 (value == 0 thus false)
return f&flag != 0
}
// NormaliseRouteFlags returns only the bits used for routes
func (f Flags) NormaliseRouteFlags() Flags {
// removes bits outside the mask
// 0110 & 0111 == 0110
// 1010 & 0111 == 0010 (values are different)
return f & routeFlagMask
}
// NormaliseRedirectFlags returns only the bits used for redirects
func (f Flags) NormaliseRedirectFlags() Flags {
// removes bits outside the mask
// 0110 & 0111 == 0110
// 1010 & 0111 == 0010 (values are different)
return f & redirectFlagMask
}

View File

@ -12,20 +12,14 @@ import (
// Redirect is a target used by the router to manage redirecting the request
// using the specified configuration.
type Redirect struct {
Pre bool // if the path has had a prefix removed
Host string // target host
Port int // target port
Path string // target path (possibly a prefix or absolute)
Abs bool // if the path is a prefix or absolute
Code int // status code used to redirect
Src string `json:"src"` // request source
Dst string `json:"dst"` // redirect destination
Flags Flags `json:"flags"` // extra flags
Code int `json:"code"` // status code used to redirect
}
// FullHost outputs a host:port combo or just the host if the port is 0.
func (r Redirect) FullHost() string {
if r.Port == 0 {
return r.Host
}
return fmt.Sprintf("%s:%d", r.Host, r.Port)
func (r Route) HasFlag(flag Flags) bool {
return r.Flags&flag != 0
}
// ServeHTTP responds with the redirect to the response writer provided.
@ -36,10 +30,12 @@ func (r Redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
code = http.StatusFound
}
// split the host and path
host, p := utils.SplitHostPath(r.Dst)
// if not Abs then join with the ending of the current path
p := r.Path
if !r.Abs {
p = path.Join(r.Path, req.URL.Path)
if !r.Flags.HasFlag(FlagAbs) {
p = path.Join(p, req.URL.Path)
// replace the trailing slash that path.Join() strips off
if strings.HasSuffix(req.URL.Path, "/") {
@ -55,7 +51,7 @@ func (r Redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// create a new URL
u := &url.URL{
Scheme: req.URL.Scheme,
Host: r.FullHost(),
Host: host,
Path: p,
}

View File

@ -7,18 +7,13 @@ import (
"testing"
)
func TestRedirect_FullHost(t *testing.T) {
assert.Equal(t, "localhost", Redirect{Host: "localhost"}.FullHost())
assert.Equal(t, "localhost:22", Redirect{Host: "localhost", Port: 22}.FullHost())
}
func TestRedirect_ServeHTTP(t *testing.T) {
a := []struct {
Redirect
target string
}{
{Redirect{Host: "example.com", Path: "/bye", Abs: true, Code: http.StatusFound}, "https://example.com/bye"},
{Redirect{Host: "example.com", Path: "/bye", Code: http.StatusFound}, "https://example.com/bye/hello/world"},
{Redirect{Dst: "example.com/bye", Flags: FlagAbs, Code: http.StatusFound}, "https://example.com/bye"},
{Redirect{Dst: "example.com/bye", Code: http.StatusFound}, "https://example.com/bye/hello/world"},
}
for _, i := range a {
res := httptest.NewRecorder()

View File

@ -36,18 +36,11 @@ var serveApiCors = cors.New(cors.Options{
// Route is a target used by the router to manage forwarding traffic to an
// internal server using the specified configuration.
type Route struct {
Pre bool // if the path has had a prefix removed
Host string // target host
Port int // target port
Path string // target path (possibly a prefix or absolute)
Abs bool // if the path is a prefix or absolute
Cors bool // add CORS headers
SecureMode bool // use HTTPS internally
ForwardHost bool // forward host header internally
ForwardAddr bool // forward remote address
IgnoreCert bool // ignore self-cert
Headers http.Header // extra headers
Proxy *proxy.HybridTransport // reverse proxy handler
Src string `json:"src"` // request source
Dst string `json:"dst"` // proxy destination
Flags Flags `json:"flags"` // extra flags
Headers http.Header `json:"-"` // extra headers
Proxy *proxy.HybridTransport `json:"-"` // reverse proxy handler
}
// UpdateHeaders takes an existing set of headers and overwrites them with the
@ -58,18 +51,10 @@ func (r Route) UpdateHeaders(header http.Header) {
}
}
// FullHost outputs a host:port combo or just the host if the port is 0.
func (r Route) FullHost() string {
if r.Port == 0 {
return r.Host
}
return fmt.Sprintf("%s:%d", r.Host, r.Port)
}
// ServeHTTP responds with the data proxied from the internal server to the
// response writer provided.
func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if r.Cors {
if r.HasFlag(FlagCors) {
// wraps with CORS handler
serveApiCors.Handler(http.HandlerFunc(r.internalServeHTTP)).ServeHTTP(rw, req)
} else {
@ -82,21 +67,16 @@ func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
// set the scheme and port using defaults if the port is 0
scheme := "http"
if r.SecureMode {
if r.HasFlag(FlagSecureMode) {
scheme = "https"
if r.Port == 0 {
r.Port = 443
}
} else {
if r.Port == 0 {
r.Port = 80
}
}
// split the host and path
host, p := utils.SplitHostPath(r.Dst)
// if not Abs then join with the ending of the current path
p := r.Path
if !r.Abs {
p = path.Join(r.Path, req.URL.Path)
if !r.HasFlag(FlagAbs) {
p = path.Join(p, req.URL.Path)
// replace the trailing slash that path.Join() strips off
if strings.HasSuffix(req.URL.Path, "/") {
@ -112,7 +92,7 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
// create a new URL
u := &url.URL{
Scheme: scheme,
Host: r.FullHost(),
Host: host,
Path: p,
RawQuery: req.URL.RawQuery,
}
@ -150,10 +130,10 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
// if forward host is enabled then send the host
if r.ForwardHost {
if r.HasFlag(FlagForwardHost) {
req2.Host = req.Host
}
if r.ForwardAddr {
if r.HasFlag(FlagForwardAddr) {
req2.Header.Add("X-Forwarded-For", req.RemoteAddr)
}
@ -162,7 +142,7 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
// serve request with reverse proxy
var resp *http.Response
if r.IgnoreCert {
if r.HasFlag(FlagIgnoreCert) {
resp, err = r.Proxy.InsecureRoundTrip(req2)
} else {
resp, err = r.Proxy.SecureRoundTrip(req2)

View File

@ -25,9 +25,9 @@ func (p *proxyTester) RoundTrip(req *http.Request) (*http.Response, error) {
return &http.Response{StatusCode: http.StatusOK}, nil
}
func TestRoute_FullHost(t *testing.T) {
assert.Equal(t, "localhost", Route{Host: "localhost"}.FullHost())
assert.Equal(t, "localhost:22", Route{Host: "localhost", Port: 22}.FullHost())
func TestRoute_HasFlag(t *testing.T) {
assert.True(t, Route{Flags: FlagPre | FlagAbs}.HasFlag(FlagPre))
assert.False(t, Route{Flags: FlagPre | FlagAbs}.HasFlag(FlagCors))
}
func TestRoute_ServeHTTP(t *testing.T) {
@ -35,12 +35,12 @@ func TestRoute_ServeHTTP(t *testing.T) {
Route
target string
}{
{Route{Host: "localhost", Port: 1234, Path: "/bye", Abs: true}, "http://localhost:1234/bye"},
{Route{Host: "1.2.3.4", Path: "/bye"}, "http://1.2.3.4:80/bye/hello/world"},
{Route{Host: "2.2.2.2", Path: "/world", Abs: true, SecureMode: true}, "https://2.2.2.2:443/world"},
{Route{Host: "api.example.com", Path: "/world", Abs: true, SecureMode: true, ForwardHost: true}, "https://api.example.com:443/world"},
{Route{Host: "api.example.org", Path: "/world", Abs: true, SecureMode: true, ForwardAddr: true}, "https://api.example.org:443/world"},
{Route{Host: "3.3.3.3", Path: "/headers", Abs: true, Headers: http.Header{"X-Other": []string{"test value"}}}, "http://3.3.3.3:80/headers"},
{Route{Dst: "localhost:1234/bye", Flags: FlagAbs}, "http://localhost:1234/bye"},
{Route{Dst: "1.2.3.4/bye"}, "http://1.2.3.4/bye/hello/world"},
{Route{Dst: "2.2.2.2/world", Flags: FlagAbs | FlagSecureMode}, "https://2.2.2.2/world"},
{Route{Dst: "api.example.com/world", Flags: FlagAbs | FlagSecureMode | FlagForwardHost}, "https://api.example.com/world"},
{Route{Dst: "api.example.org/world", Flags: FlagAbs | FlagSecureMode | FlagForwardAddr}, "https://api.example.org/world"},
{Route{Dst: "3.3.3.3/headers", Flags: FlagAbs, Headers: http.Header{"X-Other": []string{"test value"}}}, "http://3.3.3.3/headers"},
}
for _, i := range a {
pt := &proxyTester{}
@ -51,10 +51,10 @@ func TestRoute_ServeHTTP(t *testing.T) {
assert.True(t, pt.got)
assert.Equal(t, i.target, pt.req.URL.String())
if i.ForwardAddr {
if i.HasFlag(FlagForwardAddr) {
assert.Equal(t, req.RemoteAddr, pt.req.Header.Get("X-Forwarded-For"))
}
if i.ForwardHost {
if i.HasFlag(FlagForwardHost) {
assert.Equal(t, req.Host, pt.req.Host)
}
if i.Headers != nil {
@ -68,7 +68,7 @@ func TestRoute_ServeHTTP_Cors(t *testing.T) {
res := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "https://www.example.com/test", nil)
req.Header.Set("Origin", "https://test.example.com")
i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt.makeHybridTransport()}
i := &Route{Dst: "1.1.1.1:8080/hello", Flags: FlagCors, Proxy: pt.makeHybridTransport()}
i.ServeHTTP(res, req)
assert.True(t, pt.got)
@ -86,7 +86,7 @@ func TestRoute_ServeHTTP_Body(t *testing.T) {
buf := bytes.NewBuffer([]byte{0x54})
req := httptest.NewRequest(http.MethodPost, "https://www.example.com/test", buf)
req.Header.Set("Origin", "https://test.example.com")
i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt.makeHybridTransport()}
i := &Route{Dst: "1.1.1.1:8080/hello", Flags: FlagCors, Proxy: pt.makeHybridTransport()}
i.ServeHTTP(res, req)
assert.True(t, pt.got)

View File

@ -83,3 +83,38 @@ func GetTopFqdn(domain string) (string, bool) {
}
return domain[n+1:], true
}
// SplitHostPath extracts the host/path from the input
func SplitHostPath(a string) (host, path string) {
// check if source has path
n := strings.IndexByte(a, '/')
if n == -1 {
// set host then path to /
host = a
path = "/"
} else {
// set host then custom path
host = a[:n]
path = a[n:] // this required to keep / at the start of the path
}
return
}
// SplitHostPathQuery extracts the host/path?query from the input
func SplitHostPathQuery(a string) (host, path, query string) {
host, path = SplitHostPath(a)
if path == "/" {
n := strings.IndexByte(host, '?')
if n != -1 {
query = host[n+1:]
host = host[:n]
}
return
}
n := strings.IndexByte(path, '?')
if n != -1 {
query = path[n+1:]
path = path[:n] // reassign happens after
}
return
}

View File

@ -60,3 +60,40 @@ func TestGetTopFqdn(t *testing.T) {
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain)
}
func TestSplitHostPath(t *testing.T) {
h, p := SplitHostPath("example.com/hello/world")
assert.Equal(t, "example.com", h)
assert.Equal(t, "/hello/world", p)
h, p = SplitHostPath("example.com")
assert.Equal(t, "example.com", h)
assert.Equal(t, "/", p)
}
func TestSplitHostPathQuery(t *testing.T) {
h, p, q := SplitHostPathQuery("example.com/hello/world")
assert.Equal(t, "example.com", h)
assert.Equal(t, "/hello/world", p)
assert.Equal(t, "", q)
h, p, q = SplitHostPathQuery("example.com")
assert.Equal(t, "example.com", h)
assert.Equal(t, "/", p)
assert.Equal(t, "", q)
h, p, q = SplitHostPathQuery("example.com/hello/world?a=b")
assert.Equal(t, "example.com", h)
assert.Equal(t, "/hello/world", p)
assert.Equal(t, "a=b", q)
h, p, q = SplitHostPathQuery("example.com?a=b")
assert.Equal(t, "example.com", h)
assert.Equal(t, "/", p)
assert.Equal(t, "a=b", q)
h, p, q = SplitHostPathQuery("example.com/?a=b")
assert.Equal(t, "example.com", h)
assert.Equal(t, "/", p)
assert.Equal(t, "a=b", q)
}

View File

@ -0,0 +1,11 @@
package fake
import "github.com/MrMelon54/violet/utils"
// Compilable implements utils.Compilable and stores if the Compile function
// is called.
type Compilable struct{ Done bool }
func (f *Compilable) Compile() { f.Done = true }
var _ utils.Compilable = &Compilable{}

View File

@ -0,0 +1,13 @@
package fake
import "github.com/MrMelon54/violet/utils"
// Domains implements DomainProvider and makes sure `example.com` is valid
type Domains struct{}
func (f *Domains) IsValid(host string) bool { return host == "example.com" }
func (f *Domains) Put(string, bool) {}
func (f *Domains) Delete(string) {}
func (f *Domains) Compile() {}
var _ utils.DomainProvider = &Domains{}

2
utils/fake/fake.go Normal file
View File

@ -0,0 +1,2 @@
// Package fake contains fake structs used during tests
package fake

30
utils/fake/mjwt.go Normal file
View File

@ -0,0 +1,30 @@
package fake
import (
"crypto/rand"
"crypto/rsa"
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/mjwt/auth"
"github.com/MrMelon54/mjwt/claims"
"time"
)
var SnakeOilProv = GenSnakeOilProv()
func GenSnakeOilProv() mjwt.Signer {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
}
return mjwt.NewMJwtSigner("violet.test", key)
}
func GenSnakeOilKey(perm string) string {
p := claims.NewPermStorage()
p.Set(perm)
val, err := SnakeOilProv.GenerateJwt("abc", "abc", nil, 5*time.Minute, auth.AccessTokenClaims{Perms: p})
if err != nil {
panic(err)
}
return val
}

21
utils/interfaces.go Normal file
View File

@ -0,0 +1,21 @@
package utils
import "crypto/tls"
type DomainProvider interface {
IsValid(host string) bool
Put(domain string, active bool)
Delete(domain string)
Compile()
}
type AcmeChallengeProvider interface {
Get(domain, key string) string
Put(domain, key, value string)
Delete(domain, key string)
}
type CertProvider interface {
GetCertForDomain(domain string) *tls.Certificate
Compile()
}