Start writing main function

This commit is contained in:
Melon 2023-04-21 03:21:46 +01:00
parent d4aed095ec
commit 0e42e54f08
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
18 changed files with 495 additions and 11 deletions

7
.idea/sqldialects.xml Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="SqlDialectMappings">
<file url="file://$PROJECT_DIR$/cmd/violet/init.sql" dialect="GenericSQL" />
<file url="PROJECT" dialect="SQLite" />
</component>
</project>

View File

@ -22,7 +22,7 @@ func benchRequest(b *testing.B, router http.Handler, r *http.Request) {
}
func BenchmarkVioletRouter(b *testing.B) {
r := router.New()
r := router.New(nil)
r.AddRedirect("*.example.com", "", target.Redirect{
Pre: true,
Host: "example.com",

1
cmd/violet/init.sql Normal file
View File

@ -0,0 +1 @@
create table acme_challenge (id integer not null primary key, key varchar, value varchar);

View File

@ -1,5 +1,59 @@
package violet
package main
import (
"database/sql"
_ "embed"
"errors"
"flag"
"github.com/MrMelon54/violet/domains"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/router"
"github.com/MrMelon54/violet/servers"
"github.com/MrMelon54/violet/utils"
_ "github.com/mattn/go-sqlite3"
"log"
"os"
)
//go:embed init.sql
var initSql string
var (
databasePath = flag.String("db", "", "/path/to/database.sqlite")
certPath = flag.String("cert", "", "/path/to/certificates")
apiListen = flag.String("api", "127.0.0.1:8080", "address for api listening")
httpListen = flag.String("http", "0.0.0.0:80", "address for http listening")
httpsListen = flag.String("https", "0.0.0.0:443", "address for https listening")
)
func main() {
log.Println("[Violet] Starting...")
_, err := os.Stat(*certPath)
if errors.Is(err, os.ErrNotExist) {
log.Fatalf("[Violet] Certificate path '%s' does not exists", *certPath)
}
_, err = os.Stat(*databasePath)
dbExists := !errors.Is(err, os.ErrNotExist)
db, err := sql.Open("sqlite3", *databasePath)
if err != nil {
log.Fatalf("[Violet] Failed to open database '%s'...", *databasePath)
}
if !dbExists {
log.Println("[Violet] Creating new database and running init.sql")
_, err = db.Exec(initSql)
if err != nil {
log.Fatalf("[Violet] Failed to run init.sql")
}
}
allowedDomains := domains.New()
reverseProxy := proxy.CreateHybridReverseProxy()
r := router.New(reverseProxy)
servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{})
servers.NewHttpServer(*httpListen, 0, allowedDomains)
}

36
domains/domains.go Normal file
View File

@ -0,0 +1,36 @@
package domains
import (
"github.com/MrMelon54/violet/utils"
"strings"
"sync"
)
type Domains struct {
s *sync.RWMutex
m map[string]struct{}
}
func New() *Domains {
return &Domains{
s: &sync.RWMutex{},
m: make(map[string]struct{}),
}
}
func (d *Domains) IsValid(host string) bool {
domain, ok := utils.GetDomainWithoutPort(host)
if !ok {
return false
}
d.s.RLock()
defer d.s.RUnlock()
n := strings.Split(domain, ".")
for i := 0; i < len(n); i++ {
if _, ok := d.m[strings.Join(n[i:], ".")]; ok {
return true
}
}
return false
}

7
go.mod
View File

@ -3,13 +3,20 @@ module github.com/MrMelon54/violet
go 1.20
require (
code.mrmelon54.com/melon/summer-utils v0.0.3
github.com/MrMelon54/trie v0.0.2
github.com/julienschmidt/httprouter v1.3.0
github.com/mattn/go-sqlite3 v1.14.16
github.com/mrmelon54/mjwt v0.0.1
github.com/rs/cors v1.9.0
github.com/stretchr/testify v1.8.2
golang.org/x/net v0.9.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/golang-jwt/jwt/v4 v4.5.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

14
go.sum
View File

@ -1,10 +1,24 @@
code.mrmelon54.com/melon/summer-utils v0.0.3 h1:Bz4o5BBOqWCNGpKkxUum4rwMn/DIdyMCKGQ/D6SXD6Q=
code.mrmelon54.com/melon/summer-utils v0.0.3/go.mod h1:Gh/baXSzkf1ZhHonpPP8oQkyhhmFZcC2yTMlrwclDUw=
github.com/MrMelon54/trie v0.0.2 h1:ZXWcX5ij62O9K4I/anuHmVg8L3tF0UGdlPceAASwKEY=
github.com/MrMelon54/trie v0.0.2/go.mod h1:sGCGOcqb+DxSxvHgSOpbpkmA7mFZR47YDExy9OCbVZI=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mrmelon54/mjwt v0.0.1 h1:XgyWviTmgsbMiKXjxo+Jp/QSf7FF7/omkvrUag8/P5U=
github.com/mrmelon54/mjwt v0.0.1/go.mod h1:M+kZ6t9EArEQ2/CGjfgyNhAo542ot+S7gw5uJCK11Ms=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE=
github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=

View File

@ -16,13 +16,14 @@ type Router struct {
proxy *httputil.ReverseProxy
}
func New() *Router {
func New(proxy *httputil.ReverseProxy) *Router {
return &Router{
route: make(map[string]*trie.Trie[target.Route]),
redirect: make(map[string]*trie.Trie[target.Redirect]),
notFound: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
_, _ = fmt.Fprintf(rw, "%d %s\n", http.StatusNotFound, http.StatusText(http.StatusNotFound))
}),
proxy: proxy,
}
}
@ -45,10 +46,11 @@ func (r *Router) hostRedirect(host string) *trie.Trie[target.Redirect] {
}
func (r *Router) AddService(host string, t target.Route) {
r.AddRoute(host, "", t)
r.AddRoute(host, "/", t)
}
func (r *Router) AddRoute(host string, path string, t target.Route) {
t.Proxy = r.proxy
r.hostRoute(host).PutString(path, t)
}

View File

@ -113,7 +113,7 @@ func assertHttpRedirect(t *testing.T, r *Router, code int, target, method, start
func TestRouter_AddRedirect(t *testing.T) {
for _, i := range redirectTests {
r := New()
r := New(nil)
dst := i.dst
dst.Host = "example.com"
dst.Code = http.StatusFound

59
servers/api.go Normal file
View File

@ -0,0 +1,59 @@
package servers
import (
"code.mrmelon54.com/melon/summer-utils/claims/auth"
"github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter"
"github.com/mrmelon54/mjwt"
"log"
"net/http"
"time"
)
// NewApiServer creates and runs a *http.Server containing all the API endpoints for the software
//
// `/compile` - reloads all domains, routes and redirects from the configuration files
func NewApiServer(listen string, verify mjwt.Provider, compileTarget utils.MultiCompilable) *http.Server {
r := httprouter.New()
// Endpoint `/compile` reloads all domains, routes and redirects from the configuration files
r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
// Get bearer token
bearer := utils.GetBearer(req)
if bearer == "" {
utils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
// Read claims from mjwt
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer)
if err != nil {
utils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
// Token must have `violet:compile` perm
if !b.Claims.Perms.Has("violet:compile") {
utils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
// Trigger the compile action
compileTarget.Compile()
rw.WriteHeader(http.StatusAccepted)
})
// Create and run http server
s := &http.Server{
Addr: listen,
Handler: r,
ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute,
WriteTimeout: time.Minute,
IdleTimeout: time.Minute,
MaxHeaderBytes: 2500,
}
log.Printf("[API] Starting API server on: '%s'\n", s.Addr)
go utils.RunBackgroundHttp("API", s)
return s
}

26
servers/http.go Normal file
View File

@ -0,0 +1,26 @@
package servers
import (
"fmt"
"github.com/MrMelon54/violet/domains"
"github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter"
"net/http"
)
func NewHttpServer(listen string, httpsPort uint16, domainCheck *domains.Domains) *http.Server {
r := httprouter.New()
r.GET("/.well-known/acme-challenge/{token}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
if hostname, ok := utils.GetDomainWithoutPort(req.Host); ok {
if !domainCheck.IsValid(req.Host) {
http.Error(rw, fmt.Sprintf("%d %s\n", 420, "Invalid host"), 420)
return
}
if tokenValue := params.ByName("token"); tokenValue != "" {
rw.WriteHeader(http.StatusOK)
return
}
}
rw.WriteHeader(http.StatusNotFound)
})
}

1
servers/https.go Normal file
View File

@ -0,0 +1 @@
package servers

View File

@ -1,17 +1,119 @@
package target
import (
"bytes"
"fmt"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/utils"
"github.com/rs/cors"
"io"
"log"
"net/http"
"net/url"
"path"
)
var serveApiCors = cors.New(cors.Options{
AllowedOrigins: []string{"*"},
AllowedHeaders: []string{"Content-Type", "Authorization"},
AllowedMethods: []string{
http.MethodGet,
http.MethodHead,
http.MethodPost,
http.MethodPut,
http.MethodPatch,
http.MethodDelete,
http.MethodConnect,
http.MethodOptions,
http.MethodTrace,
},
AllowCredentials: true,
})
type Route struct {
Pre bool
Host string
Port int
Path string
Abs bool
Cors bool
SecureMode bool
ForwardHost bool
IgnoreCert bool
Headers http.Header
Proxy http.Handler
}
func (r Route) IsIgnoreCert() bool { return r.IgnoreCert }
func (r Route) UpdateHeaders(header http.Header) {
for k, v := range r.Headers {
header[k] = v
}
}
func (r Route) FullHost() string {
if r.Port == 0 {
return r.Host
}
return fmt.Sprintf("%s:%d", r.Host, r.Port)
}
func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// pass
if r.Cors {
serveApiCors.Handler(http.HandlerFunc(r.internalServeHTTP)).ServeHTTP(rw, req)
} else {
r.internalServeHTTP(rw, req)
}
}
func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
scheme := "http"
if r.SecureMode {
scheme = "https"
if r.Port == 0 {
r.Port = 443
}
} else {
if r.Port == 0 {
r.Port = 80
}
}
p := r.Path
if !r.Abs {
p = path.Join(r.Path, req.URL.Path)
}
if p == "" {
p = "/"
}
buf := new(bytes.Buffer)
if req.Body != nil {
_, _ = io.Copy(buf, req.Body)
}
u := &url.URL{
Scheme: scheme,
Host: r.FullHost(),
Path: p,
RawQuery: req.URL.RawQuery,
}
req2, err := http.NewRequest(req.Method, u.String(), buf)
if err != nil {
log.Printf("[ServeRoute::ServeHTTP()] Error generating new request: %s\n", err)
utils.RespondHttpStatus(rw, http.StatusBadGateway)
return
}
for k, v := range req.Header {
if k == "Host" {
continue
}
req2.Header[k] = v
}
if r.ForwardHost {
req2.Host = req.Host
}
r.Proxy.ServeHTTP(rw, proxy.SetReverseProxyHost(req2, r))
}

61
utils/domain-utils.go Normal file
View File

@ -0,0 +1,61 @@
package utils
import (
"fmt"
"strings"
)
func SplitDomainPort(host string, defaultPort uint16) (domain string, port uint16, ok bool) {
a := strings.SplitN(host, ":", 2)
switch len(a) {
case 2:
domain = a[0]
_, err := fmt.Sscanf(a[1], "%d", &port)
ok = err == nil
case 1:
domain = a[0]
port = defaultPort
ok = true
}
return
}
func GetDomainWithoutPort(domain string) (string, bool) {
a := strings.SplitN(domain, ":", 2)
if len(a) == 2 {
return a[0], true
}
if len(a) == 0 {
return "", false
}
return a[0], true
}
func ReplaceSubdomainWithWildcard(domain string) (string, bool) {
a, b := GetBaseDomain(domain)
return "*." + a, b
}
func GetBaseDomain(domain string) (string, bool) {
a := strings.SplitN(domain, ".", 2)
l := len(a)
if l == 2 {
return a[1], true
}
if l == 1 {
return a[0], true
}
return "", false
}
func GetTopFqdn(domain string) (string, bool) {
a := strings.Split(domain, ".")
l := len(a)
if l >= 2 {
return strings.Join(a[l-2:], "."), true
}
if l == 1 {
return domain, true
}
return "", false
}

View File

@ -0,0 +1,58 @@
package utils
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestSplitDomainPort(t *testing.T) {
domain, port, ok := SplitDomainPort("www.example.com:5612", 443)
assert.True(t, ok, "Output should be true")
assert.Equal(t, "www.example.com", domain)
assert.Equal(t, uint16(5612), port)
domain, port, ok = SplitDomainPort("example.com", 443)
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain)
assert.Equal(t, uint16(443), port)
}
func TestDomainWithoutPort(t *testing.T) {
domain, ok := GetDomainWithoutPort("www.example.com:5612")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "www.example.com", domain)
domain, ok = GetDomainWithoutPort("example.com:443")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain)
}
func TestReplaceSubdomainWithWildcard(t *testing.T) {
domain, ok := ReplaceSubdomainWithWildcard("www.example.com")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "*.example.com", domain)
domain, ok = ReplaceSubdomainWithWildcard("www.example.com:5612")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "*.example.com:5612", domain)
}
func TestGetBaseDomain(t *testing.T) {
domain, ok := GetBaseDomain("www.example.com")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain)
domain, ok = GetBaseDomain("www.example.com:5612")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com:5612", domain)
}
func TestGetTopFqdn(t *testing.T) {
domain, ok := GetTopFqdn("www.example.com")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain)
domain, ok = GetTopFqdn("www.www.example.com")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain)
}

