Start writing main function

This commit is contained in:
Melon 2023-04-21 03:21:46 +01:00
parent d4aed095ec
commit 0e42e54f08
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
18 changed files with 495 additions and 11 deletions

7
.idea/sqldialects.xml generated Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="SqlDialectMappings">
<file url="file://$PROJECT_DIR$/cmd/violet/init.sql" dialect="GenericSQL" />
<file url="PROJECT" dialect="SQLite" />
</component>
</project>

View File

@ -22,7 +22,7 @@ func benchRequest(b *testing.B, router http.Handler, r *http.Request) {
}
func BenchmarkVioletRouter(b *testing.B) {
r := router.New()
r := router.New(nil)
r.AddRedirect("*.example.com", "", target.Redirect{
Pre: true,
Host: "example.com",

1
cmd/violet/init.sql Normal file
View File

@ -0,0 +1 @@
create table acme_challenge (id integer not null primary key, key varchar, value varchar);

View File

@ -1,5 +1,59 @@
package violet
package main
import (
"database/sql"
_ "embed"
"errors"
"flag"
"github.com/MrMelon54/violet/domains"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/router"
"github.com/MrMelon54/violet/servers"
"github.com/MrMelon54/violet/utils"
_ "github.com/mattn/go-sqlite3"
"log"
"os"
)
//go:embed init.sql
var initSql string
var (
databasePath = flag.String("db", "", "/path/to/database.sqlite")
certPath = flag.String("cert", "", "/path/to/certificates")
apiListen = flag.String("api", "127.0.0.1:8080", "address for api listening")
httpListen = flag.String("http", "0.0.0.0:80", "address for http listening")
httpsListen = flag.String("https", "0.0.0.0:443", "address for https listening")
)
func main() {
log.Println("[Violet] Starting...")
_, err := os.Stat(*certPath)
if errors.Is(err, os.ErrNotExist) {
log.Fatalf("[Violet] Certificate path '%s' does not exists", *certPath)
}
_, err = os.Stat(*databasePath)
dbExists := !errors.Is(err, os.ErrNotExist)
db, err := sql.Open("sqlite3", *databasePath)
if err != nil {
log.Fatalf("[Violet] Failed to open database '%s'...", *databasePath)
}
if !dbExists {
log.Println("[Violet] Creating new database and running init.sql")
_, err = db.Exec(initSql)
if err != nil {
log.Fatalf("[Violet] Failed to run init.sql")
}
}
allowedDomains := domains.New()
reverseProxy := proxy.CreateHybridReverseProxy()
r := router.New(reverseProxy)
servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{})
servers.NewHttpServer(*httpListen, 0, allowedDomains)
}

36
domains/domains.go Normal file
View File

@ -0,0 +1,36 @@
package domains
import (
"github.com/MrMelon54/violet/utils"
"strings"
"sync"
)
type Domains struct {
s *sync.RWMutex
m map[string]struct{}
}
func New() *Domains {
return &Domains{
s: &sync.RWMutex{},
m: make(map[string]struct{}),
}
}
func (d *Domains) IsValid(host string) bool {
domain, ok := utils.GetDomainWithoutPort(host)
if !ok {
return false
}
d.s.RLock()
defer d.s.RUnlock()
n := strings.Split(domain, ".")
for i := 0; i < len(n); i++ {
if _, ok := d.m[strings.Join(n[i:], ".")]; ok {
return true
}
}
return false
}

7
go.mod
View File

@ -3,13 +3,20 @@ module github.com/MrMelon54/violet
go 1.20
require (
code.mrmelon54.com/melon/summer-utils v0.0.3
github.com/MrMelon54/trie v0.0.2
github.com/julienschmidt/httprouter v1.3.0
github.com/mattn/go-sqlite3 v1.14.16
github.com/mrmelon54/mjwt v0.0.1
github.com/rs/cors v1.9.0
github.com/stretchr/testify v1.8.2
golang.org/x/net v0.9.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/golang-jwt/jwt/v4 v4.5.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

14
go.sum
View File

@ -1,10 +1,24 @@
code.mrmelon54.com/melon/summer-utils v0.0.3 h1:Bz4o5BBOqWCNGpKkxUum4rwMn/DIdyMCKGQ/D6SXD6Q=
code.mrmelon54.com/melon/summer-utils v0.0.3/go.mod h1:Gh/baXSzkf1ZhHonpPP8oQkyhhmFZcC2yTMlrwclDUw=
github.com/MrMelon54/trie v0.0.2 h1:ZXWcX5ij62O9K4I/anuHmVg8L3tF0UGdlPceAASwKEY=
github.com/MrMelon54/trie v0.0.2/go.mod h1:sGCGOcqb+DxSxvHgSOpbpkmA7mFZR47YDExy9OCbVZI=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mrmelon54/mjwt v0.0.1 h1:XgyWviTmgsbMiKXjxo+Jp/QSf7FF7/omkvrUag8/P5U=
github.com/mrmelon54/mjwt v0.0.1/go.mod h1:M+kZ6t9EArEQ2/CGjfgyNhAo542ot+S7gw5uJCK11Ms=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE=
github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=

View File

@ -16,13 +16,14 @@ type Router struct {
proxy *httputil.ReverseProxy
}
func New() *Router {
func New(proxy *httputil.ReverseProxy) *Router {
return &Router{
route: make(map[string]*trie.Trie[target.Route]),
redirect: make(map[string]*trie.Trie[target.Redirect]),
notFound: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
_, _ = fmt.Fprintf(rw, "%d %s\n", http.StatusNotFound, http.StatusText(http.StatusNotFound))
}),
proxy: proxy,
}
}
@ -45,10 +46,11 @@ func (r *Router) hostRedirect(host string) *trie.Trie[target.Redirect] {
}
func (r *Router) AddService(host string, t target.Route) {
r.AddRoute(host, "", t)
r.AddRoute(host, "/", t)
}
func (r *Router) AddRoute(host string, path string, t target.Route) {
t.Proxy = r.proxy
r.hostRoute(host).PutString(path, t)
}

View File

@ -113,7 +113,7 @@ func assertHttpRedirect(t *testing.T, r *Router, code int, target, method, start
func TestRouter_AddRedirect(t *testing.T) {
for _, i := range redirectTests {
r := New()
r := New(nil)
dst := i.dst
dst.Host = "example.com"
dst.Code = http.StatusFound

59
servers/api.go Normal file
View File

@ -0,0 +1,59 @@
package servers
import (
"code.mrmelon54.com/melon/summer-utils/claims/auth"
"github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter"
"github.com/mrmelon54/mjwt"
"log"
"net/http"
"time"
)
// NewApiServer creates and runs a *http.Server containing all the API endpoints for the software
//
// `/compile` - reloads all domains, routes and redirects from the configuration files
func NewApiServer(listen string, verify mjwt.Provider, compileTarget utils.MultiCompilable) *http.Server {
r := httprouter.New()
// Endpoint `/compile` reloads all domains, routes and redirects from the configuration files
r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
// Get bearer token
bearer := utils.GetBearer(req)
if bearer == "" {
utils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
// Read claims from mjwt
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer)
if err != nil {
utils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
// Token must have `violet:compile` perm
if !b.Claims.Perms.Has("violet:compile") {
utils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
// Trigger the compile action
compileTarget.Compile()
rw.WriteHeader(http.StatusAccepted)
})
// Create and run http server
s := &http.Server{
Addr: listen,
Handler: r,
ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute,
WriteTimeout: time.Minute,
IdleTimeout: time.Minute,
MaxHeaderBytes: 2500,
}
log.Printf("[API] Starting API server on: '%s'\n", s.Addr)
go utils.RunBackgroundHttp("API", s)
return s
}

26
servers/http.go Normal file
View File

@ -0,0 +1,26 @@
package servers
import (
"fmt"
"github.com/MrMelon54/violet/domains"
"github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter"
"net/http"
)
func NewHttpServer(listen string, httpsPort uint16, domainCheck *domains.Domains) *http.Server {
r := httprouter.New()
r.GET("/.well-known/acme-challenge/{token}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
if hostname, ok := utils.GetDomainWithoutPort(req.Host); ok {
if !domainCheck.IsValid(req.Host) {
http.Error(rw, fmt.Sprintf("%d %s\n", 420, "Invalid host"), 420)
return
}
if tokenValue := params.ByName("token"); tokenValue != "" {
rw.WriteHeader(http.StatusOK)
return
}
}
rw.WriteHeader(http.StatusNotFound)
})
}

1
servers/https.go Normal file
View File

@ -0,0 +1 @@
package servers

View File

@ -1,17 +1,119 @@
package target
import (
"bytes"
"fmt"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/utils"
"github.com/rs/cors"
"io"
"log"
"net/http"
"net/url"
"path"
)
var serveApiCors = cors.New(cors.Options{
AllowedOrigins: []string{"*"},
AllowedHeaders: []string{"Content-Type", "Authorization"},
AllowedMethods: []string{
http.MethodGet,
http.MethodHead,
http.MethodPost,
http.MethodPut,
http.MethodPatch,
http.MethodDelete,
http.MethodConnect,
http.MethodOptions,
http.MethodTrace,
},
AllowCredentials: true,
})
type Route struct {
Pre bool
Host string
Port int
Path string
Abs bool
Pre bool
Host string
Port int
Path string
Abs bool
Cors bool
SecureMode bool
ForwardHost bool
IgnoreCert bool
Headers http.Header
Proxy http.Handler
}
func (r Route) IsIgnoreCert() bool { return r.IgnoreCert }
func (r Route) UpdateHeaders(header http.Header) {
for k, v := range r.Headers {
header[k] = v
}
}
func (r Route) FullHost() string {
if r.Port == 0 {
return r.Host
}
return fmt.Sprintf("%s:%d", r.Host, r.Port)
}
func (r Route) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// pass
if r.Cors {
serveApiCors.Handler(http.HandlerFunc(r.internalServeHTTP)).ServeHTTP(rw, req)
} else {
r.internalServeHTTP(rw, req)
}
}
func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
scheme := "http"
if r.SecureMode {
scheme = "https"
if r.Port == 0 {
r.Port = 443
}
} else {
if r.Port == 0 {
r.Port = 80
}
}
p := r.Path
if !r.Abs {
p = path.Join(r.Path, req.URL.Path)
}
if p == "" {
p = "/"
}
buf := new(bytes.Buffer)
if req.Body != nil {
_, _ = io.Copy(buf, req.Body)
}
u := &url.URL{
Scheme: scheme,
Host: r.FullHost(),
Path: p,
RawQuery: req.URL.RawQuery,
}
req2, err := http.NewRequest(req.Method, u.String(), buf)
if err != nil {
log.Printf("[ServeRoute::ServeHTTP()] Error generating new request: %s\n", err)
utils.RespondHttpStatus(rw, http.StatusBadGateway)
return
}
for k, v := range req.Header {
if k == "Host" {
continue
}
req2.Header[k] = v
}
if r.ForwardHost {
req2.Host = req.Host
}
r.Proxy.ServeHTTP(rw, proxy.SetReverseProxyHost(req2, r))
}

61
utils/domain-utils.go Normal file
View File

@ -0,0 +1,61 @@
package utils
import (
"fmt"
"strings"
)
func SplitDomainPort(host string, defaultPort uint16) (domain string, port uint16, ok bool) {
a := strings.SplitN(host, ":", 2)
switch len(a) {
case 2:
domain = a[0]
_, err := fmt.Sscanf(a[1], "%d", &port)
ok = err == nil
case 1:
domain = a[0]
port = defaultPort
ok = true
}
return
}
func GetDomainWithoutPort(domain string) (string, bool) {
a := strings.SplitN(domain, ":", 2)
if len(a) == 2 {
return a[0], true
}
if len(a) == 0 {
return "", false
}
return a[0], true
}
func ReplaceSubdomainWithWildcard(domain string) (string, bool) {
a, b := GetBaseDomain(domain)
return "*." + a, b
}
func GetBaseDomain(domain string) (string, bool) {
a := strings.SplitN(domain, ".", 2)
l := len(a)
if l == 2 {
return a[1], true
}
if l == 1 {
return a[0], true
}
return "", false
}
func GetTopFqdn(domain string) (string, bool) {
a := strings.Split(domain, ".")
l := len(a)
if l >= 2 {
return strings.Join(a[l-2:], "."), true
}
if l == 1 {
return domain, true
}
return "", false
}

View File

@ -0,0 +1,58 @@
package utils
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestSplitDomainPort(t *testing.T) {
domain, port, ok := SplitDomainPort("www.example.com:5612", 443)
assert.True(t, ok, "Output should be true")
assert.Equal(t, "www.example.com", domain)
assert.Equal(t, uint16(5612), port)
domain, port, ok = SplitDomainPort("example.com", 443)
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain)
assert.Equal(t, uint16(443), port)
}
func TestDomainWithoutPort(t *testing.T) {
domain, ok := GetDomainWithoutPort("www.example.com:5612")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "www.example.com", domain)
domain, ok = GetDomainWithoutPort("example.com:443")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain)
}
func TestReplaceSubdomainWithWildcard(t *testing.T) {
domain, ok := ReplaceSubdomainWithWildcard("www.example.com")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "*.example.com", domain)
domain, ok = ReplaceSubdomainWithWildcard("www.example.com:5612")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "*.example.com:5612", domain)
}
func TestGetBaseDomain(t *testing.T) {
domain, ok := GetBaseDomain("www.example.com")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain)
domain, ok = GetBaseDomain("www.example.com:5612")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com:5612", domain)
}
func TestGetTopFqdn(t *testing.T) {
domain, ok := GetTopFqdn("www.example.com")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain)
domain, ok = GetTopFqdn("www.www.example.com")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain)
}

