Compare commits

...

31 Commits

Author SHA1 Message Date
d3d6782b22
Merge branch 'tableflip' 2024-11-03 23:12:07 +00:00
0f095056d4
Modify routing error messages 2024-10-11 17:07:22 +01:00
aa77dccaaf
Log route source and internal url on error 2024-08-28 19:53:56 +01:00
ecee594219
Less logging from hybrid transport 2024-08-28 19:49:30 +01:00
d7b7721378
Correct logged listening addresses 2024-08-17 13:15:15 +01:00
3e86b91ec3
Add support for tableflip 2024-08-17 12:29:50 +01:00
f442409ebf
Forgot to make the map 2024-08-06 00:18:36 +01:00
8aa82303ce
Add filterWebsocketHeaders function 2024-08-06 00:00:54 +01:00
1f4f4414d5
Transition to new logger 2024-05-13 19:33:33 +01:00
a8db73d957
Update dependencies 2024-05-13 16:48:35 +01:00
1181fde508
Update dependencies 2024-04-20 16:17:32 +01:00
900203b560
Update go version 2024-04-15 16:21:47 +01:00
69bce2d12d
Initial support for sqlc and migrations 2024-03-08 16:05:39 +00:00
a13db89c44
Add host to metric middlewares 2024-02-21 11:44:51 +00:00
e901a73129
Add metrics 2024-02-16 01:41:42 +00:00
333394cf89
Patch to not use * for CORS auth header 2024-02-14 19:30:21 +00:00
f8dde8eebe
Try to reload certificates every 2 hours 2024-01-14 14:00:54 +00:00
822c7b570a
Add suggested TLSv1.2 config 2023-12-16 01:03:22 +00:00
bc6e98db8c
Reformat the tls config 2023-12-06 08:37:35 +00:00
2cce26429b
Update go version 2023-12-04 16:40:02 +00:00
5643f05aa0
Update dependencies 2023-12-04 16:38:31 +00:00
fc2f3d5b7b
Replace bodged GetTopFqdn with eTLD+1 2023-11-03 08:24:17 +00:00
37b0617e78
Add HSTS header 2023-11-03 08:09:29 +00:00
1194717a32
Detect no subdomains in GetTopFqdn 2023-10-29 13:15:25 +00:00
11b989b50c
Remove extra logging line 2023-10-29 12:56:53 +00:00
30bcea40b8
New mjwt library path and rewrite to allow changing active variable and properly deleting routes 2023-10-29 12:55:19 +00:00
69670e068b
Add description field to routes and redirects 2023-10-28 22:20:04 +01:00
c91f1dd2fc
I forgot the wildcard in getDomainOwnershipClaims 2023-10-27 12:54:16 +01:00
754fd2d396
Add tests and remove full route and redirect performance code 2023-10-27 12:51:12 +01:00
3834787f8f
Better route and redirect searching 2023-10-27 11:55:18 +01:00
52547234b0
Add domain specific get request 2023-10-27 09:16:52 +01:00
65 changed files with 1606 additions and 594 deletions

View File

@ -4,7 +4,7 @@ jobs:
test:
strategy:
matrix:
go-version: [1.20.x]
go-version: [1.22.x]
runs-on: ubuntu-latest
steps:
- uses: actions/setup-go@v3

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
*.sqlite
*.local
.idea/
.data

8
.idea/.gitignore generated vendored
View File

@ -1,8 +0,0 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

View File

@ -1,5 +0,0 @@
<component name="ProjectCodeStyleConfiguration">
<state>
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
</state>
</component>

12
.idea/dataSources.xml generated
View File

@ -1,12 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="identifier.sqlite" uuid="5b42d21a-92a8-43d0-8651-c1555b91060c">
<driver-ref>sqlite.xerial</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>org.sqlite.JDBC</jdbc-driver>
<jdbc-url>jdbc:sqlite:identifier.sqlite</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
</component>
</project>

7
.idea/discord.xml generated
View File

@ -1,7 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DiscordProjectSettings">
<option name="show" value="PROJECT_FILES" />
<option name="description" value="" />
</component>
</project>

6
.idea/misc.xml generated
View File

@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="MarkdownSettingsMigration">
<option name="stateVersion" value="1" />
</component>
</project>

8
.idea/modules.xml generated
View File

@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/violet.iml" filepath="$PROJECT_DIR$/.idea/violet.iml" />
</modules>
</component>
</project>

6
.idea/sqldialects.xml generated
View File

@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="SqlDialectMappings">
<file url="PROJECT" dialect="SQLite" />
</component>
</project>

6
.idea/vcs.xml generated
View File

@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

9
.idea/violet.iml generated
View File