13
utils/multi-compilable.go Normal file
View File

@ -0,0 +1,13 @@
package utils
type Compilable interface {
Compile()
}
type MultiCompilable []Compilable
func (m MultiCompilable) Compile() {
for _, i := range m {
i.Compile()
}
}

10
utils/response.go Normal file
View File

@ -0,0 +1,10 @@
package utils
import (
"fmt"
"net/http"
)
func RespondHttpStatus(rw http.ResponseWriter, status int) {
http.Error(rw, fmt.Sprintf("%d %s\n", status, http.StatusText(status)), status)
}

33
utils/server-utils.go Normal file
View File

@ -0,0 +1,33 @@
package utils
import (
"log"
"net/http"
"strings"
)
func logHttpServerError(prefix string, err error) {
if err != nil {
if err == http.ErrServerClosed {
log.Printf("[%s] The http server shutdown successfully\n", prefix)
} else {
log.Printf("[%s] Error trying to host the http server: %s\n", prefix, err.Error())
}
}
}
func RunBackgroundHttp(prefix string, s *http.Server) {
logHttpServerError(prefix, s.ListenAndServe())
}
func RunBackgroundHttps(prefix string, s *http.Server) {
logHttpServerError(prefix, s.ListenAndServeTLS("", ""))
}
func GetBearer(req *http.Request) string {
a := req.Header.Get("Authorization")
if t, ok := strings.CutPrefix(a, "Bearer "); ok {
return t
}
return ""
}