From 949dcd298a07a50a520987ffa1f0f66162de4d86 Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Wed, 12 Jul 2023 16:55:09 +0100 Subject: [PATCH] Write route/redirect APIs and rearrage some other code to make it possible --- .idea/dataSources.xml | 4 +- cmd/violet/serve.go | 40 ++++---- cmd/violet/setup.go | 16 +-- favicons/create-table-favicons.sql | 7 ++ favicons/favicons.go | 6 +- router/create-table-redirects.sql | 10 -- router/create-table-routes.sql | 14 --- router/create-tables.sql | 18 ++++ router/manager.go | 150 ++++++++--------------------- router/manager_test.go | 3 +- router/query-table-redirects.sql | 7 -- router/query-table-routes.sql | 11 --- router/router.go | 14 ++- router/router_test.go | 74 +++++++------- servers/api.go | 115 ---------------------- servers/api/api.go | 102 ++++++++++++++++++++ servers/{ => api}/api_test.go | 73 ++++---------- servers/api/auth.go | 49 ++++++++++ servers/api/target-types.go | 27 ++++++ servers/api/target.go | 88 +++++++++++++++++ servers/{ => conf}/conf.go | 28 +----- servers/http.go | 3 +- servers/http_test.go | 8 +- servers/https.go | 3 +- servers/https_test.go | 8 +- target/flags.go | 41 ++++++++ target/redirect.go | 28 +++--- target/redirect_test.go | 9 +- target/route.go | 52 +++------- target/route_test.go | 26 ++--- utils/domain-utils.go | 35 +++++++ utils/domain-utils_test.go | 37 +++++++ utils/fake/fake-compilable.go | 11 +++ utils/fake/fake-domains.go | 13 +++ utils/fake/fake.go | 2 + utils/fake/mjwt.go | 30 ++++++ utils/interfaces.go | 21 ++++ 37 files changed, 683 insertions(+), 500 deletions(-) create mode 100644 favicons/create-table-favicons.sql delete mode 100644 router/create-table-redirects.sql delete mode 100644 router/create-table-routes.sql create mode 100644 router/create-tables.sql delete mode 100644 router/query-table-redirects.sql delete mode 100644 router/query-table-routes.sql delete mode 100644 servers/api.go create mode 100644 servers/api/api.go rename servers/{ => api}/api_test.go (68%) create mode 100644 servers/api/auth.go create mode 100644 servers/api/target-types.go create mode 100644 servers/api/target.go rename servers/{ => conf}/conf.go (56%) create mode 100644 target/flags.go create mode 100644 utils/fake/fake-compilable.go create mode 100644 utils/fake/fake-domains.go create mode 100644 utils/fake/fake.go create mode 100644 utils/fake/mjwt.go create mode 100644 utils/interfaces.go diff --git a/.idea/dataSources.xml b/.idea/dataSources.xml index aba4cb2..d3d2ae2 100644 --- a/.idea/dataSources.xml +++ b/.idea/dataSources.xml @@ -1,11 +1,11 @@ - + sqlite.xerial true org.sqlite.JDBC - jdbc:sqlite:__db.sqlite + jdbc:sqlite:identifier.sqlite $ProjectFileDir$ diff --git a/cmd/violet/serve.go b/cmd/violet/serve.go index 17893cc..48e7faa 100644 --- a/cmd/violet/serve.go +++ b/cmd/violet/serve.go @@ -14,6 +14,8 @@ import ( "github.com/MrMelon54/violet/proxy" "github.com/MrMelon54/violet/router" "github.com/MrMelon54/violet/servers" + "github.com/MrMelon54/violet/servers/api" + "github.com/MrMelon54/violet/servers/conf" "github.com/MrMelon54/violet/utils" "github.com/google/subcommands" "io/fs" @@ -70,9 +72,9 @@ func (s *serveCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface{ return subcommands.ExitSuccess } -func normalLoad(conf startUpConfig, wd string) { +func normalLoad(startUp startUpConfig, wd string) { // the cert and key paths are useless in self-signed mode - if !conf.SelfSigned { + if !startUp.SelfSigned { // create path to cert dir err := os.MkdirAll(filepath.Join(wd, "certs"), os.ModePerm) if err != nil { @@ -87,11 +89,11 @@ func normalLoad(conf startUpConfig, wd string) { // errorPageDir stores an FS interface for accessing the error page directory var errorPageDir fs.FS - if conf.ErrorPagePath != "" { - errorPageDir = os.DirFS(conf.ErrorPagePath) - err := os.MkdirAll(conf.ErrorPagePath, os.ModePerm) + if startUp.ErrorPagePath != "" { + errorPageDir = os.DirFS(startUp.ErrorPagePath) + err := os.MkdirAll(startUp.ErrorPagePath, os.ModePerm) if err != nil { - log.Fatalf("[Violet] Failed to create error page path '%s'", conf.ErrorPagePath) + log.Fatalf("[Violet] Failed to create error page path '%s'", startUp.ErrorPagePath) } } @@ -110,20 +112,20 @@ func normalLoad(conf startUpConfig, wd string) { certDir := os.DirFS(filepath.Join(wd, "certs")) keyDir := os.DirFS(filepath.Join(wd, "keys")) - allowedDomains := domains.New(db) // load allowed domains - acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store - allowedCerts := certs.New(certDir, keyDir, conf.SelfSigned) // load certificate manager - hybridTransport := proxy.NewHybridTransport() // load reverse proxy - dynamicFavicons := favicons.New(db, conf.InkscapeCmd) // load dynamic favicon provider - dynamicErrorPages := errorPages.New(errorPageDir) // load dynamic error page provider - dynamicRouter := router.NewManager(db, hybridTransport) // load dynamic router manager + allowedDomains := domains.New(db) // load allowed domains + acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store + allowedCerts := certs.New(certDir, keyDir, startUp.SelfSigned) // load certificate manager + hybridTransport := proxy.NewHybridTransport() // load reverse proxy + dynamicFavicons := favicons.New(db, startUp.InkscapeCmd) // load dynamic favicon provider + dynamicErrorPages := errorPages.New(errorPageDir) // load dynamic error page provider + dynamicRouter := router.NewManager(db, hybridTransport) // load dynamic router manager // struct containing config for the http servers - srvConf := &servers.Conf{ - ApiListen: conf.Listen.Api, - HttpListen: conf.Listen.Http, - HttpsListen: conf.Listen.Https, - RateLimit: conf.RateLimit, + srvConf := &conf.Conf{ + ApiListen: startUp.Listen.Api, + HttpListen: startUp.Listen.Http, + HttpsListen: startUp.Listen.Https, + RateLimit: startUp.RateLimit, DB: db, Domains: allowedDomains, Acme: acmeChallenges, @@ -140,7 +142,7 @@ func normalLoad(conf startUpConfig, wd string) { var srvApi, srvHttp, srvHttps *http.Server if srvConf.ApiListen != "" { - srvApi = servers.NewApiServer(srvConf, allCompilables) + srvApi = api.NewApiServer(srvConf, allCompilables) log.Printf("[API] Starting API server on: '%s'\n", srvApi.Addr) go utils.RunBackgroundHttp("API", srvApi) } diff --git a/cmd/violet/setup.go b/cmd/violet/setup.go index 887ddda..087991f 100644 --- a/cmd/violet/setup.go +++ b/cmd/violet/setup.go @@ -181,13 +181,15 @@ func (s *setupCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) // add with the route manager, no need to compile as this will run when opened // with the serve subcommand routeManager := router.NewManager(db, proxy.NewHybridTransportWithCalls(&nilTransport{}, &nilTransport{})) - routeManager.Add(path.Join(apiUrl.Host, apiUrl.Path), target.Route{ - Pre: true, - Host: answers.ApiListen, - Cors: true, - ForwardHost: true, - ForwardAddr: true, - }, true) + err = routeManager.InsertRoute(target.Route{ + Src: path.Join(apiUrl.Host, apiUrl.Path), + Dst: answers.ApiListen, + Flags: target.FlagPre | target.FlagCors | target.FlagForwardHost | target.FlagForwardAddr, + }) + if err != nil { + fmt.Println("[Violet] Failed to insert api route into database: ", err) + return subcommands.ExitFailure + } } fmt.Println("[Violet] Setup complete") diff --git a/favicons/create-table-favicons.sql b/favicons/create-table-favicons.sql new file mode 100644 index 0000000..edc48ec --- /dev/null +++ b/favicons/create-table-favicons.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS favicons ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + host VARCHAR, + svg VARCHAR, + png VARCHAR, + ico VARCHAR +); diff --git a/favicons/favicons.go b/favicons/favicons.go index bbf4995..0864859 100644 --- a/favicons/favicons.go +++ b/favicons/favicons.go @@ -2,6 +2,7 @@ package favicons import ( "database/sql" + _ "embed" "errors" "fmt" "github.com/MrMelon54/rescheduler" @@ -12,6 +13,9 @@ import ( var ErrFaviconNotFound = errors.New("favicon not found") +//go:embed create-table-favicons.sql +var createTableFavicons string + // Favicons is a dynamic favicon generator which supports overwriting favicons type Favicons struct { db *sql.DB @@ -32,7 +36,7 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons { f.r = rescheduler.NewRescheduler(f.threadCompile) // init favicons table - _, err := f.db.Exec(`create table if not exists favicons (id integer primary key autoincrement, host varchar, svg varchar, png varchar, ico varchar)`) + _, err := f.db.Exec(createTableFavicons) if err != nil { log.Printf("[WARN] Failed to generate 'favicons' table\n") return nil diff --git a/router/create-table-redirects.sql b/router/create-table-redirects.sql deleted file mode 100644 index 2ae16d1..0000000 --- a/router/create-table-redirects.sql +++ /dev/null @@ -1,10 +0,0 @@ -CREATE TABLE IF NOT EXISTS redirects -( - id INTEGER PRIMARY KEY AUTOINCREMENT, - source TEXT, - pre INTEGER, - destination TEXT, - abs INTEGER, - code INTEGER, - active INTEGER DEFAULT 1 -); diff --git a/router/create-table-routes.sql b/router/create-table-routes.sql deleted file mode 100644 index 2b5fc38..0000000 --- a/router/create-table-routes.sql +++ /dev/null @@ -1,14 +0,0 @@ -CREATE TABLE IF NOT EXISTS routes -( - id INTEGER PRIMARY KEY AUTOINCREMENT, - source TEXT, - pre INTEGER, - destination TEXT, - abs INTEGER, - cors INTEGER, - secure_mode INTEGER, - forward_host INTEGER, - forward_addr INTEGER, - ignore_cert INTEGER, - active INTEGER DEFAULT 1 -); diff --git a/router/create-tables.sql b/router/create-tables.sql new file mode 100644 index 0000000..7779e67 --- /dev/null +++ b/router/create-tables.sql @@ -0,0 +1,18 @@ +CREATE TABLE IF NOT EXISTS routes +( + id INTEGER PRIMARY KEY AUTOINCREMENT, + source TEXT UNIQUE, + destination TEXT, + flags INTEGER DEFAULT 0, + active INTEGER DEFAULT 1 +); + +CREATE TABLE IF NOT EXISTS redirects +( + id INTEGER PRIMARY KEY AUTOINCREMENT, + source TEXT UNIQUE, + destination TEXT, + flags INTEGER DEFAULT 0, + code INTEGER DEFAULT 0, + active INTEGER DEFAULT 1 +); diff --git a/router/manager.go b/router/manager.go index 994d1c4..85208f1 100644 --- a/router/manager.go +++ b/router/manager.go @@ -3,15 +3,11 @@ package router import ( "database/sql" _ "embed" - "fmt" "github.com/MrMelon54/rescheduler" "github.com/MrMelon54/violet/proxy" "github.com/MrMelon54/violet/target" - "github.com/MrMelon54/violet/utils" "log" "net/http" - "path" - "strings" "sync" ) @@ -26,14 +22,8 @@ type Manager struct { } var ( - //go:embed create-table-routes.sql - createTableRoutes string - //go:embed create-table-redirects.sql - createTableRedirects string - //go:embed query-table-routes.sql - queryTableRoutes string - //go:embed query-table-redirects.sql - queryTableRedirects string + //go:embed create-tables.sql + createTables string ) // NewManager create a new manager, initialises the routes and redirects tables @@ -48,16 +38,9 @@ func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager { m.z = rescheduler.NewRescheduler(m.threadCompile) // init routes table - _, err := m.db.Exec(createTableRoutes) + _, err := m.db.Exec(createTables) if err != nil { - log.Printf("[WARN] Failed to generate 'routes' table\n") - return nil - } - - // init redirects table - _, err = m.db.Exec(createTableRedirects) - if err != nil { - log.Printf("[WARN] Failed to generate 'redirects' table\n") + log.Printf("[WARN] Failed to generate tables\n") return nil } return m @@ -96,7 +79,7 @@ func (m *Manager) internalCompile(router *Router) error { log.Println("[Manager] Updating routes from database") // sql or something? - rows, err := m.db.Query(queryTableRoutes) + rows, err := m.db.Query(`SELECT source, destination, flags FROM routes WHERE active = 1`) if err != nil { return err } @@ -105,26 +88,19 @@ func (m *Manager) internalCompile(router *Router) error { // loop through rows and scan the options for rows.Next() { var ( - pre, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert bool - src, dst string + src, dst string + flags target.Flags ) - err := rows.Scan(&src, &pre, &dst, &abs, &cors, &secure_mode, &forward_host, &forward_addr, &ignore_cert) + err := rows.Scan(&src, &dst, &flags) if err != nil { return err } - err = addRoute(router, src, dst, target.Route{ - Pre: pre, - Abs: abs, - Cors: cors, - SecureMode: secure_mode, - ForwardHost: forward_host, - ForwardAddr: forward_addr, - IgnoreCert: ignore_cert, + router.AddRoute(target.Route{ + Src: src, + Dst: dst, + Flags: flags.NormaliseRouteFlags(), }) - if err != nil { - return err - } } // check for errors @@ -133,7 +109,7 @@ func (m *Manager) internalCompile(router *Router) error { } // sql or something? - rows, err = m.db.Query(queryTableRedirects) + rows, err = m.db.Query(`SELECT source,destination,flags,code FROM redirects WHERE active = 1`) if err != nil { return err } @@ -142,99 +118,51 @@ func (m *Manager) internalCompile(router *Router) error { // loop through rows and scan the options for rows.Next() { var ( - pre, abs bool - code int src, dst string + flags target.Flags + code int ) - err := rows.Scan(&src, &pre, &dst, &abs, &code) + err := rows.Scan(&src, &dst, &flags, &code) if err != nil { return err } - err = addRedirect(router, src, dst, target.Redirect{ - Pre: pre, - Abs: abs, - Code: code, + router.AddRedirect(target.Redirect{ + Src: src, + Dst: dst, + Flags: flags.NormaliseRedirectFlags(), + Code: code, }) - if err != nil { - return err - } } // check for errors return rows.Err() } -func (m *Manager) Add(source string, route target.Route, active bool) { +func (m *Manager) InsertRoute(route target.Route) error { m.s.Lock() defer m.s.Unlock() - _, err := m.db.Exec(`INSERT INTO routes (source, pre, destination, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, source, route.Pre, path.Join(route.Host, route.Path), route.Abs, route.Cors, route.SecureMode, route.ForwardHost, route.ForwardAddr, route.IgnoreCert, active) - if err != nil { - log.Printf("[Violet] Database error: %s\n", err) - } + _, err := m.db.Exec(`INSERT INTO routes (source, destination, flags) VALUES (?, ?, ?) ON CONFLICT(source) DO UPDATE SET destination = excluded.destination, flags = excluded.flags, active = 1`, route.Src, route.Dst, route.Flags) + return err } -// addRoute is an alias to parse the src and dst then add the route -func addRoute(router *Router, src string, dst string, t target.Route) error { - srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst) - if err != nil { - return err - } - - // update target route values and add route - t.Host = dstHost - t.Port = dstPort - t.Path = dstPath - router.AddRoute(srcHost, srcPath, t) - return nil +func (m *Manager) DeleteRoute(source string) error { + m.s.Lock() + defer m.s.Unlock() + _, err := m.db.Exec(`UPDATE routes SET active = 0 WHERE source = ?`, source) + return err } -// addRedirect is an alias to parse the src and dst then add the redirect -func addRedirect(router *Router, src string, dst string, t target.Redirect) error { - srcHost, srcPath, dstHost, dstPort, dstPath, err := parseSrcDstHost(src, dst) - if err != nil { - return err - } - - t.Host = dstHost - t.Port = dstPort - t.Path = dstPath - router.AddRedirect(srcHost, srcPath, t) - return nil +func (m *Manager) InsertRedirect(redirect target.Redirect) error { + m.s.Lock() + defer m.s.Unlock() + _, err := m.db.Exec(`INSERT INTO redirects (source, destination, flags, code) VALUES (?, ?, ?, ?) ON CONFLICT(source) DO UPDATE SET destination = excluded.destination, flags = excluded.flags, code = excluded.code, active = 1`, redirect.Src, redirect.Dst, redirect.Flags, redirect.Code) + return err } -// parseSrcDstHost extracts the host/path and host:port/path from the src and dst values -func parseSrcDstHost(src string, dst string) (string, string, string, int, string, error) { - // check if source has path - var srcHost, srcPath string - nSrc := strings.IndexByte(src, '/') - if nSrc == -1 { - // set host then path to / - srcHost = src - srcPath = "/" - } else { - // set host then custom path - srcHost = src[:nSrc] - srcPath = src[nSrc:] - } - - // check if destination has path - var dstPath string - nDst := strings.IndexByte(dst, '/') - if nDst == -1 { - // set path to / - dstPath = "/" - } else { - // set custom path then trim dst string to the host - dstPath = dst[nDst:] - dst = dst[:nDst] - } - - // try to split the destination host into domain + port - dstHost, dstPort, ok := utils.SplitDomainPort(dst, 0) - if !ok { - return "", "", "", 0, "", fmt.Errorf("failed to split destination '%s' into host + port", dst) - } - - return srcHost, srcPath, dstHost, dstPort, dstPath, nil +func (m *Manager) DeleteRedirect(source string) error { + m.s.Lock() + defer m.s.Unlock() + _, err := m.db.Exec(`UPDATE redirects SET active = 0 WHERE source = ?`, source) + return err } diff --git a/router/manager_test.go b/router/manager_test.go index ccb8323..759a936 100644 --- a/router/manager_test.go +++ b/router/manager_test.go @@ -3,6 +3,7 @@ package router import ( "database/sql" "github.com/MrMelon54/violet/proxy" + "github.com/MrMelon54/violet/target" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" "net/http" @@ -37,7 +38,7 @@ func TestNewManager(t *testing.T) { assert.Equal(t, http.StatusTeapot, res.StatusCode) assert.Nil(t, ft.req) - _, err = db.Exec(`INSERT INTO routes (source, pre, destination, abs, cors, secure_mode, forward_host, forward_addr, ignore_cert, active) VALUES (?,?,?,?,?,?,?,?,?,?)`, "*.example.com", 0, "127.0.0.1:8080", 1, 0, 0, 1, 1, 0, 1) + _, err = db.Exec(`INSERT INTO routes (source, destination, flags, active) VALUES (?,?,?,1)`, "*.example.com", "127.0.0.1:8080", target.FlagAbs|target.FlagForwardHost|target.FlagForwardAddr) assert.NoError(t, err) assert.NoError(t, m.internalCompile(m.r)) diff --git a/router/query-table-redirects.sql b/router/query-table-redirects.sql deleted file mode 100644 index cb90280..0000000 --- a/router/query-table-redirects.sql +++ /dev/null @@ -1,7 +0,0 @@ -select source, - pre, - destination, - abs, - code -from redirects -where active = true diff --git a/router/query-table-routes.sql b/router/query-table-routes.sql deleted file mode 100644 index 107cd5c..0000000 --- a/router/query-table-routes.sql +++ /dev/null @@ -1,11 +0,0 @@ -select source, - pre, - destination, - abs, - cors, - secure_mode, - forward_host, - forward_addr, - ignore_cert -from routes -where active = true diff --git a/router/router.go b/router/router.go index 38227ec..0f90deb 100644 --- a/router/router.go +++ b/router/router.go @@ -46,16 +46,14 @@ func (r *Router) hostRedirect(host string) *trie.Trie[target.Redirect] { return h } -func (r *Router) AddService(host string, t target.Route) { - r.AddRoute(host, "/", t) -} - -func (r *Router) AddRoute(host string, path string, t target.Route) { +func (r *Router) AddRoute(t target.Route) { t.Proxy = r.proxy + host, path := utils.SplitHostPath(t.Src) r.hostRoute(host).PutString(path, t) } -func (r *Router) AddRedirect(host, path string, t target.Redirect) { +func (r *Router) AddRedirect(t target.Redirect) { + host, path := utils.SplitHostPath(t.Src) r.hostRedirect(host).PutString(path, t) } @@ -95,7 +93,7 @@ func (r *Router) serveRouteHTTP(rw http.ResponseWriter, req *http.Request, host if h != nil { pairs := h.GetAllKeyValues([]byte(req.URL.Path)) for i := len(pairs) - 1; i >= 0; i-- { - if pairs[i].Value.Pre || pairs[i].Key == req.URL.Path { + if pairs[i].Value.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path { req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key) pairs[i].Value.ServeHTTP(rw, req) return true @@ -110,7 +108,7 @@ func (r *Router) serveRedirectHTTP(rw http.ResponseWriter, req *http.Request, ho if h != nil { pairs := h.GetAllKeyValues([]byte(req.URL.Path)) for i := len(pairs) - 1; i >= 0; i-- { - if pairs[i].Value.Pre || pairs[i].Key == req.URL.Path { + if pairs[i].Value.Flags.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path { req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key) pairs[i].Value.ServeHTTP(rw, req) return true diff --git a/router/router_test.go b/router/router_test.go index 5090c78..097c1b6 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "path" "testing" ) @@ -29,31 +30,31 @@ var ( "/": "/", "/hello": "", }}, - {"/", target.Route{Path: "/world"}, mss{ + {"/", target.Route{Dst: "world"}, mss{ "/": "/world", "/hello": "", }}, - {"/", target.Route{Abs: true}, mss{ + {"/", target.Route{Flags: target.FlagAbs}, mss{ "/": "/", "/hello": "", }}, - {"/", target.Route{Abs: true, Path: "world"}, mss{ + {"/", target.Route{Flags: target.FlagAbs, Dst: "world"}, mss{ "/": "/world", "/hello": "", }}, - {"/", target.Route{Pre: true}, mss{ + {"/", target.Route{Flags: target.FlagPre}, mss{ "/": "/", "/hello": "/hello", }}, - {"/", target.Route{Pre: true, Path: "world"}, mss{ + {"/", target.Route{Flags: target.FlagPre, Dst: "world"}, mss{ "/": "/world", "/hello": "/world/hello", }}, - {"/", target.Route{Pre: true, Abs: true}, mss{ + {"/", target.Route{Flags: target.FlagPre | target.FlagAbs}, mss{ "/": "/", "/hello": "/", }}, - {"/", target.Route{Pre: true, Abs: true, Path: "world"}, mss{ + {"/", target.Route{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{ "/": "/world", "/hello": "/world", }}, @@ -62,37 +63,37 @@ var ( "/hello": "/", "/hello/hi": "", }}, - {"/hello", target.Route{Path: "world"}, mss{ + {"/hello", target.Route{Dst: "world"}, mss{ "/": "", "/hello": "/world", "/hello/hi": "", }}, - {"/hello", target.Route{Abs: true}, mss{ + {"/hello", target.Route{Flags: target.FlagAbs}, mss{ "/": "", "/hello": "/", "/hello/hi": "", }}, - {"/hello", target.Route{Abs: true, Path: "world"}, mss{ + {"/hello", target.Route{Flags: target.FlagAbs, Dst: "world"}, mss{ "/": "", "/hello": "/world", "/hello/hi": "", }}, - {"/hello", target.Route{Pre: true}, mss{ + {"/hello", target.Route{Flags: target.FlagPre}, mss{ "/": "", "/hello": "/", "/hello/hi": "/hi", }}, - {"/hello", target.Route{Pre: true, Path: "world"}, mss{ + {"/hello", target.Route{Flags: target.FlagPre, Dst: "world"}, mss{ "/": "", "/hello": "/world", "/hello/hi": "/world/hi", }}, - {"/hello", target.Route{Pre: true, Abs: true}, mss{ + {"/hello", target.Route{Flags: target.FlagPre | target.FlagAbs}, mss{ "/": "", "/hello": "/", "/hello/hi": "/", }}, - {"/hello", target.Route{Pre: true, Abs: true, Path: "world"}, mss{ + {"/hello", target.Route{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{ "/": "", "/hello": "/world", "/hello/hi": "/world", @@ -103,31 +104,31 @@ var ( "/": "/", "/hello": "", }}, - {"/", target.Redirect{Path: "world"}, mss{ + {"/", target.Redirect{Dst: "world"}, mss{ "/": "/world", "/hello": "", }}, - {"/", target.Redirect{Abs: true}, mss{ + {"/", target.Redirect{Flags: target.FlagAbs}, mss{ "/": "/", "/hello": "", }}, - {"/", target.Redirect{Abs: true, Path: "world"}, mss{ + {"/", target.Redirect{Flags: target.FlagAbs, Dst: "world"}, mss{ "/": "/world", "/hello": "", }}, - {"/", target.Redirect{Pre: true}, mss{ + {"/", target.Redirect{Flags: target.FlagPre}, mss{ "/": "/", "/hello": "/hello", }}, - {"/", target.Redirect{Pre: true, Path: "world"}, mss{ + {"/", target.Redirect{Flags: target.FlagPre, Dst: "world"}, mss{ "/": "/world", "/hello": "/world/hello", }}, - {"/", target.Redirect{Pre: true, Abs: true}, mss{ + {"/", target.Redirect{Flags: target.FlagPre | target.FlagAbs}, mss{ "/": "/", "/hello": "/", }}, - {"/", target.Redirect{Pre: true, Abs: true, Path: "world"}, mss{ + {"/", target.Redirect{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{ "/": "/world", "/hello": "/world", }}, @@ -136,37 +137,37 @@ var ( "/hello": "/", "/hello/hi": "", }}, - {"/hello", target.Redirect{Path: "world"}, mss{ + {"/hello", target.Redirect{Dst: "world"}, mss{ "/": "", "/hello": "/world", "/hello/hi": "", }}, - {"/hello", target.Redirect{Abs: true}, mss{ + {"/hello", target.Redirect{Flags: target.FlagAbs}, mss{ "/": "", "/hello": "/", "/hello/hi": "", }}, - {"/hello", target.Redirect{Abs: true, Path: "world"}, mss{ + {"/hello", target.Redirect{Flags: target.FlagAbs, Dst: "world"}, mss{ "/": "", "/hello": "/world", "/hello/hi": "", }}, - {"/hello", target.Redirect{Pre: true}, mss{ + {"/hello", target.Redirect{Flags: target.FlagPre}, mss{ "/": "", "/hello": "/", "/hello/hi": "/hi", }}, - {"/hello", target.Redirect{Pre: true, Path: "world"}, mss{ + {"/hello", target.Redirect{Flags: target.FlagPre, Dst: "world"}, mss{ "/": "", "/hello": "/world", "/hello/hi": "/world/hi", }}, - {"/hello", target.Redirect{Pre: true, Abs: true}, mss{ + {"/hello", target.Redirect{Flags: target.FlagPre | target.FlagAbs}, mss{ "/": "", "/hello": "/", "/hello/hi": "/", }}, - {"/hello", target.Redirect{Pre: true, Abs: true, Path: "world"}, mss{ + {"/hello", target.Redirect{Flags: target.FlagPre | target.FlagAbs, Dst: "world"}, mss{ "/": "", "/hello": "/world", "/hello/hi": "/world", @@ -181,10 +182,10 @@ func TestRouter_AddRoute(t *testing.T) { for _, i := range routeTests { r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure)) dst := i.dst - dst.Host = "127.0.0.1" - dst.Port = 8080 + dst.Dst = path.Join("127.0.0.1:8080", dst.Dst) + dst.Src = path.Join("example.com", i.path) t.Logf("Running tests for %#v\n", dst) - r.AddRoute("example.com", i.path, dst) + r.AddRoute(dst) for k, v := range i.tests { u1 := &url.URL{Scheme: "https", Host: "example.com", Path: k} req, _ := http.NewRequest(http.MethodGet, u1.String(), nil) @@ -217,10 +218,11 @@ func TestRouter_AddRedirect(t *testing.T) { for _, i := range redirectTests { r := New(nil) dst := i.dst - dst.Host = "example.com" + dst.Dst = path.Join("example.com", dst.Dst) dst.Code = http.StatusFound + dst.Src = path.Join("www.example.com", i.path) t.Logf("Running tests for %#v\n", dst) - r.AddRedirect("www.example.com", i.path, dst) + r.AddRedirect(dst) for k, v := range i.tests { u1 := &url.URL{Scheme: "https", Host: "example.com", Path: v} if v == "" { @@ -266,10 +268,10 @@ func TestRouter_AddWildcardRoute(t *testing.T) { for _, i := range routeTests { r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure)) dst := i.dst - dst.Host = "127.0.0.1" - dst.Port = 8080 + dst.Dst = path.Join("127.0.0.1:8080", dst.Dst) + dst.Src = path.Join("example.com", i.path) t.Logf("Running tests for %#v\n", dst) - r.AddRoute("example.com", i.path, dst) + r.AddRoute(dst) for k, v := range i.tests { u1 := &url.URL{Scheme: "https", Host: "example.com", Path: k} req, _ := http.NewRequest(http.MethodGet, u1.String(), nil) diff --git a/servers/api.go b/servers/api.go deleted file mode 100644 index b5bae9e..0000000 --- a/servers/api.go +++ /dev/null @@ -1,115 +0,0 @@ -package servers - -import ( - "github.com/MrMelon54/mjwt" - "github.com/MrMelon54/mjwt/auth" - "github.com/MrMelon54/violet/utils" - "github.com/julienschmidt/httprouter" - "net/http" - "time" -) - -// NewApiServer creates and runs a http server containing all the API -// endpoints for the software -// -// `/compile` - reloads all domains, routes and redirects -func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server { - r := httprouter.New() - - // Endpoint for compile action - r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { - if !hasPerms(conf.Signer, req, "violet:compile") { - utils.RespondHttpStatus(rw, http.StatusForbidden) - return - } - - // Trigger the compile action - compileTarget.Compile() - rw.WriteHeader(http.StatusAccepted) - }) - - // Endpoint for domains - r.PUT("/domain/:domain", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { - if !hasPerms(conf.Signer, req, "violet:domains") { - utils.RespondHttpStatus(rw, http.StatusForbidden) - return - } - - // add domain with active state - q := req.URL.Query() - conf.Domains.Put(params.ByName("domain"), q.Get("active") == "1") - conf.Domains.Compile() - }) - r.DELETE("/domain/:domain", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { - if !hasPerms(conf.Signer, req, "violet:domains") { - utils.RespondHttpStatus(rw, http.StatusForbidden) - return - } - - // add domain with active state - q := req.URL.Query() - conf.Domains.Put(params.ByName("domain"), q.Get("active") == "1") - conf.Domains.Compile() - }) - - // Endpoint for routes - r.POST("/route", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { - - }) - - // Endpoint for acme-challenge - r.PUT("/acme-challenge/:domain/:key/:value", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { - if !hasPerms(conf.Signer, req, "violet:acme-challenge") { - utils.RespondHttpStatus(rw, http.StatusForbidden) - return - } - domain := params.ByName("domain") - if !conf.Domains.IsValid(domain) { - utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain") - return - } - conf.Acme.Put(domain, params.ByName("key"), params.ByName("value")) - rw.WriteHeader(http.StatusAccepted) - }) - r.DELETE("/acme-challenge/:domain/:key", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { - if !hasPerms(conf.Signer, req, "violet:acme-challenge") { - utils.RespondHttpStatus(rw, http.StatusForbidden) - return - } - domain := params.ByName("domain") - if !conf.Domains.IsValid(domain) { - utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain") - return - } - conf.Acme.Delete(domain, params.ByName("key")) - rw.WriteHeader(http.StatusAccepted) - }) - - // Create and run http server - return &http.Server{ - Addr: conf.ApiListen, - Handler: r, - ReadTimeout: time.Minute, - ReadHeaderTimeout: time.Minute, - WriteTimeout: time.Minute, - IdleTimeout: time.Minute, - MaxHeaderBytes: 2500, - } -} - -func hasPerms(verify mjwt.Verifier, req *http.Request, perm string) bool { - // Get bearer token - bearer := utils.GetBearer(req) - if bearer == "" { - return false - } - - // Read claims from mjwt - _, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer) - if err != nil { - return false - } - - // Token must have perm - return b.Claims.Perms.Has(perm) -} diff --git a/servers/api/api.go b/servers/api/api.go new file mode 100644 index 0000000..7aa5e20 --- /dev/null +++ b/servers/api/api.go @@ -0,0 +1,102 @@ +package api + +import ( + "encoding/json" + "github.com/MrMelon54/mjwt" + "github.com/MrMelon54/mjwt/claims" + "github.com/MrMelon54/violet/servers/conf" + "github.com/MrMelon54/violet/utils" + "github.com/julienschmidt/httprouter" + "net/http" + "time" +) + +// NewApiServer creates and runs a http server containing all the API +// endpoints for the software +// +// `/compile` - reloads all domains, routes and redirects +func NewApiServer(conf *conf.Conf, compileTarget utils.MultiCompilable) *http.Server { + r := httprouter.New() + + // Endpoint for compile action + r.POST("/compile", checkAuthWithPerm(conf.Signer, "violet:compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, b AuthClaims) { + // Trigger the compile action + compileTarget.Compile() + rw.WriteHeader(http.StatusAccepted) + })) + + // Endpoint for domains + domainFunc := domainManage(conf.Signer, conf.Domains) + r.PUT("/domain/:domain", domainFunc) + r.DELETE("/domain/:domain", domainFunc) + + // Endpoint code for target routes/redirects + targetApis := SetupTargetApis(conf.Signer, conf.Router) + + // Endpoint for routes + r.POST("/route", targetApis.CreateRoute) + r.DELETE("/route", targetApis.DeleteRoute) + + // Endpoint for redirects + r.POST("/redirect", targetApis.CreateRedirect) + r.DELETE("/redirect", targetApis.DeleteRedirect) + + // Endpoint for acme-challenge + acmeChallengeFunc := acmeChallengeManage(conf.Signer, conf.Domains, conf.Acme) + r.PUT("/acme-challenge/:domain/:key/:value", acmeChallengeFunc) + r.DELETE("/acme-challenge/:domain/:key", acmeChallengeFunc) + + // Create and run http server + return &http.Server{ + Addr: conf.ApiListen, + Handler: r, + ReadTimeout: time.Minute, + ReadHeaderTimeout: time.Minute, + WriteTimeout: time.Minute, + IdleTimeout: time.Minute, + MaxHeaderBytes: 2500, + } +} + +// apiError outputs a generic JSON error message +func apiError(rw http.ResponseWriter, code int, m string) { + rw.WriteHeader(code) + _ = json.NewEncoder(rw).Encode(map[string]string{ + "error": m, + }) +} + +func domainManage(verify mjwt.Verifier, domains utils.DomainProvider) httprouter.Handle { + return checkAuthWithPerm(verify, "violet:domains", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) { + // add domain with active state + domains.Put(params.ByName("domain"), req.Method == http.MethodPut) + domains.Compile() + }) +} + +func acmeChallengeManage(verify mjwt.Verifier, domains utils.DomainProvider, acme utils.AcmeChallengeProvider) httprouter.Handle { + return checkAuthWithPerm(verify, "violet:acme-challenge", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) { + domain := params.ByName("domain") + if !domains.IsValid(domain) { + utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain") + return + } + if req.Method == http.MethodPut { + acme.Put(domain, params.ByName("key"), params.ByName("value")) + } else { + acme.Delete(domain, params.ByName("key")) + } + rw.WriteHeader(http.StatusAccepted) + }) +} + +// validateDomainOwnershipClaims validates if the claims contain the +// `owns=` field with the matching top level domain +func validateDomainOwnershipClaims(a string, perms *claims.PermStorage) bool { + if fqdn, ok := utils.GetTopFqdn(a); ok { + if perms.Has("owns=" + fqdn) { + return true + } + } + return false +} diff --git a/servers/api_test.go b/servers/api/api_test.go similarity index 68% rename from servers/api_test.go rename to servers/api/api_test.go index a3c9555..10c4fcc 100644 --- a/servers/api_test.go +++ b/servers/api/api_test.go @@ -1,59 +1,22 @@ -package servers +package api import ( - "crypto/rand" - "crypto/rsa" - "github.com/MrMelon54/mjwt" - "github.com/MrMelon54/mjwt/auth" - "github.com/MrMelon54/mjwt/claims" + "github.com/MrMelon54/violet/servers/conf" "github.com/MrMelon54/violet/utils" + "github.com/MrMelon54/violet/utils/fake" "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" "testing" - "time" ) -var snakeOilProv = genSnakeOilProv() - -type fakeDomains struct{} - -func (f *fakeDomains) IsValid(host string) bool { return host == "example.com" } -func (f *fakeDomains) Put(string, bool) {} -func (f *fakeDomains) Delete(string) {} -func (f *fakeDomains) Compile() {} - -func genSnakeOilProv() mjwt.Signer { - key, err := rsa.GenerateKey(rand.Reader, 1024) - if err != nil { - panic(err) - } - return mjwt.NewMJwtSigner("violet.test", key) -} - -func genSnakeOilKey(perm string) string { - p := claims.NewPermStorage() - p.Set(perm) - val, err := snakeOilProv.GenerateJwt("abc", "abc", nil, 5*time.Minute, auth.AccessTokenClaims{Perms: p}) - if err != nil { - panic(err) - } - return val -} - -type fakeCompilable struct{ done bool } - -func (f *fakeCompilable) Compile() { f.done = true } - -var _ utils.Compilable = &fakeCompilable{} - func TestNewApiServer_Compile(t *testing.T) { - apiConf := &Conf{ - Domains: &fakeDomains{}, + apiConf := &conf.Conf{ + Domains: &fake.Domains{}, Acme: utils.NewAcmeChallenge(), - Signer: snakeOilProv, + Signer: fake.SnakeOilProv, } - f := &fakeCompilable{} + f := &fake.Compilable{} srv := NewApiServer(apiConf, utils.MultiCompilable{f}) req, err := http.NewRequest(http.MethodPost, "https://example.com/compile", nil) @@ -63,25 +26,25 @@ func TestNewApiServer_Compile(t *testing.T) { srv.Handler.ServeHTTP(rec, req) res := rec.Result() assert.Equal(t, http.StatusForbidden, res.StatusCode) - assert.False(t, f.done) + assert.False(t, f.Done) - req.Header.Set("Authorization", "Bearer "+genSnakeOilKey("violet:compile")) + req.Header.Set("Authorization", "Bearer "+fake.GenSnakeOilKey("violet:compile")) rec = httptest.NewRecorder() srv.Handler.ServeHTTP(rec, req) res = rec.Result() assert.Equal(t, http.StatusAccepted, res.StatusCode) - assert.True(t, f.done) + assert.True(t, f.Done) } func TestNewApiServer_AcmeChallenge_Put(t *testing.T) { - apiConf := &Conf{ - Domains: &fakeDomains{}, + apiConf := &conf.Conf{ + Domains: &fake.Domains{}, Acme: utils.NewAcmeChallenge(), - Signer: snakeOilProv, + Signer: fake.SnakeOilProv, } srv := NewApiServer(apiConf, utils.MultiCompilable{}) - acmeKey := genSnakeOilKey("violet:acme-challenge") + acmeKey := fake.GenSnakeOilKey("violet:acme-challenge") // Valid domain req, err := http.NewRequest(http.MethodPut, "https://example.com/acme-challenge/example.com/123/123abc", nil) @@ -119,13 +82,13 @@ func TestNewApiServer_AcmeChallenge_Put(t *testing.T) { } func TestNewApiServer_AcmeChallenge_Delete(t *testing.T) { - apiConf := &Conf{ - Domains: &fakeDomains{}, + apiConf := &conf.Conf{ + Domains: &fake.Domains{}, Acme: utils.NewAcmeChallenge(), - Signer: snakeOilProv, + Signer: fake.SnakeOilProv, } srv := NewApiServer(apiConf, utils.MultiCompilable{}) - acmeKey := genSnakeOilKey("violet:acme-challenge") + acmeKey := fake.GenSnakeOilKey("violet:acme-challenge") // Valid domain req, err := http.NewRequest(http.MethodDelete, "https://example.com/acme-challenge/example.com/123", nil) diff --git a/servers/api/auth.go b/servers/api/auth.go new file mode 100644 index 0000000..4696e9b --- /dev/null +++ b/servers/api/auth.go @@ -0,0 +1,49 @@ +package api + +import ( + "github.com/MrMelon54/mjwt" + "github.com/MrMelon54/mjwt/auth" + "github.com/MrMelon54/violet/utils" + "github.com/julienschmidt/httprouter" + "net/http" +) + +type AuthClaims mjwt.BaseTypeClaims[auth.AccessTokenClaims] + +type AuthCallback func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) + +// checkAuth validates the bearer token against a mjwt.Verifier and returns an +// error message or continues to the next handler +func checkAuth(verify mjwt.Verifier, cb AuthCallback) httprouter.Handle { + return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + // Get bearer token + bearer := utils.GetBearer(req) + if bearer == "" { + apiError(rw, http.StatusForbidden, "Missing bearer token") + return + } + + // Read claims from mjwt + _, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer) + if err != nil { + apiError(rw, http.StatusForbidden, "Invalid token") + return + } + + cb(rw, req, params, AuthClaims(b)) + } +} + +// checkAuthWithPerm validates the bearer token and checks if it contains a +// required permission and returns an error message or continues to the next +// handler +func checkAuthWithPerm(verify mjwt.Verifier, perm string, cb AuthCallback) httprouter.Handle { + return checkAuth(verify, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) { + // check perms + if !b.Claims.Perms.Has(perm) { + apiError(rw, http.StatusForbidden, "No permission") + return + } + cb(rw, req, params, b) + }) +} diff --git a/servers/api/target-types.go b/servers/api/target-types.go new file mode 100644 index 0000000..04a30e0 --- /dev/null +++ b/servers/api/target-types.go @@ -0,0 +1,27 @@ +package api + +import ( + "github.com/MrMelon54/violet/target" +) + +type sourceJson struct { + Src string `json:"src"` +} + +func (s sourceJson) GetSource() string { return s.Src } + +type routeSource target.Route + +func (r routeSource) GetSource() string { return r.Src } + +type redirectSource target.Redirect + +func (r redirectSource) GetSource() string { return r.Src } + +var ( + _ sourceGetter = sourceJson{} + _ sourceGetter = routeSource{} + _ sourceGetter = redirectSource{} +) + +type sourceGetter interface{ GetSource() string } diff --git a/servers/api/target.go b/servers/api/target.go new file mode 100644 index 0000000..f14433f --- /dev/null +++ b/servers/api/target.go @@ -0,0 +1,88 @@ +package api + +import ( + "encoding/json" + "github.com/MrMelon54/mjwt" + "github.com/MrMelon54/violet/router" + "github.com/MrMelon54/violet/target" + "github.com/MrMelon54/violet/utils" + "github.com/julienschmidt/httprouter" + "log" + "net/http" + "strings" +) + +type TargetApis struct { + CreateRoute httprouter.Handle + DeleteRoute httprouter.Handle + CreateRedirect httprouter.Handle + DeleteRedirect httprouter.Handle +} + +func SetupTargetApis(verify mjwt.Verifier, manager *router.Manager) *TargetApis { + r := &TargetApis{ + CreateRoute: parseJsonAndCheckOwnership[routeSource](verify, "route", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t routeSource) { + err := manager.InsertRoute(target.Route(t)) + if err != nil { + log.Printf("[Violet] Failed to insert route into database: %s\n", err) + apiError(rw, http.StatusInternalServerError, "Failed to insert route into database") + return + } + manager.Compile() + }), + DeleteRoute: parseJsonAndCheckOwnership[sourceJson](verify, "route", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t sourceJson) { + err := manager.DeleteRoute(t.Src) + if err != nil { + log.Printf("[Violet] Failed to delete route from database: %s\n", err) + apiError(rw, http.StatusInternalServerError, "Failed to delete route from database") + return + } + manager.Compile() + }), + CreateRedirect: parseJsonAndCheckOwnership[redirectSource](verify, "redirect", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t redirectSource) { + err := manager.InsertRedirect(target.Redirect(t)) + if err != nil { + log.Printf("[Violet] Failed to insert redirect into database: %s\n", err) + apiError(rw, http.StatusInternalServerError, "Failed to insert redirect into database") + return + } + manager.Compile() + }), + DeleteRedirect: parseJsonAndCheckOwnership[sourceJson](verify, "redirect", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t sourceJson) { + err := manager.DeleteRedirect(t.Src) + if err != nil { + log.Printf("[Violet] Failed to delete redirect from database: %s\n", err) + apiError(rw, http.StatusInternalServerError, "Failed to delete redirect from database") + return + } + manager.Compile() + }), + } + return r +} + +type AuthWithJsonCallback[T any] func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t T) + +func parseJsonAndCheckOwnership[T sourceGetter](verify mjwt.Verifier, t string, cb AuthWithJsonCallback[T]) httprouter.Handle { + return checkAuthWithPerm(verify, "violet:"+t, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) { + var j T + if json.NewDecoder(req.Body).Decode(&j) != nil { + apiError(rw, http.StatusBadRequest, "Invalid request body") + return + } + + // check token owns this domain + host, _ := utils.SplitHostPath(j.GetSource()) + if strings.IndexByte(host, ':') != -1 { + apiError(rw, http.StatusBadRequest, "Invalid route source") + return + } + + if !validateDomainOwnershipClaims(host, b.Claims.Perms) { + apiError(rw, http.StatusBadRequest, "Token cannot modify the specified domain") + return + } + + cb(rw, req, params, b, j) + }) +} diff --git a/servers/conf.go b/servers/conf/conf.go similarity index 56% rename from servers/conf.go rename to servers/conf/conf.go index eb39333..69941fb 100644 --- a/servers/conf.go +++ b/servers/conf/conf.go @@ -1,12 +1,12 @@ -package servers +package conf import ( - "crypto/tls" "database/sql" "github.com/MrMelon54/mjwt" errorPages "github.com/MrMelon54/violet/error-pages" "github.com/MrMelon54/violet/favicons" "github.com/MrMelon54/violet/router" + "github.com/MrMelon54/violet/utils" ) // Conf stores the shared configuration for the API, HTTP and HTTPS servers. @@ -16,29 +16,11 @@ type Conf struct { HttpsListen string // https server listen address RateLimit uint64 // rate limit per minute DB *sql.DB - Domains DomainProvider - Acme AcmeChallengeProvider - Certs CertProvider + Domains utils.DomainProvider + Acme utils.AcmeChallengeProvider + Certs utils.CertProvider Favicons *favicons.Favicons Signer mjwt.Verifier ErrorPages *errorPages.ErrorPages Router *router.Manager } - -type DomainProvider interface { - IsValid(host string) bool - Put(domain string, active bool) - Delete(domain string) - Compile() -} - -type AcmeChallengeProvider interface { - Get(domain, key string) string - Put(domain, key, value string) - Delete(domain, key string) -} - -type CertProvider interface { - GetCertForDomain(domain string) *tls.Certificate - Compile() -} diff --git a/servers/http.go b/servers/http.go index fe1fb58..a7166b3 100644 --- a/servers/http.go +++ b/servers/http.go @@ -2,6 +2,7 @@ package servers import ( "fmt" + "github.com/MrMelon54/violet/servers/conf" "github.com/MrMelon54/violet/utils" "github.com/julienschmidt/httprouter" "net/http" @@ -14,7 +15,7 @@ import ( // // `/.well-known/acme-challenge/{token}` is used for outputting answers for // acme challenges, this is used for Let's Encrypt HTTP verification. -func NewHttpServer(conf *Conf) *http.Server { +func NewHttpServer(conf *conf.Conf) *http.Server { r := httprouter.New() var secureExtend string _, httpsPort, ok := utils.SplitDomainPort(conf.HttpsListen, 443) diff --git a/servers/http_test.go b/servers/http_test.go index eaa1fe0..04428c4 100644 --- a/servers/http_test.go +++ b/servers/http_test.go @@ -2,7 +2,9 @@ package servers import ( "bytes" + "github.com/MrMelon54/violet/servers/conf" "github.com/MrMelon54/violet/utils" + "github.com/MrMelon54/violet/utils/fake" "github.com/stretchr/testify/assert" "io" "net/http" @@ -11,10 +13,10 @@ import ( ) func TestNewHttpServer_AcmeChallenge(t *testing.T) { - httpConf := &Conf{ - Domains: &fakeDomains{}, + httpConf := &conf.Conf{ + Domains: &fake.Domains{}, Acme: utils.NewAcmeChallenge(), - Signer: snakeOilProv, + Signer: fake.SnakeOilProv, } srv := NewHttpServer(httpConf) httpConf.Acme.Put("example.com", "456", "456def") diff --git a/servers/https.go b/servers/https.go index 132d1a6..bf29a72 100644 --- a/servers/https.go +++ b/servers/https.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "fmt" "github.com/MrMelon54/violet/favicons" + "github.com/MrMelon54/violet/servers/conf" "github.com/MrMelon54/violet/utils" "github.com/sethvargo/go-limiter/httplimit" "github.com/sethvargo/go-limiter/memorystore" @@ -16,7 +17,7 @@ import ( // NewHttpsServer creates and runs a http server containing the public https // endpoints for the reverse proxy. -func NewHttpsServer(conf *Conf) *http.Server { +func NewHttpsServer(conf *conf.Conf) *http.Server { return &http.Server{ Addr: conf.HttpsListen, Handler: setupRateLimiter(conf.RateLimit, setupFaviconMiddleware(conf.Favicons, conf.Router)), diff --git a/servers/https_test.go b/servers/https_test.go index a6f1d68..69d85c9 100644 --- a/servers/https_test.go +++ b/servers/https_test.go @@ -5,6 +5,8 @@ import ( "github.com/MrMelon54/violet/certs" "github.com/MrMelon54/violet/proxy" "github.com/MrMelon54/violet/router" + "github.com/MrMelon54/violet/servers/conf" + "github.com/MrMelon54/violet/utils/fake" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" "net/http" @@ -26,11 +28,11 @@ func TestNewHttpsServer_RateLimit(t *testing.T) { assert.NoError(t, err) ft := &fakeTransport{} - httpsConf := &Conf{ + httpsConf := &conf.Conf{ RateLimit: 5, - Domains: &fakeDomains{}, + Domains: &fake.Domains{}, Certs: certs.New(nil, nil, true), - Signer: snakeOilProv, + Signer: fake.SnakeOilProv, Router: router.NewManager(db, proxy.NewHybridTransportWithCalls(ft, ft)), } srv := NewHttpsServer(httpsConf) diff --git a/target/flags.go b/target/flags.go new file mode 100644 index 0000000..eccad3c --- /dev/null +++ b/target/flags.go @@ -0,0 +1,41 @@ +package target + +type Flags uint64 + +const ( + FlagPre Flags = 1 << iota + FlagAbs + FlagCors + FlagSecureMode + FlagForwardHost + FlagForwardAddr + FlagIgnoreCert +) + +var ( + routeFlagMask = FlagPre | FlagAbs | FlagCors | FlagSecureMode | FlagForwardHost | FlagForwardAddr | FlagIgnoreCert + redirectFlagMask = FlagPre | FlagAbs +) + +// HasFlag returns true if the bits contain the requested flag +func (f Flags) HasFlag(flag Flags) bool { + // 0110 & 0100 == 0100 (value != 0 thus true) + // 0011 & 0100 == 0000 (value == 0 thus false) + return f&flag != 0 +} + +// NormaliseRouteFlags returns only the bits used for routes +func (f Flags) NormaliseRouteFlags() Flags { + // removes bits outside the mask + // 0110 & 0111 == 0110 + // 1010 & 0111 == 0010 (values are different) + return f & routeFlagMask +} + +// NormaliseRedirectFlags returns only the bits used for redirects +func (f Flags) NormaliseRedirectFlags() Flags { + // removes bits outside the mask + // 0110 & 0111 == 0110 + // 1010 & 0111 == 0010 (values are different) + return f & redirectFlagMask +} diff --git a/target/redirect.go b/target/redirect.go index ac59246..599d696 100644 --- a/target/redirect.go +++ b/target/redirect.go @@ -12,20 +12,14 @@ import ( // Redirect is a target used by the router to manage redirecting the request // using the specified configuration. type Redirect struct { - Pre bool // if the path has had a prefix removed - Host string // target host - Port int // target port - Path string // target path (possibly a prefix or absolute) - Abs bool // if the path is a prefix or absolute - Code int // status code used to redirect + Src string `json:"src"` // request source + Dst string `json:"dst"` // redirect destination + Flags Flags `json:"flags"` // extra flags + Code int `json:"code"` // status code used to redirect } -// FullHost outputs a host:port combo or just the host if the port is 0. -func (r Redirect) FullHost() string { - if r.Port == 0 { - return r.Host - } - return fmt.Sprintf("%s:%d", r.Host, r.Port) +func (r Route) HasFlag(flag Flags) bool { + return r.Flags&flag != 0 } // ServeHTTP responds with the redirect to the response writer provided. @@ -36,10 +30,12 @@ func (r Redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) { code = http.StatusFound } + // split the host and path + host, p := utils.SplitHostPath(r.Dst) + // if not Abs then join with the ending of the current path - p := r.Path - if !r.Abs { - p = path.Join(r.Path, req.URL.Path) + if !r.Flags.HasFlag(FlagAbs) { + p = path.Join(p, req.URL.Path) // replace the trailing slash that path.Join() strips off if strings.HasSuffix(req.URL.Path, "/") { @@ -55,7 +51,7 @@ func (r Redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // create a new URL u := &url.URL{ Scheme: req.URL.Scheme, - Host: r.FullHost(), + Host: host, Path: p, } diff --git a/target/redirect_test.go b/target/redirect_test.go index f686526..eacdb78 100644 --- a/target/redirect_test.go +++ b/target/redirect_test.go @@ -7,18 +7,13 @@ import ( "testing" ) -func TestRedirect_FullHost(t *testing.T) { - assert.Equal(t, "localhost", Redirect{Host: "localhost"}.FullHost()) - assert.Equal(t, "localhost:22", Redirect{Host: "localhost", Port: 22}.FullHost()) -} - func TestRedirect_ServeHTTP(t *testing.T) { a := []struct { Redirect target string }{ - {Redirect{Host: "example.com", Path: "/bye", Abs: true, Code: http.StatusFound}, "https://example.com/bye"}, - {Redirect{Host: "example.com", Path: "/bye", Code: http.StatusFound}, "https://example.com/bye/hello/world"}, + {Redirect{Dst: "example.com/bye", Flags: FlagAbs, Code: http.StatusFound}, "https://example.com/bye"}, + {Redirect{Dst: "example.com/bye", Code: http.StatusFound}, "https://example.com/bye/hello/world"}, } for _, i := range a { res := httptest.NewRecorder() diff --git a/target/route.go b/target/route.go index 3d3ba3a..ff3bae9 100644 --- a/target/route.go +++ b/target/route.go @@ -36,18 +36,11 @@ var serveApiCors = cors.New(cors.Options{ // Route is a target used by the router to manage forwarding traffic to an // internal server using the specified configuration. type Route struct { - Pre bool // if the path has had a prefix removed - Host string // target host - Port int // target port - Path string // target path (possibly a prefix or absolute) - Abs bool // if the path is a prefix or absolute - Cors bool // add CORS headers - SecureMode bool // use HTTPS internally - ForwardHost bool // forward host header internally - ForwardAddr bool // forward remote address - IgnoreCert bool // ignore self-cert - Headers http.Header // extra headers - Proxy *proxy.HybridTransport // reverse proxy handler + Src string `json:"src"` // request source + Dst string `json:"dst"` // proxy destination + Flags Flags `json:"flags"` // extra flags + Headers http.Header `json:"-"` // extra headers + Proxy *proxy.HybridTransport `json:"-"` // reverse proxy handler } // UpdateHeaders takes an existing set of headers and overwrites them with the @@ -58,18 +51,10 @@ func (r Route) UpdateHeaders(header http.Header) { } } -// FullHost outputs a host:port combo or just the host if the port is 0. -func (r Route) FullHost() string { - if r.Port == 0 { - return r.Host - } - return fmt.Sprintf("%s:%d", r.Host, r.Port) -} - // ServeHTTP responds with the data proxied from the internal server to the // response writer provided. func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - if r.Cors { + if r.HasFlag(FlagCors) { // wraps with CORS handler serveApiCors.Handler(http.HandlerFunc(r.internalServeHTTP)).ServeHTTP(rw, req) } else { @@ -82,21 +67,16 @@ func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) { // set the scheme and port using defaults if the port is 0 scheme := "http" - if r.SecureMode { + if r.HasFlag(FlagSecureMode) { scheme = "https" - if r.Port == 0 { - r.Port = 443 - } - } else { - if r.Port == 0 { - r.Port = 80 - } } + // split the host and path + host, p := utils.SplitHostPath(r.Dst) + // if not Abs then join with the ending of the current path - p := r.Path - if !r.Abs { - p = path.Join(r.Path, req.URL.Path) + if !r.HasFlag(FlagAbs) { + p = path.Join(p, req.URL.Path) // replace the trailing slash that path.Join() strips off if strings.HasSuffix(req.URL.Path, "/") { @@ -112,7 +92,7 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) { // create a new URL u := &url.URL{ Scheme: scheme, - Host: r.FullHost(), + Host: host, Path: p, RawQuery: req.URL.RawQuery, } @@ -150,10 +130,10 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) { } // if forward host is enabled then send the host - if r.ForwardHost { + if r.HasFlag(FlagForwardHost) { req2.Host = req.Host } - if r.ForwardAddr { + if r.HasFlag(FlagForwardAddr) { req2.Header.Add("X-Forwarded-For", req.RemoteAddr) } @@ -162,7 +142,7 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) { // serve request with reverse proxy var resp *http.Response - if r.IgnoreCert { + if r.HasFlag(FlagIgnoreCert) { resp, err = r.Proxy.InsecureRoundTrip(req2) } else { resp, err = r.Proxy.SecureRoundTrip(req2) diff --git a/target/route_test.go b/target/route_test.go index 2b08339..3177318 100644 --- a/target/route_test.go +++ b/target/route_test.go @@ -25,9 +25,9 @@ func (p *proxyTester) RoundTrip(req *http.Request) (*http.Response, error) { return &http.Response{StatusCode: http.StatusOK}, nil } -func TestRoute_FullHost(t *testing.T) { - assert.Equal(t, "localhost", Route{Host: "localhost"}.FullHost()) - assert.Equal(t, "localhost:22", Route{Host: "localhost", Port: 22}.FullHost()) +func TestRoute_HasFlag(t *testing.T) { + assert.True(t, Route{Flags: FlagPre | FlagAbs}.HasFlag(FlagPre)) + assert.False(t, Route{Flags: FlagPre | FlagAbs}.HasFlag(FlagCors)) } func TestRoute_ServeHTTP(t *testing.T) { @@ -35,12 +35,12 @@ func TestRoute_ServeHTTP(t *testing.T) { Route target string }{ - {Route{Host: "localhost", Port: 1234, Path: "/bye", Abs: true}, "http://localhost:1234/bye"}, - {Route{Host: "1.2.3.4", Path: "/bye"}, "http://1.2.3.4:80/bye/hello/world"}, - {Route{Host: "2.2.2.2", Path: "/world", Abs: true, SecureMode: true}, "https://2.2.2.2:443/world"}, - {Route{Host: "api.example.com", Path: "/world", Abs: true, SecureMode: true, ForwardHost: true}, "https://api.example.com:443/world"}, - {Route{Host: "api.example.org", Path: "/world", Abs: true, SecureMode: true, ForwardAddr: true}, "https://api.example.org:443/world"}, - {Route{Host: "3.3.3.3", Path: "/headers", Abs: true, Headers: http.Header{"X-Other": []string{"test value"}}}, "http://3.3.3.3:80/headers"}, + {Route{Dst: "localhost:1234/bye", Flags: FlagAbs}, "http://localhost:1234/bye"}, + {Route{Dst: "1.2.3.4/bye"}, "http://1.2.3.4/bye/hello/world"}, + {Route{Dst: "2.2.2.2/world", Flags: FlagAbs | FlagSecureMode}, "https://2.2.2.2/world"}, + {Route{Dst: "api.example.com/world", Flags: FlagAbs | FlagSecureMode | FlagForwardHost}, "https://api.example.com/world"}, + {Route{Dst: "api.example.org/world", Flags: FlagAbs | FlagSecureMode | FlagForwardAddr}, "https://api.example.org/world"}, + {Route{Dst: "3.3.3.3/headers", Flags: FlagAbs, Headers: http.Header{"X-Other": []string{"test value"}}}, "http://3.3.3.3/headers"}, } for _, i := range a { pt := &proxyTester{} @@ -51,10 +51,10 @@ func TestRoute_ServeHTTP(t *testing.T) { assert.True(t, pt.got) assert.Equal(t, i.target, pt.req.URL.String()) - if i.ForwardAddr { + if i.HasFlag(FlagForwardAddr) { assert.Equal(t, req.RemoteAddr, pt.req.Header.Get("X-Forwarded-For")) } - if i.ForwardHost { + if i.HasFlag(FlagForwardHost) { assert.Equal(t, req.Host, pt.req.Host) } if i.Headers != nil { @@ -68,7 +68,7 @@ func TestRoute_ServeHTTP_Cors(t *testing.T) { res := httptest.NewRecorder() req := httptest.NewRequest(http.MethodOptions, "https://www.example.com/test", nil) req.Header.Set("Origin", "https://test.example.com") - i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt.makeHybridTransport()} + i := &Route{Dst: "1.1.1.1:8080/hello", Flags: FlagCors, Proxy: pt.makeHybridTransport()} i.ServeHTTP(res, req) assert.True(t, pt.got) @@ -86,7 +86,7 @@ func TestRoute_ServeHTTP_Body(t *testing.T) { buf := bytes.NewBuffer([]byte{0x54}) req := httptest.NewRequest(http.MethodPost, "https://www.example.com/test", buf) req.Header.Set("Origin", "https://test.example.com") - i := &Route{Host: "1.1.1.1", Port: 8080, Path: "/hello", Cors: true, Proxy: pt.makeHybridTransport()} + i := &Route{Dst: "1.1.1.1:8080/hello", Flags: FlagCors, Proxy: pt.makeHybridTransport()} i.ServeHTTP(res, req) assert.True(t, pt.got) diff --git a/utils/domain-utils.go b/utils/domain-utils.go index acedb6f..762a1ee 100644 --- a/utils/domain-utils.go +++ b/utils/domain-utils.go @@ -83,3 +83,38 @@ func GetTopFqdn(domain string) (string, bool) { } return domain[n+1:], true } + +// SplitHostPath extracts the host/path from the input +func SplitHostPath(a string) (host, path string) { + // check if source has path + n := strings.IndexByte(a, '/') + if n == -1 { + // set host then path to / + host = a + path = "/" + } else { + // set host then custom path + host = a[:n] + path = a[n:] // this required to keep / at the start of the path + } + return +} + +// SplitHostPathQuery extracts the host/path?query from the input +func SplitHostPathQuery(a string) (host, path, query string) { + host, path = SplitHostPath(a) + if path == "/" { + n := strings.IndexByte(host, '?') + if n != -1 { + query = host[n+1:] + host = host[:n] + } + return + } + n := strings.IndexByte(path, '?') + if n != -1 { + query = path[n+1:] + path = path[:n] // reassign happens after + } + return +} diff --git a/utils/domain-utils_test.go b/utils/domain-utils_test.go index 63f51e6..a058d7e 100644 --- a/utils/domain-utils_test.go +++ b/utils/domain-utils_test.go @@ -60,3 +60,40 @@ func TestGetTopFqdn(t *testing.T) { assert.True(t, ok, "Output should be true") assert.Equal(t, "example.com", domain) } + +func TestSplitHostPath(t *testing.T) { + h, p := SplitHostPath("example.com/hello/world") + assert.Equal(t, "example.com", h) + assert.Equal(t, "/hello/world", p) + + h, p = SplitHostPath("example.com") + assert.Equal(t, "example.com", h) + assert.Equal(t, "/", p) +} + +func TestSplitHostPathQuery(t *testing.T) { + h, p, q := SplitHostPathQuery("example.com/hello/world") + assert.Equal(t, "example.com", h) + assert.Equal(t, "/hello/world", p) + assert.Equal(t, "", q) + + h, p, q = SplitHostPathQuery("example.com") + assert.Equal(t, "example.com", h) + assert.Equal(t, "/", p) + assert.Equal(t, "", q) + + h, p, q = SplitHostPathQuery("example.com/hello/world?a=b") + assert.Equal(t, "example.com", h) + assert.Equal(t, "/hello/world", p) + assert.Equal(t, "a=b", q) + + h, p, q = SplitHostPathQuery("example.com?a=b") + assert.Equal(t, "example.com", h) + assert.Equal(t, "/", p) + assert.Equal(t, "a=b", q) + + h, p, q = SplitHostPathQuery("example.com/?a=b") + assert.Equal(t, "example.com", h) + assert.Equal(t, "/", p) + assert.Equal(t, "a=b", q) +} diff --git a/utils/fake/fake-compilable.go b/utils/fake/fake-compilable.go new file mode 100644 index 0000000..8dfd9f8 --- /dev/null +++ b/utils/fake/fake-compilable.go @@ -0,0 +1,11 @@ +package fake + +import "github.com/MrMelon54/violet/utils" + +// Compilable implements utils.Compilable and stores if the Compile function +// is called. +type Compilable struct{ Done bool } + +func (f *Compilable) Compile() { f.Done = true } + +var _ utils.Compilable = &Compilable{} diff --git a/utils/fake/fake-domains.go b/utils/fake/fake-domains.go new file mode 100644 index 0000000..5dee5ae --- /dev/null +++ b/utils/fake/fake-domains.go @@ -0,0 +1,13 @@ +package fake + +import "github.com/MrMelon54/violet/utils" + +// Domains implements DomainProvider and makes sure `example.com` is valid +type Domains struct{} + +func (f *Domains) IsValid(host string) bool { return host == "example.com" } +func (f *Domains) Put(string, bool) {} +func (f *Domains) Delete(string) {} +func (f *Domains) Compile() {} + +var _ utils.DomainProvider = &Domains{} diff --git a/utils/fake/fake.go b/utils/fake/fake.go new file mode 100644 index 0000000..ff7d53e --- /dev/null +++ b/utils/fake/fake.go @@ -0,0 +1,2 @@ +// Package fake contains fake structs used during tests +package fake diff --git a/utils/fake/mjwt.go b/utils/fake/mjwt.go new file mode 100644 index 0000000..03a5df4 --- /dev/null +++ b/utils/fake/mjwt.go @@ -0,0 +1,30 @@ +package fake + +import ( + "crypto/rand" + "crypto/rsa" + "github.com/MrMelon54/mjwt" + "github.com/MrMelon54/mjwt/auth" + "github.com/MrMelon54/mjwt/claims" + "time" +) + +var SnakeOilProv = GenSnakeOilProv() + +func GenSnakeOilProv() mjwt.Signer { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + panic(err) + } + return mjwt.NewMJwtSigner("violet.test", key) +} + +func GenSnakeOilKey(perm string) string { + p := claims.NewPermStorage() + p.Set(perm) + val, err := SnakeOilProv.GenerateJwt("abc", "abc", nil, 5*time.Minute, auth.AccessTokenClaims{Perms: p}) + if err != nil { + panic(err) + } + return val +} diff --git a/utils/interfaces.go b/utils/interfaces.go new file mode 100644 index 0000000..749879d --- /dev/null +++ b/utils/interfaces.go @@ -0,0 +1,21 @@ +package utils + +import "crypto/tls" + +type DomainProvider interface { + IsValid(host string) bool + Put(domain string, active bool) + Delete(domain string) + Compile() +} + +type AcmeChallengeProvider interface { + Get(domain, key string) string + Put(domain, key, value string) + Delete(domain, key string) +} + +type CertProvider interface { + GetCertForDomain(domain string) *tls.Certificate + Compile() +}