@ -1,9 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="WEB_MODULE" version="4">
<component name="Go" enabled="true" />
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@ -4,11 +4,11 @@ import (
"crypto/tls"
"crypto/x509/pkix"
"fmt"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/utils"
"github.com/MrMelon54/certgen"
"github.com/MrMelon54/rescheduler"
"github.com/mrmelon54/certgen"
"github.com/mrmelon54/rescheduler"
"io/fs"
"log"
"math/big"
"os"
"strings"
@ -17,6 +17,8 @@ import (
"time"
)
var Logger = logger.Logger.WithPrefix("Violet Certs")
// Certs is the certificate loader and management system.
type Certs struct {
cDir fs.FS
@ -27,6 +29,8 @@ type Certs struct {
ca *certgen.CertGen
sn atomic.Int64
r *rescheduler.Rescheduler
t *time.Ticker
ts chan struct{}
}
// New creates a new cert list
@ -37,15 +41,26 @@ func New(certDir fs.FS, keyDir fs.FS, selfCert bool) *Certs {
ss: selfCert,
s: &sync.RWMutex{},
m: make(map[string]*tls.Certificate),
ts: make(chan struct{}, 1),
}
// the rescheduler isn't even used in self cert mode so why initialise it
if !selfCert {
// the rescheduler isn't even used in self cert mode so why initialise it
c.r = rescheduler.NewRescheduler(c.threadCompile)
}
// in self-signed mode generate a CA certificate to sign other certificates
if c.ss {
c.t = time.NewTicker(2 * time.Hour)
go func() {
for {
select {
case <-c.t.C:
c.Compile()
case <-c.ts:
return
}
}
}()
} else {
// in self-signed mode generate a CA certificate to sign other certificates
ca, err := certgen.MakeCaTls(4096, pkix.Name{
Country: []string{"GB"},
Organization: []string{"Violet"},
@ -56,7 +71,7 @@ func New(certDir fs.FS, keyDir fs.FS, selfCert bool) *Certs {
return now.AddDate(10, 0, 0)
})
if err != nil {
log.Fatalln("Failed to generate CA cert for self-signed mode:", err)
logger.Logger.Fatal("Failed to generate CA cert for self-signed mode", "err", err)
}
c.ca = ca
}
@ -118,6 +133,13 @@ func (c *Certs) Compile() {
c.r.Run()
}
func (c *Certs) Stop() {
if c.t != nil {
c.t.Stop()
}
close(c.ts)
}
func (c *Certs) threadCompile() {
// new map
certMap := make(map[string]*tls.Certificate)
@ -125,7 +147,7 @@ func (c *Certs) threadCompile() {
// compile map and check errors
err := c.internalCompile(certMap)
if err != nil {
log.Printf("[Certs] Compile failed: %s\n", err)
Logger.Infof("Compile failed: %s\n", err)
return
}
@ -148,7 +170,7 @@ func (c *Certs) internalCompile(m map[string]*tls.Certificate) error {
return fmt.Errorf("failed to read cert dir: %w", err)
}
log.Printf("[Certs] Compiling lookup table for %d certificates\n", len(files))
Logger.Infof("Compiling lookup table for %d certificates\n", len(files))
// find and parse certs
for _, i := range files {

View File

@ -3,7 +3,7 @@ package certs
import (
"crypto/x509/pkix"
"fmt"
"github.com/MrMelon54/certgen"
"github.com/mrmelon54/certgen"
"github.com/stretchr/testify/assert"
"math/big"
"testing"
@ -16,7 +16,7 @@ func TestCertsNew_Lookup(t *testing.T) {
// type to test that certificate files can be found and read correctly. This
// uses a MapFS for performance during tests.
ca, err := certgen.MakeCaTls(4096, pkix.Name{
ca, err := certgen.MakeCaTls(2048, pkix.Name{
Country: []string{"GB"},
Organization: []string{"Violet"},
OrganizationalUnit: []string{"Development"},
@ -29,7 +29,7 @@ func TestCertsNew_Lookup(t *testing.T) {
domain := "example.com"
sn := int64(1)
serverTls, err := certgen.MakeServerTls(ca, 4096, pkix.Name{
serverTls, err := certgen.MakeServerTls(ca, 2048, pkix.Name{
Country: []string{"GB"},
Organization: []string{domain},
OrganizationalUnit: []string{domain},
@ -63,6 +63,10 @@ func TestCertsNew_Lookup(t *testing.T) {
}
func TestCertsNew_SelfSigned(t *testing.T) {
if testing.Short() {
return
}
certs := New(nil, nil, true)
cc := certs.GetCertForDomain("example.com")
leaf := certgen.TlsLeaf(cc)

View File

@ -2,13 +2,15 @@ package main
import (
"context"
"database/sql"
"encoding/json"
"flag"
"github.com/1f349/mjwt"
"github.com/1f349/violet"
"github.com/1f349/violet/certs"
"github.com/1f349/violet/domains"
errorPages "github.com/1f349/violet/error-pages"
"github.com/1f349/violet/favicons"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/1f349/violet/router"
@ -16,60 +18,64 @@ import (
"github.com/1f349/violet/servers/api"
"github.com/1f349/violet/servers/conf"
"github.com/1f349/violet/utils"
"github.com/MrMelon54/exit-reload"
"github.com/MrMelon54/mjwt"
"github.com/charmbracelet/log"
"github.com/cloudflare/tableflip"
"github.com/google/subcommands"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"io/fs"
"log"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"path/filepath"
"runtime/pprof"
"syscall"
"time"
)
type serveCmd struct {
configPath string
cpuprofile string
debugLog bool
pidFile string
}
func (s *serveCmd) Name() string { return "serve" }
func (s *serveCmd) Synopsis() string { return "Serve reverse proxy server" }
func (s *serveCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&s.configPath, "conf", "", "/path/to/config.json : path to the config file")
f.StringVar(&s.cpuprofile, "cpuprofile", "", "write cpu profile to file")
f.BoolVar(&s.debugLog, "debug", false, "enable debug logging")
f.StringVar(&s.pidFile, "pid-file", "", "path to pid file")
}
func (s *serveCmd) Usage() string {
return `serve [-conf <config file>] [-cpuprofile <profile file>]
return `serve [-conf <config file>] [-debug] [-pid-file <pid file>]
Serve reverse proxy server using information from config file
`
}
func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
log.Println("[Violet] Starting...")
// Enable cpu profiling
if s.cpuprofile != "" {
f, err := os.Create(s.cpuprofile)
if err != nil {
log.Fatal(err)
}
log.Printf("[Violet] CPU profiling enabled, writing to '%s'\n", s.cpuprofile)
_ = pprof.StartCPUProfile(f)
defer pprof.StopCPUProfile()
if s.debugLog {
logger.Logger.SetLevel(log.DebugLevel)
}
logger.Logger.Info("Starting...")
upg, err := tableflip.New(tableflip.Options{
PIDFile: s.pidFile,
})
if err != nil {
panic(err)
}
defer upg.Stop()
if s.configPath == "" {
log.Println("[Violet] Error: config flag is missing")
logger.Logger.Info("Error: config flag is missing")
return subcommands.ExitUsageError
}
openConf, err := os.Open(s.configPath)
if err != nil {
if os.IsNotExist(err) {
log.Println("[Violet] Error: missing config file")
logger.Logger.Info("Error: missing config file")
} else {
log.Println("[Violet] Error: open config file: ", err)
logger.Logger.Info("Error: open config file: ", err)
}
return subcommands.ExitFailure
}
@ -77,124 +83,167 @@ func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
var config startUpConfig
err = json.NewDecoder(openConf).Decode(&config)
if err != nil {
log.Println("[Violet] Error: invalid config file: ", err)
logger.Logger.Info("Error: invalid config file: ", err)
return subcommands.ExitFailure
}
// working directory is the parent of the config file
wd := filepath.Dir(s.configPath)
normalLoad(config, wd)
return subcommands.ExitSuccess
}
func normalLoad(startUp startUpConfig, wd string) {
// the cert and key paths are useless in self-signed mode
if !startUp.SelfSigned {
if !config.SelfSigned {
// create path to cert dir
err := os.MkdirAll(filepath.Join(wd, "certs"), os.ModePerm)
if err != nil {
log.Fatal("[Violet] Failed to create certificate path")
logger.Logger.Fatal("Failed to create certificate path")
}
// create path to key dir
err = os.MkdirAll(filepath.Join(wd, "keys"), os.ModePerm)
if err != nil {
log.Fatal("[Violet] Failed to create certificate key path")
logger.Logger.Fatal("Failed to create certificate key path")
}
}
// errorPageDir stores an FS interface for accessing the error page directory
var errorPageDir fs.FS
if startUp.ErrorPagePath != "" {
errorPageDir = os.DirFS(startUp.ErrorPagePath)
err := os.MkdirAll(startUp.ErrorPagePath, os.ModePerm)
if config.ErrorPagePath != "" {
errorPageDir = os.DirFS(config.ErrorPagePath)
err := os.MkdirAll(config.ErrorPagePath, os.ModePerm)
if err != nil {
log.Fatalf("[Violet] Failed to create error page path '%s'", startUp.ErrorPagePath)
logger.Logger.Fatal("Failed to create error page", "path", config.ErrorPagePath)
}
}
// load the MJWT RSA public key from a pem encoded file
mJwtVerify, err := mjwt.NewMJwtVerifierFromFile(filepath.Join(wd, "signer.public.pem"))
if err != nil {
log.Fatalf("[Violet] Failed to load MJWT verifier public key from file '%s': %s", filepath.Join(wd, "signer.public.pem"), err)
logger.Logger.Fatal("Failed to load MJWT verifier public key", "file", filepath.Join(wd, "signer.public.pem"), "err", err)
}
// open sqlite database
db, err := sql.Open("sqlite3", filepath.Join(wd, "violet.db.sqlite"))
db, err := violet.InitDB(filepath.Join(wd, "violet.db.sqlite"))
if err != nil {
log.Fatal("[Violet] Failed to open database")
logger.Logger.Fatal("Failed to open database", "err", err)
}
certDir := os.DirFS(filepath.Join(wd, "certs"))
keyDir := os.DirFS(filepath.Join(wd, "keys"))
// setup registry for metrics
promRegistry := prometheus.NewRegistry()
promRegistry.MustRegister(
collectors.NewGoCollector(),
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
)
ws := websocket.NewServer()
allowedDomains := domains.New(db) // load allowed domains
acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store
allowedCerts := certs.New(certDir, keyDir, startUp.SelfSigned) // load certificate manager
hybridTransport := proxy.NewHybridTransport(ws) // load reverse proxy
dynamicFavicons := favicons.New(db, startUp.InkscapeCmd) // load dynamic favicon provider
dynamicErrorPages := errorPages.New(errorPageDir) // load dynamic error page provider
dynamicRouter := router.NewManager(db, hybridTransport) // load dynamic router manager
allowedDomains := domains.New(db) // load allowed domains
acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store
allowedCerts := certs.New(certDir, keyDir, config.SelfSigned) // load certificate manager
hybridTransport := proxy.NewHybridTransport(ws) // load reverse proxy
dynamicFavicons := favicons.New(db, config.InkscapeCmd) // load dynamic favicon provider
dynamicErrorPages := errorPages.New(errorPageDir) // load dynamic error page provider
dynamicRouter := router.NewManager(db, hybridTransport) // load dynamic router manager
// struct containing config for the http servers
srvConf := &conf.Conf{
ApiListen: startUp.Listen.Api,
HttpListen: startUp.Listen.Http,
HttpsListen: startUp.Listen.Https,
RateLimit: startUp.RateLimit,
DB: db,
Domains: allowedDomains,
Acme: acmeChallenges,
Certs: allowedCerts,
Favicons: dynamicFavicons,
Signer: mJwtVerify,
ErrorPages: dynamicErrorPages,
Router: dynamicRouter,
RateLimit: config.RateLimit,
DB: db,
Domains: allowedDomains,
Acme: acmeChallenges,
Certs: allowedCerts,
Favicons: dynamicFavicons,
Signer: mJwtVerify,
ErrorPages: dynamicErrorPages,
Router: dynamicRouter,
}
// create the compilable list and run a first time compile
allCompilables := utils.MultiCompilable{allowedDomains, allowedCerts, dynamicFavicons, dynamicErrorPages, dynamicRouter}
allCompilables.Compile()
var srvApi, srvHttp, srvHttps *http.Server
if srvConf.ApiListen != "" {
srvApi = api.NewApiServer(srvConf, allCompilables)
srvApi.SetKeepAlivesEnabled(false)
log.Printf("[API] Starting API server on: '%s'\n", srvApi.Addr)
go utils.RunBackgroundHttp("API", srvApi)
}
if srvConf.HttpListen != "" {
srvHttp = servers.NewHttpServer(srvConf)
srvHttp.SetKeepAlivesEnabled(false)
log.Printf("[HTTP] Starting HTTP server on: '%s'\n", srvHttp.Addr)
go utils.RunBackgroundHttp("HTTP", srvHttp)
}
if srvConf.HttpsListen != "" {
srvHttps = servers.NewHttpsServer(srvConf)
srvHttps.SetKeepAlivesEnabled(false)
log.Printf("[HTTPS] Starting HTTPS server on: '%s'\n", srvHttps.Addr)
go utils.RunBackgroundHttps("HTTPS", srvHttps)
_, httpsPort, ok := utils.SplitDomainPort(config.Listen.Https, 443)
if !ok {
httpsPort = 443
}
var srvApi, srvHttp, srvHttps *http.Server
if config.Listen.Api != "" {
// Listen must be called before Ready
lnApi, err := upg.Listen("tcp", config.Listen.Api)
if err != nil {
logger.Logger.Fatal("Listen failed", "err", err)
}
srvApi = api.NewApiServer(srvConf, allCompilables, promRegistry)
srvApi.SetKeepAlivesEnabled(false)
l := logger.Logger.With("server", "API")
l.Info("Starting server", "addr", config.Listen.Api)
go utils.RunBackgroundHttp(l, srvApi, lnApi)
}
if config.Listen.Http != "" {
// Listen must be called before Ready
lnHttp, err := upg.Listen("tcp", config.Listen.Http)
if err != nil {
logger.Logger.Fatal("Listen failed", "err", err)
}
srvHttp = servers.NewHttpServer(uint16(httpsPort), srvConf, promRegistry)
srvHttp.SetKeepAlivesEnabled(false)
l := logger.Logger.With("server", "HTTP")
l.Info("Starting server", "addr", config.Listen.Http)
go utils.RunBackgroundHttp(l, srvHttp, lnHttp)
}
if config.Listen.Https != "" {
// Listen must be called before Ready
lnHttps, err := upg.Listen("tcp", config.Listen.Https)
if err != nil {
logger.Logger.Fatal("Listen failed", "err", err)
}
srvHttps = servers.NewHttpsServer(srvConf, promRegistry)
srvHttps.SetKeepAlivesEnabled(false)
l := logger.Logger.With("server", "HTTPS")
l.Info("Starting server", "addr", config.Listen.Https)
go utils.RunBackgroundHttps(l, srvHttps, lnHttps)
}
// Do an upgrade on SIGHUP
go func() {
log.Println(http.ListenAndServe("localhost:6600", nil))
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGHUP)
for range sig {
err := upg.Upgrade()
if err != nil {
logger.Logger.Error("Failed upgrade", "err", err)
}
}
}()
exit_reload.ExitReload("Violet", func() {
allCompilables.Compile()
}, func() {
// close websockets first
ws.Shutdown()
logger.Logger.Info("Ready")
if err := upg.Ready(); err != nil {
panic(err)
}
<-upg.Exit()
// close http servers
if srvApi != nil {
_ = srvApi.Close()
}
if srvHttp != nil {
_ = srvHttp.Close()
}
if srvHttps != nil {
_ = srvHttps.Close()
}
time.AfterFunc(30*time.Second, func() {
logger.Logger.Warn("Graceful shutdown timed out")
os.Exit(1)
})
// stop updating certificates
allowedCerts.Stop()
// close websockets first
ws.Shutdown()
// close http servers
if srvApi != nil {
_ = srvApi.Close()
}
if srvHttp != nil {
_ = srvHttp.Close()
}
if srvHttps != nil {
_ = srvHttps.Close()
}
return subcommands.ExitSuccess
}

View File

@ -2,18 +2,18 @@ package main
import (
"context"
"database/sql"
"encoding/json"
"flag"
"fmt"
"github.com/1f349/violet"
"github.com/1f349/violet/domains"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/1f349/violet/router"
"github.com/1f349/violet/target"
"github.com/AlecAivazis/survey/v2"
"github.com/google/subcommands"
"log"
"net"
"net/http"
"net/url"
@ -42,7 +42,7 @@ func (s *setupCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
// get absolute path to specify files
wdAbs, err := filepath.Abs(s.wdPath)
if err != nil {
fmt.Println("[Violet] Failed to get full directory path: ", err)
fmt.Println("Failed to get full directory path: ", err)
return subcommands.ExitFailure
}
@ -50,11 +50,11 @@ func (s *setupCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
createFile := false
err = survey.AskOne(&survey.Confirm{Message: fmt.Sprintf("Create Violet config files in this directory: '%s'?", wdAbs)}, &createFile)
if err != nil {
fmt.Println("[Violet] Error: ", err)
fmt.Println("Error: ", err)
return subcommands.ExitFailure
}
if !createFile {
fmt.Println("[Violet] Goodbye")
fmt.Println("Goodbye")
return subcommands.ExitSuccess
}
@ -111,7 +111,7 @@ func (s *setupCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
},
}, &answers)
if err != nil {
fmt.Println("[Violet] Error: ", err)
fmt.Println("Error: ", err)
return subcommands.ExitFailure
}
@ -142,14 +142,14 @@ func (s *setupCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
RateLimit: answers.RateLimit,
})
if err != nil {
fmt.Println("[Violet] Failed to write config file: ", err)
fmt.Println("Failed to write config file: ", err)
return subcommands.ExitFailure
}
// open sqlite database
db, err := sql.Open("sqlite3", databaseFile)
db, err := violet.InitDB(databaseFile)
if err != nil {
log.Fatalf("[Violet] Failed to open database '%s'...", databaseFile)
logger.Logger.Fatal("Failed to open database", "err", err)
}
// domain manager to add a domain, no need to compile here as the program needs
@ -168,33 +168,36 @@ func (s *setupCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
return nil
}))
if err != nil {
fmt.Println("[Violet] Error: ", err)
fmt.Println("Error: ", err)
return subcommands.ExitFailure
}
// parse the api url
apiUrl, err := url.Parse(answers.ApiUrl)
if err != nil {
fmt.Println("[Violet] Failed to parse API URL: ", err)
fmt.Println("Failed to parse API URL: ", err)
return subcommands.ExitFailure
}
// add with the route manager, no need to compile as this will run when opened
// with the serve subcommand
routeManager := router.NewManager(db, proxy.NewHybridTransportWithCalls(&nilTransport{}, &nilTransport{}, &websocket.Server{}))
err = routeManager.InsertRoute(target.Route{
Src: path.Join(apiUrl.Host, apiUrl.Path),
Dst: answers.ApiListen,
Flags: target.FlagPre | target.FlagCors | target.FlagForwardHost | target.FlagForwardAddr,
err = routeManager.InsertRoute(target.RouteWithActive{
Route: target.Route{
Src: path.Join(apiUrl.Host, apiUrl.Path),
Dst: answers.ApiListen,
Flags: target.FlagPre | target.FlagCors | target.FlagForwardHost | target.FlagForwardAddr,
},
Active: true,
})
if err != nil {
fmt.Println("[Violet] Failed to insert api route into database: ", err)
fmt.Println("Failed to insert api route into database: ", err)
return subcommands.ExitFailure
}
}
fmt.Println("[Violet] Setup complete")
fmt.Printf("[Violet] Run the reverse proxy with `violet serve -conf %s`\n", confFile)
fmt.Println("Setup complete")
fmt.Printf("Run the reverse proxy with `violet serve -conf %s`\n", confFile)
return subcommands.ExitSuccess
}

31
database/db.go Normal file
View File

@ -0,0 +1,31 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.25.0
package database
import (
"context"
"database/sql"
)
type DBTX interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
PrepareContext(context.Context, string) (*sql.Stmt, error)
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}
func New(db DBTX) *Queries {
return &Queries{db: db}
}
type Queries struct {
db DBTX
}
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{
db: tx,
}
}

68
database/domain.sql.go Normal file
View File

@ -0,0 +1,68 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.25.0
// source: domain.sql
package database
import (
"context"
)
const addDomain = `-- name: AddDomain :exec
INSERT OR
REPLACE
INTO domains (domain, active)
VALUES (?, ?)
`
type AddDomainParams struct {
Domain string `json:"domain"`
Active bool `json:"active"`
}
func (q *Queries) AddDomain(ctx context.Context, arg AddDomainParams) error {
_, err := q.db.ExecContext(ctx, addDomain, arg.Domain, arg.Active)
return err
}
const deleteDomain = `-- name: DeleteDomain :exec
INSERT OR
REPLACE
INTO domains(domain, active)
VALUES (?, false)
`
func (q *Queries) DeleteDomain(ctx context.Context, domain string) error {
_, err := q.db.ExecContext(ctx, deleteDomain, domain)
return err
}
const getActiveDomains = `-- name: GetActiveDomains :many
SELECT domain
FROM domains
WHERE active = 1
`
func (q *Queries) GetActiveDomains(ctx context.Context) ([]string, error) {
rows, err := q.db.QueryContext(ctx, getActiveDomains)
if err != nil {
return nil, err
}
defer rows.Close()
var items []string
for rows.Next() {
var domain string
if err := rows.Scan(&domain); err != nil {
return nil, err
}
items = append(items, domain)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

74
database/favicon.sql.go Normal file
View File

@ -0,0 +1,74 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.25.0
// source: favicon.sql
package database
import (
"context"
"database/sql"
)
const getFavicons = `-- name: GetFavicons :many
SELECT host, svg, png, ico
FROM favicons
`
type GetFaviconsRow struct {
Host string `json:"host"`
Svg sql.NullString `json:"svg"`
Png sql.NullString `json:"png"`
Ico sql.NullString `json:"ico"`
}
func (q *Queries) GetFavicons(ctx context.Context) ([]GetFaviconsRow, error) {
rows, err := q.db.QueryContext(ctx, getFavicons)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetFaviconsRow
for rows.Next() {
var i GetFaviconsRow
if err := rows.Scan(
&i.Host,
&i.Svg,
&i.Png,
&i.Ico,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateFaviconCache = `-- name: UpdateFaviconCache :exec
INSERT OR
REPLACE INTO favicons (host, svg, png, ico)
VALUES (?, ?, ?, ?)
`
type UpdateFaviconCacheParams struct {
Host string `json:"host"`
Svg sql.NullString `json:"svg"`
Png sql.NullString `json:"png"`
Ico sql.NullString `json:"ico"`
}
func (q *Queries) UpdateFaviconCache(ctx context.Context, arg UpdateFaviconCacheParams) error {
_, err := q.db.ExecContext(ctx, updateFaviconCache,
arg.Host,
arg.Svg,
arg.Png,
arg.Ico,
)
return err
}

View File

@ -0,0 +1,4 @@
DROP TABLE domains;
DROP TABLE favicons;
DROP TABLE routes;
DROP TABLE redirects;

View File

@ -0,0 +1,36 @@
CREATE TABLE IF NOT EXISTS domains
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT UNIQUE NOT NULL,
active BOOLEAN NOT NULL DEFAULT 1
);
CREATE TABLE IF NOT EXISTS favicons
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
host VARCHAR NOT NULL,
svg VARCHAR,
png VARCHAR,
ico VARCHAR
);
CREATE TABLE IF NOT EXISTS routes
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
source TEXT UNIQUE NOT NULL,
destination TEXT NOT NULL,
description TEXT NOT NULL,
flags INTEGER NOT NULL DEFAULT 0,
active BOOLEAN NOT NULL DEFAULT 1
);
CREATE TABLE IF NOT EXISTS redirects
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
source TEXT UNIQUE NOT NULL,
destination TEXT NOT NULL,
description TEXT NOT NULL,
flags INTEGER NOT NULL DEFAULT 0,
code INTEGER NOT NULL DEFAULT 0,
active BOOLEAN NOT NULL DEFAULT 1
);

44
database/models.go Normal file
View File

@ -0,0 +1,44 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.25.0
package database
import (
"database/sql"
"github.com/1f349/violet/target"
)
type Domain struct {
ID int64 `json:"id"`
Domain string `json:"domain"`
Active bool `json:"active"`
}
type Favicon struct {
ID int64 `json:"id"`
Host string `json:"host"`
Svg sql.NullString `json:"svg"`
Png sql.NullString `json:"png"`
Ico sql.NullString `json:"ico"`
}
type Redirect struct {
ID int64 `json:"id"`
Source string `json:"source"`
Destination string `json:"destination"`
Description string `json:"description"`
Flags target.Flags `json:"flags"`
Code int64 `json:"code"`
Active bool `json:"active"`
}
type Route struct {
ID int64 `json:"id"`
Source string `json:"source"`
Destination string `json:"destination"`
Description string `json:"description"`
Flags target.Flags `json:"flags"`
Active bool `json:"active"`
}

View File

@ -0,0 +1,16 @@
-- name: GetActiveDomains :many
SELECT domain
FROM domains
WHERE active = 1;
-- name: AddDomain :exec
INSERT OR
REPLACE
INTO domains (domain, active)
VALUES (?, ?);
-- name: DeleteDomain :exec
INSERT OR
REPLACE
INTO domains(domain, active)
VALUES (?, false);

View File

@ -0,0 +1,8 @@
-- name: GetFavicons :many
SELECT host, svg, png, ico
FROM favicons;
-- name: UpdateFaviconCache :exec
INSERT OR
REPLACE INTO favicons (host, svg, png, ico)
VALUES (?, ?, ?, ?);

View File

@ -0,0 +1,39 @@
-- name: GetActiveRoutes :many
SELECT source, destination, flags
FROM routes
WHERE active = 1;
-- name: GetActiveRedirects :many
SELECT source, destination, flags, code
FROM redirects
WHERE active = 1;
-- name: GetAllRoutes :many
SELECT source, destination, description, flags, active
FROM routes;
-- name: GetAllRedirects :many
SELECT source, destination, description, flags, code, active
FROM redirects;
-- name: AddRoute :exec
INSERT OR
REPLACE
INTO routes (source, destination, description, flags, active)
VALUES (?, ?, ?, ?, ?);
-- name: AddRedirect :exec
INSERT OR
REPLACE
INTO redirects (source, destination, description, flags, code, active)
VALUES (?, ?, ?, ?, ?, ?);
-- name: RemoveRoute :exec
DELETE
FROM routes
WHERE source = ?;
-- name: RemoveRedirect :exec
DELETE
FROM redirects
WHERE source = ?;

250
database/routing.sql.go Normal file
View File

@ -0,0 +1,250 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.25.0
// source: routing.sql
package database
import (
"context"
"github.com/1f349/violet/target"
)
const addRedirect = `-- name: AddRedirect :exec
INSERT OR
REPLACE
INTO redirects (source, destination, description, flags, code, active)
VALUES (?, ?, ?, ?, ?, ?)
`
type AddRedirectParams struct {
Source string `json:"source"`
Destination string `json:"destination"`
Description string `json:"description"`
Flags target.Flags `json:"flags"`
Code int64 `json:"code"`
Active bool `json:"active"`
}
func (q *Queries) AddRedirect(ctx context.Context, arg AddRedirectParams) error {
_, err := q.db.ExecContext(ctx, addRedirect,
arg.Source,
arg.Destination,
arg.Description,
arg.Flags,
arg.Code,
arg.Active,
)
return err
}
const addRoute = `-- name: AddRoute :exec
INSERT OR
REPLACE
INTO routes (source, destination, description, flags, active)
VALUES (?, ?, ?, ?, ?)
`
type AddRouteParams struct {
Source string `json:"source"`
Destination string `json:"destination"`
Description string `json:"description"`
Flags target.Flags `json:"flags"`
Active bool `json:"active"`
}
func (q *Queries) AddRoute(ctx context.Context, arg AddRouteParams) error {
_, err := q.db.ExecContext(ctx, addRoute,
arg.Source,
arg.Destination,
arg.Description,
arg.Flags,
arg.Active,
)
return err
}
const getActiveRedirects = `-- name: GetActiveRedirects :many
SELECT source, destination, flags, code
FROM redirects
WHERE active = 1
`
type GetActiveRedirectsRow struct {
Source string `json:"source"`
Destination string `json:"destination"`
Flags target.Flags `json:"flags"`
Code int64 `json:"code"`
}
func (q *Queries) GetActiveRedirects(ctx context.Context) ([]GetActiveRedirectsRow, error) {
rows, err := q.db.QueryContext(ctx, getActiveRedirects)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetActiveRedirectsRow
for rows.Next() {
var i GetActiveRedirectsRow
if err := rows.Scan(
&i.Source,
&i.Destination,
&i.Flags,
&i.Code,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getActiveRoutes = `-- name: GetActiveRoutes :many
SELECT source, destination, flags
FROM routes
WHERE active = 1
`
type GetActiveRoutesRow struct {
Source string `json:"source"`
Destination string `json:"destination"`
Flags target.Flags `json:"flags"`
}
func (q *Queries) GetActiveRoutes(ctx context.Context) ([]GetActiveRoutesRow, error) {
rows, err := q.db.QueryContext(ctx, getActiveRoutes)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetActiveRoutesRow
for rows.Next() {
var i GetActiveRoutesRow
if err := rows.Scan(&i.Source, &i.Destination, &i.Flags); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getAllRedirects = `-- name: GetAllRedirects :many
SELECT source, destination, description, flags, code, active
FROM redirects
`
type GetAllRedirectsRow struct {
Source string `json:"source"`
Destination string `json:"destination"`
Description string `json:"description"`
Flags target.Flags `json:"flags"`
Code int64 `json:"code"`
Active bool `json:"active"`
}
func (q *Queries) GetAllRedirects(ctx context.Context) ([]GetAllRedirectsRow, error) {
rows, err := q.db.QueryContext(ctx, getAllRedirects)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetAllRedirectsRow
for rows.Next() {
var i GetAllRedirectsRow
if err := rows.Scan(
&i.Source,
&i.Destination,
&i.Description,
&i.Flags,
&i.Code,
&i.Active,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getAllRoutes = `-- name: GetAllRoutes :many
SELECT source, destination, description, flags, active
FROM routes
`
type GetAllRoutesRow struct {
Source string `json:"source"`
Destination string `json:"destination"`
Description string `json:"description"`
Flags target.Flags `json:"flags"`
Active bool `json:"active"`
}
func (q *Queries) GetAllRoutes(ctx context.Context) ([]GetAllRoutesRow, error) {
rows, err := q.db.QueryContext(ctx, getAllRoutes)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetAllRoutesRow
for rows.Next() {
var i GetAllRoutesRow
if err := rows.Scan(
&i.Source,
&i.Destination,
&i.Description,
&i.Flags,
&i.Active,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const removeRedirect = `-- name: RemoveRedirect :exec
DELETE
FROM redirects
WHERE source = ?
`
func (q *Queries) RemoveRedirect(ctx context.Context, source string) error {
_, err := q.db.ExecContext(ctx, removeRedirect, source)
return err
}
const removeRoute = `-- name: RemoveRoute :exec
DELETE
FROM routes
WHERE source = ?
`
func (q *Queries) RemoveRoute(ctx context.Context, source string) error {
_, err := q.db.ExecContext(ctx, removeRoute, source)
return err
}

View File

@ -1,6 +0,0 @@
CREATE TABLE IF NOT EXISTS domains
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT UNIQUE,
active INTEGER DEFAULT 1
);

View File

@ -1,41 +1,34 @@
package domains
import (
"database/sql"
"context"
_ "embed"
"github.com/1f349/violet/database"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/utils"
"github.com/MrMelon54/rescheduler"
"log"
"github.com/mrmelon54/rescheduler"
"strings"
"sync"
)
//go:embed create-table-domains.sql
var createTableDomains string
var Logger = logger.Logger.WithPrefix("Violet Domains")
// Domains is the domain list and management system.
type Domains struct {
db *sql.DB
db *database.Queries
s *sync.RWMutex
m map[string]struct{}
r *rescheduler.Rescheduler
}
// New creates a new domain list
func New(db *sql.DB) *Domains {
func New(db *database.Queries) *Domains {
a := &Domains{
db: db,
s: &sync.RWMutex{},
m: make(map[string]struct{}),
}
a.r = rescheduler.NewRescheduler(a.threadCompile)
// init domains table
_, err := a.db.Exec(createTableDomains)
if err != nil {
log.Printf("[WARN] Failed to generate 'domains' table\n")
return nil
}
return a
}
@ -77,7 +70,7 @@ func (d *Domains) threadCompile() {
// compile map and check errors
err := d.internalCompile(domainMap)
if err != nil {
log.Printf("[Domains] Compile failed: %s\n", err)
Logger.Info("Compile faile", "err", err)
return
}
@ -90,43 +83,39 @@ func (d *Domains) threadCompile() {
// internalCompile is a hidden internal method for querying the database during
// the Compile() method.
func (d *Domains) internalCompile(m map[string]struct{}) error {
log.Println("[Domains] Updating domains from database")
Logger.Info("Updating domains from database")
// sql or something?
rows, err := d.db.Query(`select domain from domains where active = 1`)
rows, err := d.db.GetActiveDomains(context.Background())
if err != nil {
return err
}
defer rows.Close()
// loop through rows and scan the allowed domain names
for rows.Next() {
var name string
err = rows.Scan(&name)
if err != nil {
return err
}
m[name] = struct{}{}
for _, i := range rows {
m[i] = struct{}{}
}
// check for errors
return rows.Err()
return nil
}
func (d *Domains) Put(domain string, active bool) {
d.s.Lock()
defer d.s.Unlock()
_, err := d.db.Exec("INSERT OR REPLACE INTO domains (domain, active) VALUES (?, ?)", domain, active)
err := d.db.AddDomain(context.Background(), database.AddDomainParams{
Domain: domain,
Active: active,
})
if err != nil {
log.Printf("[Violet] Database error: %s\n", err)
logger.Logger.Infof("Database error: %s\n", err)
}
}
func (d *Domains) Delete(domain string) {
d.s.Lock()
defer d.s.Unlock()
_, err := d.db.Exec("INSERT OR REPLACE INTO domains (domain, active) VALUES (?, ?)", domain, false)
err := d.db.DeleteDomain(context.Background(), domain)
if err != nil {
log.Printf("[Violet] Database error: %s\n", err)
logger.Logger.Infof("Database error: %s\n", err)
}
}

View File

@ -1,18 +1,20 @@
package domains
import (
"database/sql"
"context"
"github.com/1f349/violet"
"github.com/1f349/violet/database"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"testing"
)
func TestDomainsNew(t *testing.T) {
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
db, err := violet.InitDB("file:TestDomainsNew?mode=memory&cache=shared")
assert.NoError(t, err)
domains := New(db)
_, err = db.Exec("INSERT OR IGNORE INTO domains (domain, active) VALUES (?, ?)", "example.com", 1)
err = db.AddDomain(context.Background(), database.AddDomainParams{Domain: "example.com", Active: true})
assert.NoError(t, err)
domains.Compile()
@ -27,11 +29,11 @@ func TestDomainsNew(t *testing.T) {
func TestDomains_IsValid(t *testing.T) {
// open sqlite database
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
db, err := violet.InitDB("file:TestDomains_IsValid?mode=memory&cache=shared")
assert.NoError(t, err)
domains := New(db)
_, err = domains.db.Exec("INSERT OR IGNORE INTO domains (domain, active) VALUES (?, ?)", "example.com", 1)
err = db.AddDomain(context.Background(), database.AddDomainParams{Domain: "example.com", Active: true})
assert.NoError(t, err)
domains.s.Lock()

View File

@ -2,9 +2,9 @@ package error_pages
import (
"fmt"
"github.com/MrMelon54/rescheduler"
"github.com/1f349/violet/logger"
"github.com/mrmelon54/rescheduler"
"io/fs"
"log"
"net/http"
"path/filepath"
"strconv"
@ -12,6 +12,8 @@ import (
"sync"
)
var Logger = logger.Logger.WithPrefix("Violet Error Pages")
// ErrorPages stores the custom error pages and is called by the servers to
// output meaningful pages for HTTP error codes
type ErrorPages struct {
@ -78,7 +80,7 @@ func (e *ErrorPages) threadCompile() {
if e.dir != nil {
err := e.internalCompile(errorPageMap)
if err != nil {
log.Printf("[ErrorPages] Compile failed: %s\n", err)
Logger.Info("Compile failed", "err", err)
return
}
}
@ -96,7 +98,7 @@ func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) err
return fmt.Errorf("failed to read error pages dir: %w", err)
}
log.Printf("[ErrorPages] Compiling lookup table for %d error pages\n", len(files))
Logger.Info("Compiling lookup table", "page count", len(files))
// find and load error pages
for _, i := range files {
@ -111,20 +113,20 @@ func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) err
// if the extension is not 'html' then ignore the file
if ext != ".html" {
log.Printf("[ErrorPages] WARNING: ignoring non '.html' file in error pages directory: '%s'\n", name)
Logger.Warn("Ignoring non '.html' file in error pages directory", "name", name)
continue
}
// if the name can't be
nameInt, err := strconv.Atoi(strings.TrimSuffix(name, ".html"))
if err != nil {
log.Printf("[ErrorPages] WARNING: ignoring invalid error page in error pages directory: '%s'\n", name)
Logger.Warn("Ignoring invalid error page in error pages directory", "name", name)
continue
}
// check if code is in range 100-599
if nameInt < 100 || nameInt >= 600 {
log.Printf("[ErrorPages] WARNING: ignoring invalid error page in error pages directory must be 100-599: '%s'\n", name)
Logger.Warn("Ignoring invalid error page in error pages directory must be 100-599", "name", name)
continue
}

View File

@ -1,7 +0,0 @@
CREATE TABLE IF NOT EXISTS favicons (
id INTEGER PRIMARY KEY AUTOINCREMENT,
host VARCHAR,
svg VARCHAR,
png VARCHAR,
ico VARCHAR
);

View File

@ -1,5 +1,7 @@
package favicons
import "database/sql"
// FaviconImage stores the url, hash and raw bytes of an image
type FaviconImage struct {
Url string
@ -9,9 +11,9 @@ type FaviconImage struct {
// CreateFaviconImage outputs a FaviconImage with the specified URL or nil if
// the URL is an empty string.
func CreateFaviconImage(url string) *FaviconImage {
if url == "" {
func CreateFaviconImage(url sql.NullString) *FaviconImage {
if !url.Valid {
return nil
}
return &FaviconImage{Url: url}
return &FaviconImage{Url: url.String}
}

View File

@ -6,7 +6,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"github.com/MrMelon54/png2ico"
"github.com/mrmelon54/png2ico"
"image/png"
"io"
"net/http"
@ -74,7 +74,7 @@ func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error
// download SVG
l.Svg.Raw, err = getFaviconViaRequest(l.Svg.Url)
if err != nil {
return fmt.Errorf("[Favicons] Failed to fetch SVG icon: %w", err)
return fmt.Errorf("favicons: failed to fetch SVG icon: %w", err)
}
l.Svg.Hash = hex.EncodeToString(sha256.New().Sum(l.Svg.Raw))
}
@ -84,14 +84,14 @@ func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error
// download PNG
l.Png.Raw, err = getFaviconViaRequest(l.Png.Url)
if err != nil {
return fmt.Errorf("[Favicons] Failed to fetch PNG icon: %w", err)
return fmt.Errorf("favicons: failed to fetch PNG icon: %w", err)
}
} else if l.Svg != nil {
// generate PNG from SVG
l.Png = &FaviconImage{}
l.Png.Raw, err = convert(l.Svg.Raw)
if err != nil {
return fmt.Errorf("[Favicons] Failed to generate PNG icon: %w", err)
return fmt.Errorf("favicons: failed to generate PNG icon: %w", err)
}
}
@ -100,19 +100,19 @@ func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error
// download ICO
l.Ico.Raw, err = getFaviconViaRequest(l.Ico.Url)
if err != nil {
return fmt.Errorf("[Favicons] Failed to fetch ICO icon: %w", err)
return fmt.Errorf("favicons: failed to fetch ICO icon: %w", err)
}
} else if l.Png != nil {
// generate ICO from PNG
l.Ico = &FaviconImage{}
decode, err := png.Decode(bytes.NewReader(l.Png.Raw))
if err != nil {
return fmt.Errorf("[Favicons] Failed to decode PNG icon: %w", err)
return fmt.Errorf("favicons: failed to decode PNG icon: %w", err)
}
b := decode.Bounds()
l.Ico.Raw, err = png2ico.ConvertPngToIco(l.Png.Raw, b.Dx(), b.Dy())
if err != nil {
return fmt.Errorf("[Favicons] Failed to generate ICO icon: %w", err)
return fmt.Errorf("favicons: failed to generate ICO icon: %w", err)
}
}
@ -139,16 +139,16 @@ func (l *FaviconList) genSha256() {
var getFaviconViaRequest = func(url string) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("[Favicons] Failed to send request '%s': %w", url, err)
return nil, fmt.Errorf("favicons: Failed to send request '%s': %w", url, err)
}
req.Header.Set("X-Violet-Raw-Favicon", "1")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("[Favicons] Failed to do request '%s': %w", url, err)
return nil, fmt.Errorf("favicons: failed to do request '%s': %w", url, err)
}
rawBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("[Favicons] Failed to read response '%s': %w", url, err)
return nil, fmt.Errorf("favicons: failed to read response '%s': %w", url, err)
}
return rawBody, nil
}

View File

@ -1,24 +1,24 @@
package favicons
import (
"database/sql"
"context"
_ "embed"
"errors"
"fmt"
"github.com/MrMelon54/rescheduler"
"github.com/1f349/violet/database"
"github.com/1f349/violet/logger"
"github.com/mrmelon54/rescheduler"
"golang.org/x/sync/errgroup"
"log"
"sync"
)
var ErrFaviconNotFound = errors.New("favicon not found")
var Logger = logger.Logger.WithPrefix("Violet Favicons")
//go:embed create-table-favicons.sql
var createTableFavicons string
var ErrFaviconNotFound = errors.New("favicon not found")
// Favicons is a dynamic favicon generator which supports overwriting favicons
type Favicons struct {
db *sql.DB
db *database.Queries
cmd string
cLock *sync.RWMutex
faviconMap map[string]*FaviconList
@ -26,7 +26,7 @@ type Favicons struct {
}
// New creates a new dynamic favicon generator
func New(db *sql.DB, inkscapeCmd string) *Favicons {
func New(db *database.Queries, inkscapeCmd string) *Favicons {
f := &Favicons{
db: db,
cmd: inkscapeCmd,
@ -35,13 +35,6 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons {
}
f.r = rescheduler.NewRescheduler(f.threadCompile)
// init favicons table
_, err := f.db.Exec(createTableFavicons)
if err != nil {
log.Printf("[WARN] Failed to generate 'favicons' table\n")
return nil
}
// run compile to get the initial data
f.Compile()
return f
@ -75,7 +68,7 @@ func (f *Favicons) threadCompile() {
err := f.internalCompile(favicons)
if err != nil {
// log compile errors
log.Printf("[Favicons] Compile failed: %s\n", err)
Logger.Info("Compile failed", "err", err)
return
}
@ -89,29 +82,23 @@ func (f *Favicons) threadCompile() {
// favicons.
func (f *Favicons) internalCompile(m map[string]*FaviconList) error {
// query all rows in database
query, err := f.db.Query(`select host, svg, png, ico from favicons`)
rows, err := f.db.GetFavicons(context.Background())
if err != nil {
return fmt.Errorf("failed to prepare query: %w", err)
return fmt.Errorf("failed to prepare rows: %w", err)
}
// loop over rows and scan in data using error group to catch errors
var g errgroup.Group
for query.Next() {
var host, rawSvg, rawPng, rawIco string
err := query.Scan(&host, &rawSvg, &rawPng, &rawIco)
if err != nil {
return fmt.Errorf("failed to scan row: %w", err)
}
for _, row := range rows {
// create favicon list for this row
l := &FaviconList{
Ico: CreateFaviconImage(rawIco),
Png: CreateFaviconImage(rawPng),
Svg: CreateFaviconImage(rawSvg),
Ico: CreateFaviconImage(row.Ico),
Png: CreateFaviconImage(row.Png),
Svg: CreateFaviconImage(row.Svg),
}
// save the favicon list to the map
m[host] = l
m[row.Host] = l
// run the pre-process in a separate goroutine
g.Go(func() error {

View File

@ -2,8 +2,11 @@ package favicons
import (
"bytes"
"context"
"database/sql"
_ "embed"
"github.com/1f349/violet"
"github.com/1f349/violet/database"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"image/png"
@ -22,11 +25,17 @@ var (
func TestFaviconsNew(t *testing.T) {
getFaviconViaRequest = func(_ string) ([]byte, error) { return exampleSvg, nil }
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
db, err := violet.InitDB("file:TestFaviconsNew?mode=memory&cache=shared")
assert.NoError(t, err)
favicons := New(db, "inkscape")
_, err = db.Exec("insert into favicons (host, svg, png, ico) values (?, ?, ?, ?)", "example.com", "https://example.com/assets/logo.svg", "", "")
err = db.UpdateFaviconCache(context.Background(), database.UpdateFaviconCacheParams{
Host: "example.com",
Svg: sql.NullString{
String: "https://example.com/assets/logo.svg",
Valid: true,
},
})
assert.NoError(t, err)
favicons.cLock.Lock()
assert.NoError(t, favicons.internalCompile(favicons.faviconMap))

68
go.mod
View File

@ -1,40 +1,62 @@
module github.com/1f349/violet
go 1.20
go 1.22
require (
github.com/1f349/mjwt v0.2.5
github.com/AlecAivazis/survey/v2 v2.3.7
github.com/MrMelon54/certgen v0.0.1
github.com/MrMelon54/exit-reload v0.0.1
github.com/MrMelon54/mjwt v0.1.1
github.com/MrMelon54/png2ico v1.0.1
github.com/MrMelon54/rescheduler v0.0.1
github.com/MrMelon54/trie v0.0.2
github.com/charmbracelet/log v0.4.0
github.com/cloudflare/tableflip v1.2.3
github.com/golang-migrate/migrate/v4 v4.17.1
github.com/google/subcommands v1.2.0
github.com/google/uuid v1.3.1
github.com/gorilla/websocket v1.5.0
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.1
github.com/julienschmidt/httprouter v1.3.0
github.com/mattn/go-sqlite3 v1.14.16
github.com/rs/cors v1.9.0
github.com/sethvargo/go-limiter v0.7.2
github.com/stretchr/testify v1.8.4
golang.org/x/net v0.9.0
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4
github.com/mattn/go-sqlite3 v1.14.22
github.com/mrmelon54/certgen v0.0.2
github.com/mrmelon54/png2ico v1.0.2
github.com/mrmelon54/rescheduler v0.0.3
github.com/mrmelon54/trie v0.0.3
github.com/prometheus/client_golang v1.19.1
github.com/rs/cors v1.11.0
github.com/sethvargo/go-limiter v1.0.0
github.com/stretchr/testify v1.9.0
golang.org/x/net v0.25.0
golang.org/x/sync v0.7.0
)
require (
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/becheran/wildmatch-go v1.0.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/lipgloss v0.10.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-logfmt/logfmt v0.6.0 // indirect
github.com/golang-jwt/jwt/v4 v4.5.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/kr/pretty v0.1.0 // indirect
github.com/mattn/go-colorable v0.1.2 // indirect
github.com/mattn/go-isatty v0.0.8 // indirect
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect
github.com/muesli/reflow v0.3.0 // indirect
github.com/muesli/termenv v0.15.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/sys v0.7.0 // indirect
golang.org/x/term v0.7.0 // indirect
golang.org/x/text v0.9.0 // indirect
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.53.0 // indirect
github.com/prometheus/procfs v0.14.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect
golang.org/x/sys v0.20.0 // indirect
golang.org/x/term v0.20.0 // indirect
golang.org/x/text v0.15.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

148
go.sum
View File

@ -1,100 +1,162 @@
github.com/1f349/mjwt v0.2.5 h1:IxjLaali22ayTzZ628lH7j0JDdYJoj6+CJ/VktCqtXQ=
github.com/1f349/mjwt v0.2.5/go.mod h1:KEs6jd9JjWrQW+8feP2pGAU7pdA3aYTqjkT/YQr73PU=
github.com/AlecAivazis/survey/v2 v2.3.7 h1:6I/u8FvytdGsgonrYsVn2t8t4QiRnh6QSTqkkhIiSjQ=
github.com/AlecAivazis/survey/v2 v2.3.7/go.mod h1:xUTIdE4KCOIjsBAE1JYsUPoCqYdZ1reCfTwbto0Fduo=
github.com/MrMelon54/certgen v0.0.1 h1:ycWdZ2RlxQ5qSuejeBVv4aXjGo5hdqqL4j4EjrXnFMk=
github.com/MrMelon54/certgen v0.0.1/go.mod h1:GHflVlSbtFLJZLpN1oWyUvDBRrR8qCWiwZLXCCnS2Gc=
github.com/MrMelon54/exit-reload v0.0.1 h1:sxHa59tNEQMcikwuX2+93lw6Vi1+R7oCRF8a0C3alXc=
github.com/MrMelon54/exit-reload v0.0.1/go.mod h1:PLiSfmUzwdpTTQP3BBfUPhkqPwaIZjx0DuXBnM76Bug=
github.com/MrMelon54/mjwt v0.1.1 h1:m+aTpxbhQCrOPKHN170DQMFR5r938LkviU38unob5Jw=
github.com/MrMelon54/mjwt v0.1.1/go.mod h1:oYrDBWK09Hju98xb+bRQ0wy+RuAzacxYvKYOZchR2Tk=
github.com/MrMelon54/png2ico v1.0.1 h1:zJoSSl4OkvSIMWGyGPvb8fWNa0KrUvMIjgNGLNLJhVQ=
github.com/MrMelon54/png2ico v1.0.1/go.mod h1:NOv3tO4497mInG+3tcFkIohmxCywUwMLU8WNxJZLVmU=
github.com/MrMelon54/rescheduler v0.0.1 h1:gzNvL8X81M00uYN0i9clFVrXCkG1UuLNYxDcvjKyBqo=
github.com/MrMelon54/rescheduler v0.0.1/go.mod h1:OQDFtZHdS4/qA/r7rtJUQA22/hbpnZ9MGQCXOPjhC6w=
github.com/MrMelon54/trie v0.0.2 h1:ZXWcX5ij62O9K4I/anuHmVg8L3tF0UGdlPceAASwKEY=
github.com/MrMelon54/trie v0.0.2/go.mod h1:sGCGOcqb+DxSxvHgSOpbpkmA7mFZR47YDExy9OCbVZI=
github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2 h1:+vx7roKuyA63nhn5WAunQHLTznkw5W8b1Xc0dNjp83s=
github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2/go.mod h1:HBCaDeC1lPdgDeDbhX8XFpy1jqjK0IBG8W5K+xYqA0w=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA=
github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMtwg/AFW3s=
github.com/charmbracelet/lipgloss v0.10.0/go.mod h1:Wig9DSfvANsxqkRsqj6x87irdy123SR4dOXlKa91ciE=
github.com/charmbracelet/log v0.4.0 h1:G9bQAcx8rWA2T3pWvx7YtPTPwgqpk7D68BX21IRW8ZM=
github.com/charmbracelet/log v0.4.0/go.mod h1:63bXt/djrizTec0l11H20t8FDSvA4CRZJ1KH22MdptM=
github.com/cloudflare/tableflip v1.2.3 h1:8I+B99QnnEWPHOY3fWipwVKxS70LGgUsslG7CSfmHMw=
github.com/cloudflare/tableflip v1.2.3/go.mod h1:P4gRehmV6Z2bY5ao5ml9Pd8u6kuEnlB37pUFMmv7j2E=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/creack/pty v1.1.17 h1:QeVUsEDNrLBW4tMgZHvxy18sKtr6VI492kBhUfhDJNI=
github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-migrate/migrate/v4 v4.17.1 h1:4zQ6iqL6t6AiItphxJctQb3cFqWiSpMnX7wLTPnnYO4=
github.com/golang-migrate/migrate/v4 v4.17.1/go.mod h1:m8hinFyWBn0SA4QKHuKh175Pm9wjmxj3S2Mia7dbXzM=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog=
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI=
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
github.com/mrmelon54/certgen v0.0.2 h1:4CMDkA/gGZu+E4iikU+5qdOWK7qOQrk58KtUfnmyYmY=
github.com/mrmelon54/certgen v0.0.2/go.mod h1:vwrWSXQmxZYqEyh+cf05IvDIFV2aYuxL4+O6ABIlN8M=
github.com/mrmelon54/png2ico v1.0.2 h1:KyJd3ATmDjxAJS28MTSf44GxzYnlZ+7KT8SXzGb3sN8=
github.com/mrmelon54/png2ico v1.0.2/go.mod h1:vp8Be9y5cz102ANon+BnsIzTUdet3VQRvOuWJTH9h0M=
github.com/mrmelon54/rescheduler v0.0.3 h1:TrkJL6S7PKvXuo1mvdgRgsILA/pk5L1lrXhV/q7IEzQ=
github.com/mrmelon54/rescheduler v0.0.3/go.mod h1:q415n6W1xcePPP5Rix6FOiADgcN66BYMyNOsFnNyoWQ=
github.com/mrmelon54/trie v0.0.3 h1:wZmws84FiGNBZJ00garLyQ2EQhtx0SipVoV7fK8+kZE=
github.com/mrmelon54/trie v0.0.3/go.mod h1:d3hl0YUBSWR3XN4S9BDLkGVzLT4VgwP2mZkBJM6uFpw=
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo=
github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE=
github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/sethvargo/go-limiter v0.7.2 h1:FgC4N7RMpV5gMrUdda15FaFTkQ/L4fEqM7seXMs4oO8=
github.com/sethvargo/go-limiter v0.7.2/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU=
github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE=
github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+aLCE=
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
github.com/prometheus/procfs v0.14.0 h1:Lw4VdGGoKEZilJsayHf0B+9YgLGREba2C6xr+Fdfq6s=
github.com/prometheus/procfs v0.14.0/go.mod h1:XL+Iwz8k8ZabyZfMFHPiilCniixqQarAy5Mu67pHlNQ=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/rs/cors v1.11.0 h1:0B9GE/r9Bc2UxRMMtymBkHTenPkHDv0CW4Y98GBY+po=
github.com/rs/cors v1.11.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/sethvargo/go-limiter v1.0.0 h1:JqW13eWEMn0VFv86OKn8wiYJY/m250WoXdrjRV0kLe4=
github.com/sethvargo/go-limiter v1.0.0/go.mod h1:01b6tW25Ap+MeLYBuD4aHunMrJoNO5PVUFdS9rac3II=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.7.0 h1:BEvjmm5fURWqcfbSKTdpkDXYBrUS1c0m8agp14W48vQ=
golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

38
initdb.go Normal file
View File

@ -0,0 +1,38 @@
package violet
import (
"database/sql"
"embed"
"errors"
"github.com/1f349/violet/database"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/iofs"
)
//go:embed database/migrations/*.sql
var migrations embed.FS
func InitDB(p string) (*database.Queries, error) {
migDrv, err := iofs.New(migrations, "database/migrations")
if err != nil {
return nil, err
}
dbOpen, err := sql.Open("sqlite3", p)
if err != nil {
return nil, err
}
dbDrv, err := sqlite3.WithInstance(dbOpen, &sqlite3.Config{})
if err != nil {
return nil, err
}
mig, err := migrate.NewWithInstance("iofs", migDrv, "sqlite3", dbDrv)
if err != nil {
return nil, err
}
err = mig.Up()
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
return nil, err
}
return database.New(dbOpen), nil
}

12
logger/logger.go Normal file
View File

@ -0,0 +1,12 @@
package logger
import (
"github.com/charmbracelet/log"
"os"
)
var Logger = log.NewWithOptions(os.Stderr, log.Options{
ReportCaller: true,
ReportTimestamp: true,
Prefix: "Violet",
})

View File

@ -2,15 +2,18 @@ package proxy
import (
"crypto/tls"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/proxy/websocket"
"github.com/google/uuid"
"log"
"net"
"net/http"
"sync"
"time"
)
var loggerSecure = logger.Logger.WithPrefix("Violet Secure Transport")
var loggerInsecure = logger.Logger.WithPrefix("Violet Insecure Transport")
var loggerWebsocket = logger.Logger.WithPrefix("Violet Websocket Transport")
type HybridTransport struct {
baseDialer *net.Dialer
normalTransport http.RoundTripper
@ -71,24 +74,15 @@ func NewHybridTransportWithCalls(normal, insecure http.RoundTripper, ws *websock
// SecureRoundTrip calls the secure transport
func (h *HybridTransport) SecureRoundTrip(req *http.Request) (*http.Response, error) {
u := uuid.New()
log.Println("[Transport] Start upgrade:", u)
defer log.Println("[Transport] Stop upgrade:", u)
return h.normalTransport.RoundTrip(req)
}
// InsecureRoundTrip calls the insecure transport
func (h *HybridTransport) InsecureRoundTrip(req *http.Request) (*http.Response, error) {
u := uuid.New()
log.Println("[Transport insecure] Start upgrade:", u)
defer log.Println("[Transport insecure] Stop upgrade:", u)
return h.insecureTransport.RoundTrip(req)
}
// ConnectWebsocket calls the websocket upgrader and thus hijacks the connection
func (h *HybridTransport) ConnectWebsocket(rw http.ResponseWriter, req *http.Request) {
u := uuid.New()
log.Println("[Websocket] Start upgrade:", u)
h.ws.Upgrade(rw, req)
log.Println("[Websocket] Stop upgrade:", u)
}

View File

@ -1,13 +1,16 @@
package websocket
import (
"github.com/1f349/violet/logger"
"github.com/gorilla/websocket"
"log"
"net/http"
"slices"
"sync"
"time"
)
var Logger = logger.Logger.WithPrefix("Violet Websocket")
var upgrader = websocket.Upgrader{
HandshakeTimeout: time.Second * 5,
ReadBufferSize: 1024,
@ -34,7 +37,7 @@ func NewServer() *Server {
func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) {
req.URL.Scheme = "ws"
log.Printf("[Websocket] Upgrading request to '%s' from '%s'\n", req.URL.String(), req.Header.Get("Origin"))
Logger.Info("Upgrading request", "url", req.URL, "origin", req.Header.Get("Origin"))
c, err := upgrader.Upgrade(rw, req, nil)
if err != nil {
@ -54,12 +57,12 @@ func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) {
s.conns[c.RemoteAddr().String()] = c
s.connLock.Unlock()
log.Printf("[Websocket] Dialing: '%s'\n", req.URL.String())
Logger.Info("Dialing", "url", req.URL)
// dial for internal connection
ic, _, err := websocket.DefaultDialer.DialContext(req.Context(), req.URL.String(), nil)
ic, _, err := websocket.DefaultDialer.DialContext(req.Context(), req.URL.String(), filterWebsocketHeaders(req.Header))
if err != nil {
log.Printf("[Websocket] Failed to dial '%s': %s\n", req.URL.String(), err)
Logger.Info("Failed to dial", "url", req.URL, "err", err)
s.Remove(c)
return
}
@ -73,7 +76,7 @@ func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) {
go s.wsRelay(d2, ic, c)
// wait for done signal and close both connections
log.Println("[Websocket] Completed websocket hijacking")
Logger.Info("Completed websocket hijacking")
// waiting until d1 or d2 close then automatically defer close both connections
select {
@ -82,6 +85,17 @@ func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) {
}
}
// filterWebsocketHeaders allows specific headers to forward to the underlying websocket connection
func filterWebsocketHeaders(headers http.Header) (out http.Header) {
out = make(http.Header)
for k, v := range headers {
if k == "Origin" {
out[k] = slices.Clone(v)
}
}
return
}
func (s *Server) wsRelay(done chan struct{}, a, b *websocket.Conn) {
defer func() {
close(done)
@ -89,7 +103,7 @@ func (s *Server) wsRelay(done chan struct{}, a, b *websocket.Conn) {
for {
mt, message, err := a.ReadMessage()
if err != nil {
log.Println("Websocket read message error: ", err)
Logger.Info("Read message", "err", err)
return
}
if b.WriteMessage(mt, message) != nil {

View File

@ -1,18 +0,0 @@
CREATE TABLE IF NOT EXISTS routes
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
source TEXT UNIQUE,
destination TEXT,
flags INTEGER DEFAULT 0,
active INTEGER DEFAULT 1
);
CREATE TABLE IF NOT EXISTS redirects
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
source TEXT UNIQUE,
destination TEXT,
flags INTEGER DEFAULT 0,
code INTEGER DEFAULT 0,
active INTEGER DEFAULT 1
);

View File

@ -1,34 +1,33 @@
package router
import (
"database/sql"
"context"
_ "embed"
"github.com/1f349/violet/database"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/target"
"github.com/MrMelon54/rescheduler"
"log"
"github.com/mrmelon54/rescheduler"
"net/http"
"strings"
"sync"
)
var Logger = logger.Logger.WithPrefix("Violet Manager")
// Manager is a database and mutex wrap around router allowing it to be
// dynamically regenerated after updating the database of routes.
type Manager struct {
db *sql.DB
db *database.Queries
s *sync.RWMutex
r *Router
p *proxy.HybridTransport
z *rescheduler.Rescheduler
}
var (
//go:embed create-tables.sql
createTables string
)
// NewManager create a new manager, initialises the routes and redirects tables
// in the database and runs a first time compile.
func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager {
func NewManager(db *database.Queries, proxy *proxy.HybridTransport) *Manager {
m := &Manager{
db: db,
s: &sync.RWMutex{},
@ -36,13 +35,6 @@ func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager {
p: proxy,
}
m.z = rescheduler.NewRescheduler(m.threadCompile)
// init routes table
_, err := m.db.Exec(createTables)
if err != nil {
log.Printf("[WARN] Failed to generate tables\n")
return nil
}
return m
}
@ -64,7 +56,7 @@ func (m *Manager) threadCompile() {
// compile router and check errors
err := m.internalCompile(router)
if err != nil {
log.Printf("[Manager] Compile failed: %s\n", err)
Logger.Info("Compile failed", "err", err)
return
}
@ -77,123 +69,160 @@ func (m *Manager) threadCompile() {
// internalCompile is a hidden internal method for querying the database during
// the Compile() method.
func (m *Manager) internalCompile(router *Router) error {
log.Println("[Manager] Updating routes from database")
Logger.Info("Updating routes from database")
// sql or something?
rows, err := m.db.Query(`SELECT source, destination, flags FROM routes WHERE active = 1`)
routeRows, err := m.db.GetActiveRoutes(context.Background())
if err != nil {
return err
}
defer rows.Close()
// loop through rows and scan the options
for rows.Next() {
var (
src, dst string
flags target.Flags
)
err := rows.Scan(&src, &dst, &flags)
if err != nil {
return err
}
for _, row := range routeRows {
router.AddRoute(target.Route{
Src: src,
Dst: dst,
Flags: flags.NormaliseRouteFlags(),
Src: row.Source,
Dst: row.Destination,
Flags: row.Flags.NormaliseRouteFlags(),
})
}
// check for errors
if err := rows.Err(); err != nil {
return err
}
// sql or something?
rows, err = m.db.Query(`SELECT source,destination,flags,code FROM redirects WHERE active = 1`)
redirectsRows, err := m.db.GetActiveRedirects(context.Background())
if err != nil {
return err
}
defer rows.Close()
// loop through rows and scan the options
for rows.Next() {
var (
src, dst string
flags target.Flags
code int
)
err := rows.Scan(&src, &dst, &flags, &code)
if err != nil {
return err
}
for _, row := range redirectsRows {
router.AddRedirect(target.Redirect{
Src: src,
Dst: dst,
Flags: flags.NormaliseRedirectFlags(),
Code: code,
Src: row.Source,
Dst: row.Destination,
Flags: row.Flags.NormaliseRedirectFlags(),
Code: row.Code,
})
}
// check for errors
return rows.Err()
return nil
}
func (m *Manager) GetAllRoutes() ([]target.RouteWithActive, error) {
func (m *Manager) GetAllRoutes(hosts []string) ([]target.RouteWithActive, error) {
if len(hosts) < 1 {
return []target.RouteWithActive{}, nil
}
s := make([]target.RouteWithActive, 0)
query, err := m.db.Query(`SELECT source, destination, flags, active FROM routes`)
rows, err := m.db.GetAllRoutes(context.Background())
if err != nil {
return nil, err
}
for query.Next() {
var a target.RouteWithActive
if query.Scan(&a.Src, &a.Dst, &a.Flags, &a.Active) != nil {
return nil, err
for _, row := range rows {
a := target.RouteWithActive{
Route: target.Route{
Src: row.Source,
Dst: row.Destination,
Desc: row.Description,
Flags: row.Flags,
},
Active: row.Active,
}
for _, i := range hosts {
// if this is never true then the domain was mistakenly grabbed from the database
if a.OnDomain(i) {
s = append(s, a)
break
}
}
s = append(s, a)
}
return s, nil
}
func (m *Manager) InsertRoute(route target.Route) error {
_, err := m.db.Exec(`INSERT INTO routes (source, destination, flags) VALUES (?, ?, ?) ON CONFLICT(source) DO UPDATE SET destination = excluded.destination, flags = excluded.flags, active = 1`, route.Src, route.Dst, route.Flags)
return err
func (m *Manager) InsertRoute(route target.RouteWithActive) error {
return m.db.AddRoute(context.Background(), database.AddRouteParams{
Source: route.Src,
Destination: route.Dst,
Description: route.Desc,
Flags: route.Flags,
Active: route.Active,
})
}
func (m *Manager) DeleteRoute(source string) error {
_, err := m.db.Exec(`UPDATE routes SET active = 0 WHERE source = ?`, source)
return err
return m.db.RemoveRoute(context.Background(), source)
}
func (m *Manager) GetAllRedirects() ([]target.RedirectWithActive, error) {
func (m *Manager) GetAllRedirects(hosts []string) ([]target.RedirectWithActive, error) {
if len(hosts) < 1 {
return []target.RedirectWithActive{}, nil
}
s := make([]target.RedirectWithActive, 0)
query, err := m.db.Query(`SELECT source, destination, flags, code, active FROM redirects`)
rows, err := m.db.GetAllRedirects(context.Background())
if err != nil {
return nil, err
}
for query.Next() {
var a target.RedirectWithActive
if query.Scan(&a.Src, &a.Dst, &a.Flags, &a.Code, &a.Active) != nil {
return nil, err
for _, row := range rows {
a := target.RedirectWithActive{
Redirect: target.Redirect{
Src: row.Source,
Dst: row.Destination,
Desc: row.Description,
Flags: row.Flags,
Code: row.Code,
},
Active: row.Active,
}
for _, i := range hosts {
// if this is never true then the domain was mistakenly grabbed from the database
if a.OnDomain(i) {
s = append(s, a)
break
}
}
s = append(s, a)
}
return s, nil
}
func (m *Manager) InsertRedirect(redirect target.Redirect) error {
_, err := m.db.Exec(`INSERT INTO redirects (source, destination, flags, code) VALUES (?, ?, ?, ?) ON CONFLICT(source) DO UPDATE SET destination = excluded.destination, flags = excluded.flags, code = excluded.code, active = 1`, redirect.Src, redirect.Dst, redirect.Flags, redirect.Code)
return err
func (m *Manager) InsertRedirect(redirect target.RedirectWithActive) error {
return m.db.AddRedirect(context.Background(), database.AddRedirectParams{
Source: redirect.Src,
Destination: redirect.Dst,
Description: redirect.Desc,
Flags: redirect.Flags,
Code: redirect.Code,
Active: redirect.Active,
})
}
func (m *Manager) DeleteRedirect(source string) error {
_, err := m.db.Exec(`UPDATE redirects SET active = 0 WHERE source = ?`, source)
return err
return m.db.RemoveRedirect(context.Background(), source)
}
// GenerateHostSearch this should help improve performance
// TODO(Melon) discover how to implement this correctly
func GenerateHostSearch(hosts []string) (string, []string) {
var searchString strings.Builder
searchString.WriteString("WHERE ")
hostArgs := make([]string, len(hosts)*2)
for i := range hosts {
if i != 0 {
searchString.WriteString(" OR ")
}
// these like checks are not perfect but do reduce load on the database
searchString.WriteString("source LIKE '%' + ? + '/%'")
searchString.WriteString(" OR source LIKE '%' + ?")
// loads the hostname into even and odd args
hostArgs[i*2] = hosts[i]
hostArgs[i*2+1] = hosts[i]
}
return searchString.String(), hostArgs
}

View File

@ -1,7 +1,9 @@
package router
import (
"database/sql"
"context"
"github.com/1f349/violet"
"github.com/1f349/violet/database"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/1f349/violet/target"
@ -22,7 +24,7 @@ func (f *fakeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
}
func TestNewManager(t *testing.T) {
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
db, err := violet.InitDB("file:TestNewManager?mode=memory&cache=shared")
assert.NoError(t, err)
ft := &fakeTransport{}
@ -39,7 +41,13 @@ func TestNewManager(t *testing.T) {
assert.Equal(t, http.StatusTeapot, res.StatusCode)
assert.Nil(t, ft.req)
_, err = db.Exec(`INSERT INTO routes (source, destination, flags, active) VALUES (?,?,?,1)`, "*.example.com", "127.0.0.1:8080", target.FlagAbs|target.FlagForwardHost|target.FlagForwardAddr)
err = db.AddRoute(context.Background(), database.AddRouteParams{
Source: "*.example.com",
Destination: "127.0.0.1:8080",
Description: "",
Flags: target.FlagAbs | target.FlagForwardHost | target.FlagForwardAddr,
Active: true,
})
assert.NoError(t, err)
assert.NoError(t, m.internalCompile(m.r))
@ -50,3 +58,71 @@ func TestNewManager(t *testing.T) {
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.NotNil(t, ft.req)
}
func TestManager_GetAllRoutes(t *testing.T) {
db, err := violet.InitDB("file:TestManager_GetAllRoutes?mode=memory&cache=shared")
assert.NoError(t, err)
m := NewManager(db, nil)
a := []error{
m.InsertRoute(target.RouteWithActive{Route: target.Route{Src: "example.com"}, Active: true}),
m.InsertRoute(target.RouteWithActive{Route: target.Route{Src: "test.example.com"}, Active: true}),
m.InsertRoute(target.RouteWithActive{Route: target.Route{Src: "example.com/hello"}, Active: true}),
m.InsertRoute(target.RouteWithActive{Route: target.Route{Src: "test.example.com/hello"}, Active: true}),
m.InsertRoute(target.RouteWithActive{Route: target.Route{Src: "example.org"}, Active: true}),
m.InsertRoute(target.RouteWithActive{Route: target.Route{Src: "test.example.org"}, Active: true}),
m.InsertRoute(target.RouteWithActive{Route: target.Route{Src: "example.org/hello"}, Active: true}),
m.InsertRoute(target.RouteWithActive{Route: target.Route{Src: "test.example.org/hello"}, Active: true}),
}
for _, i := range a {
if i != nil {
t.Fatal(i)
}
}
routes, err := m.GetAllRoutes([]string{"example.com"})
if err != nil {
t.Fatal(err)
}
assert.Equal(t, []target.RouteWithActive{
{Route: target.Route{Src: "example.com"}, Active: true},
{Route: target.Route{Src: "test.example.com"}, Active: true},
{Route: target.Route{Src: "example.com/hello"}, Active: true},
{Route: target.Route{Src: "test.example.com/hello"}, Active: true},
}, routes)
}
func TestManager_GetAllRedirects(t *testing.T) {
db, err := violet.InitDB("file:TestManager_GetAllRedirects?mode=memory&cache=shared")
assert.NoError(t, err)
m := NewManager(db, nil)
a := []error{
m.InsertRedirect(target.RedirectWithActive{Redirect: target.Redirect{Src: "example.com"}, Active: true}),
m.InsertRedirect(target.RedirectWithActive{Redirect: target.Redirect{Src: "test.example.com"}, Active: true}),
m.InsertRedirect(target.RedirectWithActive{Redirect: target.Redirect{Src: "example.com/hello"}, Active: true}),
m.InsertRedirect(target.RedirectWithActive{Redirect: target.Redirect{Src: "test.example.com/hello"}, Active: true}),
m.InsertRedirect(target.RedirectWithActive{Redirect: target.Redirect{Src: "example.org"}, Active: true}),
m.InsertRedirect(target.RedirectWithActive{Redirect: target.Redirect{Src: "test.example.org"}, Active: true}),
m.InsertRedirect(target.RedirectWithActive{Redirect: target.Redirect{Src: "example.org/hello"}, Active: true}),
m.InsertRedirect(target.RedirectWithActive{Redirect: target.Redirect{Src: "test.example.org/hello"}, Active: true}),
}
for _, i := range a {
if i != nil {
t.Fatal(i)
}
}
redirects, err := m.GetAllRedirects([]string{"example.com"})
if err != nil {
t.Fatal(err)
}
assert.Equal(t, []target.RedirectWithActive{
{Redirect: target.Redirect{Src: "example.com"}, Active: true},
{Redirect: target.Redirect{Src: "test.example.com"}, Active: true},
{Redirect: target.Redirect{Src: "example.com/hello"}, Active: true},
{Redirect: target.Redirect{Src: "test.example.com/hello"}, Active: true},
}, redirects)
}
func TestGenerateHostSearch(t *testing.T) {
query, args := GenerateHostSearch([]string{"example.com", "example.org"})
assert.Equal(t, "WHERE source LIKE '%' + ? + '/%' OR source LIKE '%' + ? OR source LIKE '%' + ? + '/%' OR source LIKE '%' + ?", query)
assert.Equal(t, []string{"example.com", "example.com", "example.org", "example.org"}, args)
}

View File

@ -5,7 +5,7 @@ import (
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/target"
"github.com/1f349/violet/utils"
"github.com/MrMelon54/trie"
"github.com/mrmelon54/trie"
"net/http"
"strings"
)

View File

@ -5,7 +5,7 @@ import (
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/1f349/violet/target"
"github.com/MrMelon54/trie"
"github.com/mrmelon54/trie"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"

View File

@ -2,11 +2,13 @@ package api
import (
"encoding/json"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/1f349/violet/servers/conf"
"github.com/1f349/violet/utils"
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/mjwt/claims"
"github.com/julienschmidt/httprouter"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"net/http"
"time"
)
@ -15,12 +17,15 @@ import (
// endpoints for the software
//
// `/compile` - reloads all domains, routes and redirects
func NewApiServer(conf *conf.Conf, compileTarget utils.MultiCompilable) *http.Server {
func NewApiServer(conf *conf.Conf, compileTarget utils.MultiCompilable, registry *prometheus.Registry) *http.Server {
r := httprouter.New()
r.GET("/", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
http.Error(rw, "Violet API Endpoint", http.StatusOK)
})
r.GET("/metrics", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
promhttp.HandlerFor(registry, promhttp.HandlerOpts{}).ServeHTTP(rw, req)
})
// Endpoint for compile action
r.POST("/compile", checkAuthWithPerm(conf.Signer, "violet:compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, b AuthClaims) {
@ -43,7 +48,6 @@ func NewApiServer(conf *conf.Conf, compileTarget utils.MultiCompilable) *http.Se
// Create and run http server
return &http.Server{
Addr: conf.ApiListen,
Handler: r,
ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute,
@ -86,11 +90,21 @@ func acmeChallengeManage(verify mjwt.Verifier, domains utils.DomainProvider, acm
})
}
// getDomainOwnershipClaims returns the domains marked as owned from PermStorage,
// they match `domain:owns=<fqdn>` where fqdn will be returned
func getDomainOwnershipClaims(perms *claims.PermStorage) []string {
a := perms.Search("domain:owns=*")
for i := range a {
a[i] = a[i][len("domain:owns="):]
}
return a
}
// validateDomainOwnershipClaims validates if the claims contain the
// `owns=<fqdn>` field with the matching top level domain
// `domain:owns=<fqdn>` field with the matching top level domain
func validateDomainOwnershipClaims(a string, perms *claims.PermStorage) bool {
if fqdn, ok := utils.GetTopFqdn(a); ok {
if perms.Has("owns=" + fqdn) {
if perms.Has("domain:owns=" + fqdn) {
return true
}
}

View File

@ -17,7 +17,7 @@ func TestNewApiServer_Compile(t *testing.T) {
Signer: fake.SnakeOilProv,
}
f := &fake.Compilable{}
srv := NewApiServer(apiConf, utils.MultiCompilable{f})
srv := NewApiServer(apiConf, utils.MultiCompilable{f}, nil)
req, err := http.NewRequest(http.MethodPost, "https://example.com/compile", nil)
assert.NoError(t, err)
@ -43,7 +43,7 @@ func TestNewApiServer_AcmeChallenge_Put(t *testing.T) {
Acme: utils.NewAcmeChallenge(),
Signer: fake.SnakeOilProv,
}
srv := NewApiServer(apiConf, utils.MultiCompilable{})
srv := NewApiServer(apiConf, utils.MultiCompilable{}, nil)
acmeKey := fake.GenSnakeOilKey("violet:acme-challenge")
// Valid domain
@ -87,7 +87,7 @@ func TestNewApiServer_AcmeChallenge_Delete(t *testing.T) {
Acme: utils.NewAcmeChallenge(),
Signer: fake.SnakeOilProv,
}
srv := NewApiServer(apiConf, utils.MultiCompilable{})
srv := NewApiServer(apiConf, utils.MultiCompilable{}, nil)
acmeKey := fake.GenSnakeOilKey("violet:acme-challenge")
// Valid domain

View File

@ -1,9 +1,9 @@
package api
import (
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/auth"
"github.com/1f349/violet/utils"
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/mjwt/auth"
"github.com/julienschmidt/httprouter"
"net/http"
)

View File

@ -10,11 +10,11 @@ type sourceJson struct {
func (s sourceJson) GetSource() string { return s.Src }
type routeSource target.Route
type routeSource target.RouteWithActive
func (r routeSource) GetSource() string { return r.Src }
type redirectSource target.Redirect
type redirectSource target.RedirectWithActive
func (r redirectSource) GetSource() string { return r.Src }

View File

@ -2,12 +2,12 @@ package api
import (
"encoding/json"
"github.com/1f349/mjwt"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/router"
"github.com/1f349/violet/target"
"github.com/1f349/violet/utils"
"github.com/MrMelon54/mjwt"
"github.com/julienschmidt/httprouter"
"log"
"net/http"
"strings"
)
@ -15,8 +15,11 @@ import (
func SetupTargetApis(r *httprouter.Router, verify mjwt.Verifier, manager *router.Manager) {
// Endpoint for routes
r.GET("/route", checkAuthWithPerm(verify, "violet:route", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
routes, err := manager.GetAllRoutes()
domains := getDomainOwnershipClaims(b.Claims.Perms)
routes, err := manager.GetAllRoutes(domains)
if err != nil {
logger.Logger.Infof("Failed to get routes from database: %s\n", err)
apiError(rw, http.StatusInternalServerError, "Failed to get routes from database")
return
}
@ -24,9 +27,9 @@ func SetupTargetApis(r *httprouter.Router, verify mjwt.Verifier, manager *router
_ = json.NewEncoder(rw).Encode(routes)
}))
r.POST("/route", parseJsonAndCheckOwnership[routeSource](verify, "route", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t routeSource) {
err := manager.InsertRoute(target.Route(t))
err := manager.InsertRoute(target.RouteWithActive(t))
if err != nil {
log.Printf("[Violet] Failed to insert route into database: %s\n", err)
logger.Logger.Infof("Failed to insert route into database: %s\n", err)
apiError(rw, http.StatusInternalServerError, "Failed to insert route into database")
return
}
@ -35,7 +38,7 @@ func SetupTargetApis(r *httprouter.Router, verify mjwt.Verifier, manager *router
r.DELETE("/route", parseJsonAndCheckOwnership[sourceJson](verify, "route", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t sourceJson) {
err := manager.DeleteRoute(t.Src)
if err != nil {
log.Printf("[Violet] Failed to delete route from database: %s\n", err)
logger.Logger.Infof("Failed to delete route from database: %s\n", err)
apiError(rw, http.StatusInternalServerError, "Failed to delete route from database")
return
}
@ -44,8 +47,11 @@ func SetupTargetApis(r *httprouter.Router, verify mjwt.Verifier, manager *router
// Endpoint for redirects
r.GET("/redirect", checkAuthWithPerm(verify, "violet:redirect", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
redirects, err := manager.GetAllRedirects()
domains := getDomainOwnershipClaims(b.Claims.Perms)
redirects, err := manager.GetAllRedirects(domains)
if err != nil {
logger.Logger.Infof("Failed to get redirects from database: %s\n", err)
apiError(rw, http.StatusInternalServerError, "Failed to get redirects from database")
return
}
@ -53,9 +59,9 @@ func SetupTargetApis(r *httprouter.Router, verify mjwt.Verifier, manager *router
_ = json.NewEncoder(rw).Encode(redirects)
}))
r.POST("/redirect", parseJsonAndCheckOwnership[redirectSource](verify, "redirect", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t redirectSource) {
err := manager.InsertRedirect(target.Redirect(t))
err := manager.InsertRedirect(target.RedirectWithActive(t))
if err != nil {
log.Printf("[Violet] Failed to insert redirect into database: %s\n", err)
logger.Logger.Infof("Failed to insert redirect into database: %s\n", err)
apiError(rw, http.StatusInternalServerError, "Failed to insert redirect into database")
return
}
@ -64,7 +70,7 @@ func SetupTargetApis(r *httprouter.Router, verify mjwt.Verifier, manager *router
r.DELETE("/redirect", parseJsonAndCheckOwnership[sourceJson](verify, "redirect", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t sourceJson) {
err := manager.DeleteRedirect(t.Src)
if err != nil {
log.Printf("[Violet] Failed to delete redirect from database: %s\n", err)
logger.Logger.Infof("Failed to delete redirect from database: %s\n", err)
apiError(rw, http.StatusInternalServerError, "Failed to delete redirect from database")
return
}

View File

@ -1,26 +1,23 @@
package conf
import (
"database/sql"
"github.com/1f349/mjwt"
"github.com/1f349/violet/database"
errorPages "github.com/1f349/violet/error-pages"
"github.com/1f349/violet/favicons"
"github.com/1f349/violet/router"
"github.com/1f349/violet/utils"
"github.com/MrMelon54/mjwt"
)
// Conf stores the shared configuration for the API, HTTP and HTTPS servers.
type Conf struct {
ApiListen string // api server listen address
HttpListen string // http server listen address
HttpsListen string // https server listen address
RateLimit uint64 // rate limit per minute
DB *sql.DB
Domains utils.DomainProvider
Acme utils.AcmeChallengeProvider
Certs utils.CertProvider
Favicons *favicons.Favicons
Signer mjwt.Verifier
ErrorPages *errorPages.ErrorPages
Router *router.Manager
RateLimit uint64 // rate limit per minute
DB *database.Queries
Domains utils.DomainProvider
Acme utils.AcmeChallengeProvider
Certs utils.CertProvider
Favicons *favicons.Favicons
Signer mjwt.Verifier
ErrorPages *errorPages.ErrorPages
Router *router.Manager
}

View File

@ -3,8 +3,10 @@ package servers
import (
"fmt"
"github.com/1f349/violet/servers/conf"
"github.com/1f349/violet/servers/metrics"
"github.com/1f349/violet/utils"
"github.com/julienschmidt/httprouter"
"github.com/prometheus/client_golang/prometheus"
"net/http"
"net/url"
"time"
@ -15,13 +17,9 @@ import (
//
// `/.well-known/acme-challenge/{token}` is used for outputting answers for
// acme challenges, this is used for Let's Encrypt HTTP verification.
func NewHttpServer(conf *conf.Conf) *http.Server {
func NewHttpServer(httpsPort uint16, conf *conf.Conf, registry *prometheus.Registry) *http.Server {
r := httprouter.New()
var secureExtend string
_, httpsPort, ok := utils.SplitDomainPort(conf.HttpsListen, 443)
if !ok {
httpsPort = 443
}
if httpsPort != 443 {
secureExtend = fmt.Sprintf(":%d", httpsPort)
}
@ -61,10 +59,16 @@ func NewHttpServer(conf *conf.Conf) *http.Server {
utils.FastRedirect(rw, req, u.String(), http.StatusPermanentRedirect)
})
metricsMiddleware := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
r.ServeHTTP(rw, req)
})
if registry != nil {
metricsMiddleware = metrics.New(registry, nil).WrapHandler("violet-http-insecure", r)
}
// Create and run http server
return &http.Server{
Addr: conf.HttpListen,
Handler: r,
Handler: metricsMiddleware,
ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute,
WriteTimeout: time.Minute,

View File

@ -18,7 +18,7 @@ func TestNewHttpServer_AcmeChallenge(t *testing.T) {
Acme: utils.NewAcmeChallenge(),
Signer: fake.SnakeOilProv,
}
srv := NewHttpServer(httpConf)
srv := NewHttpServer(443, httpConf, nil)
httpConf.Acme.Put("example.com", "456", "456def")
req, err := http.NewRequest(http.MethodGet, "https://example.com/.well-known/acme-challenge/456", nil)

View File

@ -4,11 +4,13 @@ import (
"crypto/tls"
"fmt"
"github.com/1f349/violet/favicons"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/servers/conf"
"github.com/1f349/violet/servers/metrics"
"github.com/1f349/violet/utils"
"github.com/prometheus/client_golang/prometheus"
"github.com/sethvargo/go-limiter/httplimit"
"github.com/sethvargo/go-limiter/memorystore"
"log"
"net/http"
"path"
"runtime"
@ -17,32 +19,57 @@ import (
// NewHttpsServer creates and runs a http server containing the public https
// endpoints for the reverse proxy.
func NewHttpsServer(conf *conf.Conf) *http.Server {
func NewHttpsServer(conf *conf.Conf, registry *prometheus.Registry) *http.Server {
r := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
log.Printf("[Debug] Request: %s - '%s' - '%s' - '%s' - len: %d - thread: %d\n", req.Method, req.URL.String(), req.RemoteAddr, req.Host, req.ContentLength, runtime.NumGoroutine())
logger.Logger.Debug("Request", "method", req.Method, "url", req.URL, "remote", req.RemoteAddr, "host", req.Host, "length", req.ContentLength, "goroutine", runtime.NumGoroutine())
conf.Router.ServeHTTP(rw, req)
})
favMiddleware := setupFaviconMiddleware(conf.Favicons, r)
rateLimiter := setupRateLimiter(conf.RateLimit, favMiddleware)
metricsMeta := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
r.ServeHTTP(rw, req)
})
if registry != nil {
metricsMiddleware := metrics.New(registry, nil).WrapHandler("violet-https", favMiddleware)
metricsMeta = func(rw http.ResponseWriter, req *http.Request) {
metricsMiddleware.ServeHTTP(rw, metrics.AddHostCtx(req))
}
}
rateLimiter := setupRateLimiter(conf.RateLimit, metricsMeta)
hsts := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains")
rateLimiter.ServeHTTP(rw, req)
})
return &http.Server{
Addr: conf.HttpsListen,
Handler: rateLimiter,
TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
// error out on invalid domains
if !conf.Domains.IsValid(info.ServerName) {
return nil, fmt.Errorf("invalid hostname used: '%s'", info.ServerName)
}
Handler: hsts,
TLSConfig: &tls.Config{
// Suggested by https://ssl-config.mozilla.org/#server=go&version=1.21.5&config=intermediate
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
},
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
// error out on invalid domains
if !conf.Domains.IsValid(info.ServerName) {
return nil, fmt.Errorf("invalid hostname used: '%s'", info.ServerName)
}
// find a certificate
cert := conf.Certs.GetCertForDomain(info.ServerName)
if cert == nil {
return nil, fmt.Errorf("failed to find certificate for: '%s'", info.ServerName)
}
// find a certificate
cert := conf.Certs.GetCertForDomain(info.ServerName)
if cert == nil {
return nil, fmt.Errorf("failed to find certificate for: '%s'", info.ServerName)
}
// time to return
return cert, nil
}},
// time to return
return cert, nil
},
},
ReadTimeout: 150 * time.Second,
ReadHeaderTimeout: 150 * time.Second,
WriteTimeout: 150 * time.Second,
@ -60,13 +87,13 @@ func setupRateLimiter(rateLimit uint64, next http.Handler) http.Handler {
Interval: time.Minute,
})
if err != nil {
log.Fatalln(err)
logger.Logger.Fatal("Failed to initialize memory store", "err", err)
}
// create a middleware using ips as the key for rate limits
middleware, err := httplimit.NewMiddleware(store, httplimit.IPKeyFunc())
if err != nil {
log.Fatalln(err)
logger.Logger.Fatal("Failed to initialize httplimit middleware", "err", err)
}
return middleware.Handle(next)
}

View File

@ -1,7 +1,7 @@
package servers
import (
"database/sql"
"github.com/1f349/violet"
"github.com/1f349/violet/certs"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
@ -25,7 +25,7 @@ func (f *fakeTransport) RoundTrip(_ *http.Request) (*http.Response, error) {
}
func TestNewHttpsServer_RateLimit(t *testing.T) {
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
db, err := violet.InitDB("file:TestNewHttpsServer_RateLimit?mode=memory&cache=shared")
assert.NoError(t, err)
ft := &fakeTransport{}
@ -36,7 +36,7 @@ func TestNewHttpsServer_RateLimit(t *testing.T) {
Signer: fake.SnakeOilProv,
Router: router.NewManager(db, proxy.NewHybridTransportWithCalls(ft, ft, &websocket.Server{})),
}
srv := NewHttpsServer(httpsConf)
srv := NewHttpsServer(httpsConf, nil)
req, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
req.RemoteAddr = "127.0.0.1:1447"

View File

@ -0,0 +1,118 @@
package metrics
import (
"context"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
"net/http"
)
// Copyright 2022 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package metrics is adapted from
// https://github.com/bwplotka/correlator/tree/main/examples/observability/ping/pkg/httpinstrumentation
// https://github.com/prometheus/client_golang/blob/main/examples/middleware/httpmiddleware/httpmiddleware.go
type Middleware interface {
// WrapHandler wraps the given HTTP handler for instrumentation.
WrapHandler(handlerName string, handler http.Handler) http.HandlerFunc
}
type middleware struct {
buckets []float64
registry prometheus.Registerer
}
// WrapHandler wraps the given HTTP handler for instrumentation:
// It registers four metric collectors (if not already done) and reports HTTP
// metrics to the (newly or already) registered collectors.
// Each has a constant label named "handler" with the provided handlerName as
// value.
func (m *middleware) WrapHandler(handlerName string, handler http.Handler) http.HandlerFunc {
reg := prometheus.WrapRegistererWith(prometheus.Labels{"handler": handlerName}, m.registry)
requestsTotal := promauto.With(reg).NewCounterVec(
prometheus.CounterOpts{
Name: "http_requests_total",
Help: "Tracks the number of HTTP requests.",
}, []string{"method", "code", "host"},
)
requestDuration := promauto.With(reg).NewHistogramVec(
prometheus.HistogramOpts{
Name: "http_request_duration_seconds",
Help: "Tracks the latencies for HTTP requests.",
Buckets: m.buckets,
},
[]string{"method", "code", "host"},
)
requestSize := promauto.With(reg).NewSummaryVec(
prometheus.SummaryOpts{
Name: "http_request_size_bytes",
Help: "Tracks the size of HTTP requests.",
},
[]string{"method", "code", "host"},
)
responseSize := promauto.With(reg).NewSummaryVec(
prometheus.SummaryOpts{
Name: "http_response_size_bytes",
Help: "Tracks the size of HTTP responses.",
},
[]string{"method", "code", "host"},
)
hostCtxGetter := promhttp.WithLabelFromCtx("host", func(ctx context.Context) string {
s, _ := ctx.Value(hostCtxKey(0)).(string)
return s
})
// Wraps the provided http.Handler to observe the request result with the provided metrics.
base := promhttp.InstrumentHandlerCounter(
requestsTotal,
promhttp.InstrumentHandlerDuration(
requestDuration,
promhttp.InstrumentHandlerRequestSize(
requestSize,
promhttp.InstrumentHandlerResponseSize(
responseSize,
handler,
hostCtxGetter,
),
hostCtxGetter,
),
hostCtxGetter,
),
hostCtxGetter,
)
return base.ServeHTTP
}
// New returns a Middleware interface.
func New(registry prometheus.Registerer, buckets []float64) Middleware {
if buckets == nil {
buckets = prometheus.ExponentialBuckets(0.1, 1.5, 5)
}
return &middleware{
buckets: buckets,
registry: registry,
}
}
type hostCtxKey uint8
func AddHostCtx(req *http.Request) *http.Request {
return req.WithContext(context.WithValue(req.Context(), hostCtxKey(0), req.Host))
}

15
sqlc.yaml Normal file
View File

@ -0,0 +1,15 @@
version: "2"
sql:
- engine: sqlite
queries: database/queries
schema: database/migrations
gen:
go:
package: "database"
out: "database"
emit_json_tags: true
overrides:
- column: "routes.flags"
go_type: "github.com/1f349/violet/target.Flags"
- column: "redirects.flags"
go_type: "github.com/1f349/violet/target.Flags"

View File

@ -14,8 +14,9 @@ import (
type Redirect struct {
Src string `json:"src"` // request source
Dst string `json:"dst"` // redirect destination
Desc string `json:"desc"` // description for admin panel use
Flags Flags `json:"flags"` // extra flags
Code int `json:"code"` // status code used to redirect
Code int64 `json:"code"` // status code used to redirect
}
type RedirectWithActive struct {
@ -23,6 +24,17 @@ type RedirectWithActive struct {
Active bool `json:"active"`
}
func (r Redirect) OnDomain(domain string) bool {
// if there is no / then the first part is still the domain
domainPart, _, _ := strings.Cut(r.Src, "/")
if domainPart == domain {
return true
}
// domainPart could start with a subdomain
return strings.HasSuffix(domainPart, "."+domain)
}
func (r Redirect) HasFlag(flag Flags) bool {
return r.Flags&flag != 0
}
@ -66,7 +78,7 @@ func (r Redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
// use fast redirect for speed
utils.FastRedirect(rw, req, u.String(), code)
utils.FastRedirect(rw, req, u.String(), int(code))
}
// String outputs a debug string for the redirect.

View File

@ -7,6 +7,22 @@ import (
"testing"
)
func TestRedirect_OnDomain(t *testing.T) {
assert.True(t, Route{Src: "example.com"}.OnDomain("example.com"))
assert.True(t, Route{Src: "test.example.com"}.OnDomain("example.com"))
assert.True(t, Route{Src: "example.com/hello"}.OnDomain("example.com"))
assert.True(t, Route{Src: "test.example.com/hello"}.OnDomain("example.com"))
assert.False(t, Route{Src: "example.com"}.OnDomain("example.org"))
assert.False(t, Route{Src: "test.example.com"}.OnDomain("example.org"))
assert.False(t, Route{Src: "example.com/hello"}.OnDomain("example.org"))
assert.False(t, Route{Src: "test.example.com/hello"}.OnDomain("example.org"))
}
func TestRedirect_HasFlag(t *testing.T) {
assert.True(t, Route{Flags: FlagPre | FlagAbs}.HasFlag(FlagPre))
assert.False(t, Route{Flags: FlagPre | FlagAbs}.HasFlag(FlagCors))
}
func TestRedirect_ServeHTTP(t *testing.T) {
a := []struct {
Redirect
@ -19,7 +35,7 @@ func TestRedirect_ServeHTTP(t *testing.T) {
res := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "https://www.example.com/hello/world", nil)
i.ServeHTTP(res, req)
assert.Equal(t, i.Code, res.Code)
assert.Equal(t, i.Code, int64(res.Code))
assert.Equal(t, i.target, res.Header().Get("Location"))
}
}

View File

@ -2,14 +2,13 @@ package target
import (
"fmt"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/utils"
"github.com/google/uuid"
websocket2 "github.com/gorilla/websocket"
"github.com/rs/cors"
"golang.org/x/net/http/httpguts"
"io"
"log"
"net"
"net/http"
"net/textproto"
@ -18,10 +17,13 @@ import (
"strings"
)
var Logger = logger.Logger.WithPrefix("Violet Serve Route")
// serveApiCors outputs the cors headers to make APIs work.
var serveApiCors = cors.New(cors.Options{
AllowedOrigins: []string{"*"}, // allow all origins for api requests
AllowedHeaders: []string{"Content-Type", "Authorization"},
// allow all origins for api requests
AllowOriginFunc: func(origin string) bool { return true },
AllowedHeaders: []string{"Content-Type", "Authorization"},
AllowedMethods: []string{
http.MethodGet,
http.MethodHead,
@ -39,6 +41,7 @@ var serveApiCors = cors.New(cors.Options{
type Route struct {
Src string `json:"src"` // request source
Dst string `json:"dst"` // proxy destination
Desc string `json:"desc"` // description for admin panel use
Flags Flags `json:"flags"` // extra flags
Headers http.Header `json:"-"` // extra headers
Proxy *proxy.HybridTransport `json:"-"` // reverse proxy handler
@ -49,6 +52,17 @@ type RouteWithActive struct {
Active bool `json:"active"`
}
func (r Route) OnDomain(domain string) bool {
// if there is no / then the first part is still the domain
domainPart, _, _ := strings.Cut(r.Src, "/")
if domainPart == domain {
return true
}
// domainPart could start with a subdomain
return strings.HasSuffix(domainPart, "."+domain)
}
func (r Route) HasFlag(flag Flags) bool {
return r.Flags&flag != 0
}
@ -115,8 +129,7 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
// create the internal request
req2, err := http.NewRequest(req.Method, u.String(), req.Body)
if err != nil {
log.Printf("[ServeRoute::ServeHTTP()] Error generating new request: %s\n", err)
utils.RespondVioletError(rw, http.StatusBadGateway, "error generating new request")
utils.RespondVioletError(rw, http.StatusBadGateway, "Invalid request for proxy")
return
}
@ -166,8 +179,8 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
resp, err = r.Proxy.SecureRoundTrip(req2)
}
if err != nil {
log.Printf("[ServeRoute::ServeHTTP()] Error receiving internal round trip response: %s\n", err)
utils.RespondVioletError(rw, http.StatusBadGateway, "error receiving internal round trip response")
Logger.Warn("Error receiving internal round trip response", "route src", r.Src, "url", req2.URL.String(), "err", err)
utils.RespondVioletError(rw, http.StatusBadGateway, "Error receiving internal round trip response")
return
}
@ -177,9 +190,8 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
if resp.StatusCode == http.StatusLoopDetected {
u := uuid.New()
log.Printf("[ServeRoute::ServeHTTP()] Loop Detected: %s %s '%s' -> '%s'\n", u, req.Method, req.URL.String(), req2.URL.String())
utils.RespondVioletError(rw, http.StatusLoopDetected, "error loop detected: "+u.String())
Logger.Warn("Loop Detected", "method", req.Method, "url", req.URL, "url2", req2.URL.String())
utils.RespondVioletError(rw, http.StatusLoopDetected, "Error loop detected")
return
}
@ -210,7 +222,7 @@ func (r Route) internalReverseProxyMeta(rw http.ResponseWriter, req, req2 *http.
reqUpType := upgradeType(req2.Header)
if !asciiIsPrint(reqUpType) {
utils.RespondVioletError(rw, http.StatusBadRequest, fmt.Sprintf("client tried to switch to invalid protocol %q", reqUpType))
utils.RespondVioletError(rw, http.StatusBadRequest, fmt.Sprintf("Invalid protocol %s", reqUpType))
return true
}
removeHopByHopHeaders(req2.Header)

View File

@ -27,6 +27,17 @@ func (p *proxyTester) RoundTrip(req *http.Request) (*http.Response, error) {
return &http.Response{StatusCode: http.StatusOK}, nil
}
func TestRoute_OnDomain(t *testing.T) {
assert.True(t, Route{Src: "example.com"}.OnDomain("example.com"))
assert.True(t, Route{Src: "test.example.com"}.OnDomain("example.com"))
assert.True(t, Route{Src: "example.com/hello"}.OnDomain("example.com"))
assert.True(t, Route{Src: "test.example.com/hello"}.OnDomain("example.com"))
assert.False(t, Route{Src: "example.com"}.OnDomain("example.org"))
assert.False(t, Route{Src: "test.example.com"}.OnDomain("example.org"))
assert.False(t, Route{Src: "example.com/hello"}.OnDomain("example.org"))
assert.False(t, Route{Src: "test.example.com/hello"}.OnDomain("example.org"))
}
func TestRoute_HasFlag(t *testing.T) {
assert.True(t, Route{Flags: FlagPre | FlagAbs}.HasFlag(FlagPre))
assert.False(t, Route{Flags: FlagPre | FlagAbs}.HasFlag(FlagCors))
@ -79,7 +90,7 @@ func TestRoute_ServeHTTP_Cors(t *testing.T) {
assert.Equal(t, http.MethodOptions, pt.req.Method)
assert.Equal(t, "http://1.1.1.1:8080/hello/test", pt.req.URL.String())
assert.Equal(t, "Origin", res.Header().Get("Vary"))
assert.Equal(t, "*", res.Header().Get("Access-Control-Allow-Origin"))
assert.Equal(t, "https://test.example.com", res.Header().Get("Access-Control-Allow-Origin"))
assert.Equal(t, "true", res.Header().Get("Access-Control-Allow-Credentials"))
assert.Equal(t, "Origin", res.Header().Get("Vary"))
}

View File

@ -1,6 +1,7 @@
package utils
import (
"golang.org/x/net/publicsuffix"
"strconv"
"strings"
)
@ -65,23 +66,8 @@ func GetParentDomain(domain string) (string, bool) {
//
// hello.world.example.com => example.com
func GetTopFqdn(domain string) (string, bool) {
var countDot int
n := strings.LastIndexFunc(domain, func(r rune) bool {
// return true if this is the second '.'
// otherwise counts one and continues
if r == '.' {
if countDot == 1 {
return true
}
countDot++
}
return false
})
// if a valid index isn't found then return false
if n == -1 {
return "", false
}
return domain[n+1:], true
out, err := publicsuffix.EffectiveTLDPlusOne(domain)
return out, err == nil
}
// SplitHostPath extracts the host/path from the input

View File

@ -52,7 +52,11 @@ func TestGetBaseDomain(t *testing.T) {
}
func TestGetTopFqdn(t *testing.T) {
domain, ok := GetTopFqdn("www.example.com")
domain, ok := GetTopFqdn("example.com")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain)
domain, ok = GetTopFqdn("www.example.com")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain)

View File

@ -3,9 +3,9 @@ package fake
import (
"crypto/rand"
"crypto/rsa"
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/mjwt/auth"
"github.com/MrMelon54/mjwt/claims"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/auth"
"github.com/1f349/mjwt/claims"
"time"
)

View File

@ -2,33 +2,34 @@ package utils
import (
"errors"
"log"
"github.com/charmbracelet/log"
"net"
"net/http"
"strings"
)
// logHttpServerError is the internal function powering the logging in
// RunBackgroundHttp and RunBackgroundHttps.
func logHttpServerError(prefix string, err error) {
func logHttpServerError(logger *log.Logger, err error) {
if err != nil {
if errors.Is(err, http.ErrServerClosed) {
log.Printf("[%s] The http server shutdown successfully\n", prefix)
logger.Info("The http server shutdown successfully")
} else {
log.Printf("[%s] Error trying to host the http server: %s\n", prefix, err.Error())
logger.Info("Error trying to host the http server", "err", err.Error())
}
}
}
// RunBackgroundHttp runs a http server and logs when the server closes or
// errors.
func RunBackgroundHttp(prefix string, s *http.Server) {
logHttpServerError(prefix, s.ListenAndServe())
func RunBackgroundHttp(logger *log.Logger, s *http.Server, ln net.Listener) {
logHttpServerError(logger, s.Serve(ln))
}
// RunBackgroundHttps runs a http server with TLS encryption and logs when the
// server closes or errors.
func RunBackgroundHttps(prefix string, s *http.Server) {
logHttpServerError(prefix, s.ListenAndServeTLS("", ""))
func RunBackgroundHttps(logger *log.Logger, s *http.Server, ln net.Listener) {
logHttpServerError(logger, s.ServeTLS(ln, "", ""))
}
// GetBearer returns the bearer from the Authorization header or an empty string