From 0e42e54f08ef042bb75aa9efa6b2f43a3b58e3d1 Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Fri, 21 Apr 2023 03:21:46 +0100 Subject: [PATCH] Start writing main function --- .idea/sqldialects.xml | 7 +++ benchmarks/router_test.go | 2 +- cmd/violet/init.sql | 1 + cmd/violet/main.go | 56 +++++++++++++++++- domains/domains.go | 36 ++++++++++++ go.mod | 7 +++ go.sum | 14 +++++ router/router.go | 6 +- router/router_test.go | 2 +- servers/api.go | 59 +++++++++++++++++++ servers/http.go | 26 +++++++++ servers/https.go | 1 + target/route.go | 114 +++++++++++++++++++++++++++++++++++-- utils/domain-utils.go | 61 ++++++++++++++++++++ utils/domain-utils_test.go | 58 +++++++++++++++++++ utils/multi-compilable.go | 13 +++++ utils/response.go | 10 ++++ utils/server-utils.go | 33 +++++++++++ 18 files changed, 495 insertions(+), 11 deletions(-) create mode 100644 .idea/sqldialects.xml create mode 100644 cmd/violet/init.sql create mode 100644 domains/domains.go create mode 100644 servers/api.go create mode 100644 servers/http.go create mode 100644 servers/https.go create mode 100644 utils/domain-utils.go create mode 100644 utils/domain-utils_test.go create mode 100644 utils/multi-compilable.go create mode 100644 utils/response.go create mode 100644 utils/server-utils.go diff --git a/.idea/sqldialects.xml b/.idea/sqldialects.xml new file mode 100644 index 0000000..ec81629 --- /dev/null +++ b/.idea/sqldialects.xml @@ -0,0 +1,7 @@ + + + + + + + \ No newline at end of file diff --git a/benchmarks/router_test.go b/benchmarks/router_test.go index b8d089f..61d68c8 100644 --- a/benchmarks/router_test.go +++ b/benchmarks/router_test.go @@ -22,7 +22,7 @@ func benchRequest(b *testing.B, router http.Handler, r *http.Request) { } func BenchmarkVioletRouter(b *testing.B) { - r := router.New() + r := router.New(nil) r.AddRedirect("*.example.com", "", target.Redirect{ Pre: true, Host: "example.com", diff --git a/cmd/violet/init.sql b/cmd/violet/init.sql new file mode 100644 index 0000000..4cc4215 --- /dev/null +++ b/cmd/violet/init.sql @@ -0,0 +1 @@ +create table acme_challenge (id integer not null primary key, key varchar, value varchar); diff --git a/cmd/violet/main.go b/cmd/violet/main.go index ace8a1e..264545f 100644 --- a/cmd/violet/main.go +++ b/cmd/violet/main.go @@ -1,5 +1,59 @@ -package violet +package main + +import ( + "database/sql" + _ "embed" + "errors" + "flag" + "github.com/MrMelon54/violet/domains" + "github.com/MrMelon54/violet/proxy" + "github.com/MrMelon54/violet/router" + "github.com/MrMelon54/violet/servers" + "github.com/MrMelon54/violet/utils" + _ "github.com/mattn/go-sqlite3" + "log" + "os" +) + +//go:embed init.sql +var initSql string + +var ( + databasePath = flag.String("db", "", "/path/to/database.sqlite") + certPath = flag.String("cert", "", "/path/to/certificates") + apiListen = flag.String("api", "127.0.0.1:8080", "address for api listening") + httpListen = flag.String("http", "0.0.0.0:80", "address for http listening") + httpsListen = flag.String("https", "0.0.0.0:443", "address for https listening") +) func main() { + log.Println("[Violet] Starting...") + _, err := os.Stat(*certPath) + if errors.Is(err, os.ErrNotExist) { + log.Fatalf("[Violet] Certificate path '%s' does not exists", *certPath) + } + + _, err = os.Stat(*databasePath) + dbExists := !errors.Is(err, os.ErrNotExist) + + db, err := sql.Open("sqlite3", *databasePath) + if err != nil { + log.Fatalf("[Violet] Failed to open database '%s'...", *databasePath) + } + + if !dbExists { + log.Println("[Violet] Creating new database and running init.sql") + _, err = db.Exec(initSql) + if err != nil { + log.Fatalf("[Violet] Failed to run init.sql") + } + } + + allowedDomains := domains.New() + reverseProxy := proxy.CreateHybridReverseProxy() + r := router.New(reverseProxy) + + servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{}) + servers.NewHttpServer(*httpListen, 0, allowedDomains) } diff --git a/domains/domains.go b/domains/domains.go new file mode 100644 index 0000000..5029276 --- /dev/null +++ b/domains/domains.go @@ -0,0 +1,36 @@ +package domains + +import ( + "github.com/MrMelon54/violet/utils" + "strings" + "sync" +) + +type Domains struct { + s *sync.RWMutex + m map[string]struct{} +} + +func New() *Domains { + return &Domains{ + s: &sync.RWMutex{}, + m: make(map[string]struct{}), + } +} + +func (d *Domains) IsValid(host string) bool { + domain, ok := utils.GetDomainWithoutPort(host) + if !ok { + return false + } + d.s.RLock() + defer d.s.RUnlock() + + n := strings.Split(domain, ".") + for i := 0; i < len(n); i++ { + if _, ok := d.m[strings.Join(n[i:], ".")]; ok { + return true + } + } + return false +} diff --git a/go.mod b/go.mod index dad864c..8472289 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,20 @@ module github.com/MrMelon54/violet go 1.20 require ( + code.mrmelon54.com/melon/summer-utils v0.0.3 github.com/MrMelon54/trie v0.0.2 + github.com/julienschmidt/httprouter v1.3.0 + github.com/mattn/go-sqlite3 v1.14.16 + github.com/mrmelon54/mjwt v0.0.1 + github.com/rs/cors v1.9.0 github.com/stretchr/testify v1.8.2 golang.org/x/net v0.9.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/golang-jwt/jwt/v4 v4.5.0 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index c8c4a50..c9a78f3 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,24 @@ +code.mrmelon54.com/melon/summer-utils v0.0.3 h1:Bz4o5BBOqWCNGpKkxUum4rwMn/DIdyMCKGQ/D6SXD6Q= +code.mrmelon54.com/melon/summer-utils v0.0.3/go.mod h1:Gh/baXSzkf1ZhHonpPP8oQkyhhmFZcC2yTMlrwclDUw= github.com/MrMelon54/trie v0.0.2 h1:ZXWcX5ij62O9K4I/anuHmVg8L3tF0UGdlPceAASwKEY= github.com/MrMelon54/trie v0.0.2/go.mod h1:sGCGOcqb+DxSxvHgSOpbpkmA7mFZR47YDExy9OCbVZI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= +github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= +github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= +github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mrmelon54/mjwt v0.0.1 h1:XgyWviTmgsbMiKXjxo+Jp/QSf7FF7/omkvrUag8/P5U= +github.com/mrmelon54/mjwt v0.0.1/go.mod h1:M+kZ6t9EArEQ2/CGjfgyNhAo542ot+S7gw5uJCK11Ms= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE= +github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/router/router.go b/router/router.go index 165433e..de9713e 100644 --- a/router/router.go +++ b/router/router.go @@ -16,13 +16,14 @@ type Router struct { proxy *httputil.ReverseProxy } -func New() *Router { +func New(proxy *httputil.ReverseProxy) *Router { return &Router{ route: make(map[string]*trie.Trie[target.Route]), redirect: make(map[string]*trie.Trie[target.Redirect]), notFound: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { _, _ = fmt.Fprintf(rw, "%d %s\n", http.StatusNotFound, http.StatusText(http.StatusNotFound)) }), + proxy: proxy, } } @@ -45,10 +46,11 @@ func (r *Router) hostRedirect(host string) *trie.Trie[target.Redirect] { } func (r *Router) AddService(host string, t target.Route) { - r.AddRoute(host, "", t) + r.AddRoute(host, "/", t) } func (r *Router) AddRoute(host string, path string, t target.Route) { + t.Proxy = r.proxy r.hostRoute(host).PutString(path, t) } diff --git a/router/router_test.go b/router/router_test.go index c4dabf0..03a68ae 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -113,7 +113,7 @@ func assertHttpRedirect(t *testing.T, r *Router, code int, target, method, start func TestRouter_AddRedirect(t *testing.T) { for _, i := range redirectTests { - r := New() + r := New(nil) dst := i.dst dst.Host = "example.com" dst.Code = http.StatusFound diff --git a/servers/api.go b/servers/api.go new file mode 100644 index 0000000..7ef4e89 --- /dev/null +++ b/servers/api.go @@ -0,0 +1,59 @@ +package servers + +import ( + "code.mrmelon54.com/melon/summer-utils/claims/auth" + "github.com/MrMelon54/violet/utils" + "github.com/julienschmidt/httprouter" + "github.com/mrmelon54/mjwt" + "log" + "net/http" + "time" +) + +// NewApiServer creates and runs a *http.Server containing all the API endpoints for the software +// +// `/compile` - reloads all domains, routes and redirects from the configuration files +func NewApiServer(listen string, verify mjwt.Provider, compileTarget utils.MultiCompilable) *http.Server { + r := httprouter.New() + + // Endpoint `/compile` reloads all domains, routes and redirects from the configuration files + r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { + // Get bearer token + bearer := utils.GetBearer(req) + if bearer == "" { + utils.RespondHttpStatus(rw, http.StatusForbidden) + return + } + + // Read claims from mjwt + _, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer) + if err != nil { + utils.RespondHttpStatus(rw, http.StatusForbidden) + return + } + + // Token must have `violet:compile` perm + if !b.Claims.Perms.Has("violet:compile") { + utils.RespondHttpStatus(rw, http.StatusForbidden) + return + } + + // Trigger the compile action + compileTarget.Compile() + rw.WriteHeader(http.StatusAccepted) + }) + + // Create and run http server + s := &http.Server{ + Addr: listen, + Handler: r, + ReadTimeout: time.Minute, + ReadHeaderTimeout: time.Minute, + WriteTimeout: time.Minute, + IdleTimeout: time.Minute, + MaxHeaderBytes: 2500, + } + log.Printf("[API] Starting API server on: '%s'\n", s.Addr) + go utils.RunBackgroundHttp("API", s) + return s +} diff --git a/servers/http.go b/servers/http.go new file mode 100644 index 0000000..23e522d --- /dev/null +++ b/servers/http.go @@ -0,0 +1,26 @@ +package servers + +import ( + "fmt" + "github.com/MrMelon54/violet/domains" + "github.com/MrMelon54/violet/utils" + "github.com/julienschmidt/httprouter" + "net/http" +) + +func NewHttpServer(listen string, httpsPort uint16, domainCheck *domains.Domains) *http.Server { + r := httprouter.New() + r.GET("/.well-known/acme-challenge/{token}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + if hostname, ok := utils.GetDomainWithoutPort(req.Host); ok { + if !domainCheck.IsValid(req.Host) { + http.Error(rw, fmt.Sprintf("%d %s\n", 420, "Invalid host"), 420) + return + } + if tokenValue := params.ByName("token"); tokenValue != "" { + rw.WriteHeader(http.StatusOK) + return + } + } + rw.WriteHeader(http.StatusNotFound) + }) +} diff --git a/servers/https.go b/servers/https.go new file mode 100644 index 0000000..84c4cc0 --- /dev/null +++ b/servers/https.go @@ -0,0 +1 @@ +package servers diff --git a/target/route.go b/target/route.go index b0f7e75..eaa8043 100644 --- a/target/route.go +++ b/target/route.go @@ -1,17 +1,119 @@ package target import ( + "bytes" + "fmt" + "github.com/MrMelon54/violet/proxy" + "github.com/MrMelon54/violet/utils" + "github.com/rs/cors" + "io" + "log" "net/http" + "net/url" + "path" ) +var serveApiCors = cors.New(cors.Options{ + AllowedOrigins: []string{"*"}, + AllowedHeaders: []string{"Content-Type", "Authorization"}, + AllowedMethods: []string{ + http.MethodGet, + http.MethodHead, + http.MethodPost, + http.MethodPut, + http.MethodPatch, + http.MethodDelete, + http.MethodConnect, + http.MethodOptions, + http.MethodTrace, + }, + AllowCredentials: true, +}) + type Route struct { - Pre bool - Host string - Port int - Path string - Abs bool + Pre bool + Host string + Port int + Path string + Abs bool + Cors bool + SecureMode bool + ForwardHost bool + IgnoreCert bool + Headers http.Header + Proxy http.Handler +} + +func (r Route) IsIgnoreCert() bool { return r.IgnoreCert } + +func (r Route) UpdateHeaders(header http.Header) { + for k, v := range r.Headers { + header[k] = v + } +} + +func (r Route) FullHost() string { + if r.Port == 0 { + return r.Host + } + return fmt.Sprintf("%s:%d", r.Host, r.Port) } func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - // pass + if r.Cors { + serveApiCors.Handler(http.HandlerFunc(r.internalServeHTTP)).ServeHTTP(rw, req) + } else { + r.internalServeHTTP(rw, req) + } +} + +func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) { + scheme := "http" + if r.SecureMode { + scheme = "https" + if r.Port == 0 { + r.Port = 443 + } + } else { + if r.Port == 0 { + r.Port = 80 + } + } + + p := r.Path + if !r.Abs { + p = path.Join(r.Path, req.URL.Path) + } + + if p == "" { + p = "/" + } + + buf := new(bytes.Buffer) + if req.Body != nil { + _, _ = io.Copy(buf, req.Body) + } + + u := &url.URL{ + Scheme: scheme, + Host: r.FullHost(), + Path: p, + RawQuery: req.URL.RawQuery, + } + req2, err := http.NewRequest(req.Method, u.String(), buf) + if err != nil { + log.Printf("[ServeRoute::ServeHTTP()] Error generating new request: %s\n", err) + utils.RespondHttpStatus(rw, http.StatusBadGateway) + return + } + for k, v := range req.Header { + if k == "Host" { + continue + } + req2.Header[k] = v + } + if r.ForwardHost { + req2.Host = req.Host + } + r.Proxy.ServeHTTP(rw, proxy.SetReverseProxyHost(req2, r)) } diff --git a/utils/domain-utils.go b/utils/domain-utils.go new file mode 100644 index 0000000..43d7d08 --- /dev/null +++ b/utils/domain-utils.go @@ -0,0 +1,61 @@ +package utils + +import ( + "fmt" + "strings" +) + +func SplitDomainPort(host string, defaultPort uint16) (domain string, port uint16, ok bool) { + a := strings.SplitN(host, ":", 2) + switch len(a) { + case 2: + domain = a[0] + _, err := fmt.Sscanf(a[1], "%d", &port) + ok = err == nil + case 1: + domain = a[0] + port = defaultPort + ok = true + } + return +} + +func GetDomainWithoutPort(domain string) (string, bool) { + a := strings.SplitN(domain, ":", 2) + if len(a) == 2 { + return a[0], true + } + if len(a) == 0 { + return "", false + } + return a[0], true +} + +func ReplaceSubdomainWithWildcard(domain string) (string, bool) { + a, b := GetBaseDomain(domain) + return "*." + a, b +} + +func GetBaseDomain(domain string) (string, bool) { + a := strings.SplitN(domain, ".", 2) + l := len(a) + if l == 2 { + return a[1], true + } + if l == 1 { + return a[0], true + } + return "", false +} + +func GetTopFqdn(domain string) (string, bool) { + a := strings.Split(domain, ".") + l := len(a) + if l >= 2 { + return strings.Join(a[l-2:], "."), true + } + if l == 1 { + return domain, true + } + return "", false +} diff --git a/utils/domain-utils_test.go b/utils/domain-utils_test.go new file mode 100644 index 0000000..a73609f --- /dev/null +++ b/utils/domain-utils_test.go @@ -0,0 +1,58 @@ +package utils + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestSplitDomainPort(t *testing.T) { + domain, port, ok := SplitDomainPort("www.example.com:5612", 443) + assert.True(t, ok, "Output should be true") + assert.Equal(t, "www.example.com", domain) + assert.Equal(t, uint16(5612), port) + + domain, port, ok = SplitDomainPort("example.com", 443) + assert.True(t, ok, "Output should be true") + assert.Equal(t, "example.com", domain) + assert.Equal(t, uint16(443), port) +} + +func TestDomainWithoutPort(t *testing.T) { + domain, ok := GetDomainWithoutPort("www.example.com:5612") + assert.True(t, ok, "Output should be true") + assert.Equal(t, "www.example.com", domain) + + domain, ok = GetDomainWithoutPort("example.com:443") + assert.True(t, ok, "Output should be true") + assert.Equal(t, "example.com", domain) +} + +func TestReplaceSubdomainWithWildcard(t *testing.T) { + domain, ok := ReplaceSubdomainWithWildcard("www.example.com") + assert.True(t, ok, "Output should be true") + assert.Equal(t, "*.example.com", domain) + + domain, ok = ReplaceSubdomainWithWildcard("www.example.com:5612") + assert.True(t, ok, "Output should be true") + assert.Equal(t, "*.example.com:5612", domain) +} + +func TestGetBaseDomain(t *testing.T) { + domain, ok := GetBaseDomain("www.example.com") + assert.True(t, ok, "Output should be true") + assert.Equal(t, "example.com", domain) + + domain, ok = GetBaseDomain("www.example.com:5612") + assert.True(t, ok, "Output should be true") + assert.Equal(t, "example.com:5612", domain) +} + +func TestGetTopFqdn(t *testing.T) { + domain, ok := GetTopFqdn("www.example.com") + assert.True(t, ok, "Output should be true") + assert.Equal(t, "example.com", domain) + + domain, ok = GetTopFqdn("www.www.example.com") + assert.True(t, ok, "Output should be true") + assert.Equal(t, "example.com", domain) +} diff --git a/utils/multi-compilable.go b/utils/multi-compilable.go new file mode 100644 index 0000000..ec76a8f --- /dev/null +++ b/utils/multi-compilable.go @@ -0,0 +1,13 @@ +package utils + +type Compilable interface { + Compile() +} + +type MultiCompilable []Compilable + +func (m MultiCompilable) Compile() { + for _, i := range m { + i.Compile() + } +} diff --git a/utils/response.go b/utils/response.go new file mode 100644 index 0000000..cfdabc6 --- /dev/null +++ b/utils/response.go @@ -0,0 +1,10 @@ +package utils + +import ( + "fmt" + "net/http" +) + +func RespondHttpStatus(rw http.ResponseWriter, status int) { + http.Error(rw, fmt.Sprintf("%d %s\n", status, http.StatusText(status)), status) +} diff --git a/utils/server-utils.go b/utils/server-utils.go new file mode 100644 index 0000000..30f4ca6 --- /dev/null +++ b/utils/server-utils.go @@ -0,0 +1,33 @@ +package utils + +import ( + "log" + "net/http" + "strings" +) + +func logHttpServerError(prefix string, err error) { + if err != nil { + if err == http.ErrServerClosed { + log.Printf("[%s] The http server shutdown successfully\n", prefix) + } else { + log.Printf("[%s] Error trying to host the http server: %s\n", prefix, err.Error()) + } + } +} + +func RunBackgroundHttp(prefix string, s *http.Server) { + logHttpServerError(prefix, s.ListenAndServe()) +} + +func RunBackgroundHttps(prefix string, s *http.Server) { + logHttpServerError(prefix, s.ListenAndServeTLS("", "")) +} + +func GetBearer(req *http.Request) string { + a := req.Header.Get("Authorization") + if t, ok := strings.CutPrefix(a, "Bearer "); ok { + return t + } + return "" +}