13
utils/multi-compilable.go Normal file
View File

@ -0,0 +1,13 @@
package utils
type Compilable interface {
Compile()
}
type MultiCompilable []Compilable
func (m MultiCompilable) Compile() {
for _, i := range m {
i.Compile()
}
}

10
utils/response.go Normal file
View File

@ -0,0 +1,10 @@
package utils
import (
"fmt"
"net/http"
)
func RespondHttpStatus(rw http.ResponseWriter, status int) {
http.Error(rw, fmt.Sprintf("%d %s\n", status, http.StatusText(status)), status)
}

33
utils/server-utils.go Normal file
View File

@ -0,0 +1,33 @@
package utils
import (
"log"
"net/http"
"strings"
)
func logHttpServerError(prefix string, err error) {
if err != nil {
if err == http.ErrServerClosed {
log.Printf("[%s] The http server shutdown successfully\n", prefix)
} else {
log.Printf("[%s] Error trying to host the http server: %s\n", prefix, err.Error())
}
}
}
func RunBackgroundHttp(prefix string, s *http.Server) {
logHttpServerError(prefix, s.ListenAndServe())
}
func RunBackgroundHttps(prefix string, s *http.Server) {
logHttpServerError(prefix, s.ListenAndServeTLS("", ""))
}
func GetBearer(req *http.Request) string {
a := req.Header.Get("Authorization")
if t, ok := strings.CutPrefix(a, "Bearer "); ok {
return t
}
return ""
}