Compare commits

..

No commits in common. "main" and "v0.0.4" have entirely different histories.
main ... v0.0.4

70 changed files with 693 additions and 2028 deletions

View File

@ -4,7 +4,7 @@ jobs:
test:
strategy:
matrix:
go-version: [1.22.x]
go-version: [1.20.x]
runs-on: ubuntu-latest
steps:
- uses: actions/setup-go@v3

1
.gitignore vendored
View File

@ -1,4 +1,3 @@
*.sqlite
*.local
.idea/
.data

8
.idea/.gitignore generated vendored Normal file
View File

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

5
.idea/codeStyles/codeStyleConfig.xml generated Normal file
View File

@ -0,0 +1,5 @@
<component name="ProjectCodeStyleConfiguration">
<state>
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
</state>
</component>

12
.idea/dataSources.xml generated Normal file
View File

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="identifier.sqlite" uuid="5b42d21a-92a8-43d0-8651-c1555b91060c">
<driver-ref>sqlite.xerial</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>org.sqlite.JDBC</jdbc-driver>
<jdbc-url>jdbc:sqlite:identifier.sqlite</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
</component>
</project>

7
.idea/discord.xml generated Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DiscordProjectSettings">
<option name="show" value="PROJECT_FILES" />
<option name="description" value="" />
</component>
</project>

6
.idea/misc.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="MarkdownSettingsMigration">
<option name="stateVersion" value="1" />
</component>
</project>

8
.idea/modules.xml generated Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/violet.iml" filepath="$PROJECT_DIR$/.idea/violet.iml" />
</modules>
</component>
</project>

6
.idea/sqldialects.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="SqlDialectMappings">
<file url="PROJECT" dialect="SQLite" />
</component>
</project>

6
.idea/vcs.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

9
.idea/violet.iml generated Normal file
View File

@ -0,0 +1,9 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="WEB_MODULE" version="4">
<component name="Go" enabled="true" />
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@ -4,21 +4,18 @@ import (
"crypto/tls"
"crypto/x509/pkix"
"fmt"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/utils"
"github.com/mrmelon54/certgen"
"github.com/mrmelon54/rescheduler"
"github.com/MrMelon54/certgen"
"github.com/MrMelon54/rescheduler"
"github.com/MrMelon54/violet/utils"
"io/fs"
"log"
"math/big"
"os"
"strings"
"sync"
"sync/atomic"
"time"
)
var Logger = logger.Logger.WithPrefix("Violet Certs")
// Certs is the certificate loader and management system.
type Certs struct {
cDir fs.FS
@ -29,8 +26,6 @@ type Certs struct {
ca *certgen.CertGen
sn atomic.Int64
r *rescheduler.Rescheduler
t *time.Ticker
ts chan struct{}
}
// New creates a new cert list
@ -41,26 +36,15 @@ func New(certDir fs.FS, keyDir fs.FS, selfCert bool) *Certs {
ss: selfCert,
s: &sync.RWMutex{},
m: make(map[string]*tls.Certificate),
ts: make(chan struct{}, 1),
}
// the rescheduler isn't even used in self cert mode so why initialise it
if !selfCert {
// the rescheduler isn't even used in self cert mode so why initialise it
c.r = rescheduler.NewRescheduler(c.threadCompile)
}
c.t = time.NewTicker(2 * time.Hour)
go func() {
for {
select {
case <-c.t.C:
c.Compile()
case <-c.ts:
return
}
}
}()
} else {
// in self-signed mode generate a CA certificate to sign other certificates
// in self-signed mode generate a CA certificate to sign other certificates
if c.ss {
ca, err := certgen.MakeCaTls(4096, pkix.Name{
Country: []string{"GB"},
Organization: []string{"Violet"},
@ -71,7 +55,7 @@ func New(certDir fs.FS, keyDir fs.FS, selfCert bool) *Certs {
return now.AddDate(10, 0, 0)
})
if err != nil {
logger.Logger.Fatal("Failed to generate CA cert for self-signed mode", "err", err)
log.Fatalln("Failed to generate CA cert for self-signed mode:", err)
}
c.ca = ca
}
@ -133,13 +117,6 @@ func (c *Certs) Compile() {
c.r.Run()
}
func (c *Certs) Stop() {
if c.t != nil {
c.t.Stop()
}
close(c.ts)
}
func (c *Certs) threadCompile() {
// new map
certMap := make(map[string]*tls.Certificate)
@ -147,7 +124,7 @@ func (c *Certs) threadCompile() {
// compile map and check errors
err := c.internalCompile(certMap)
if err != nil {
Logger.Infof("Compile failed: %s\n", err)
log.Printf("[Certs] Compile failed: %s\n", err)
return
}
@ -170,7 +147,7 @@ func (c *Certs) internalCompile(m map[string]*tls.Certificate) error {
return fmt.Errorf("failed to read cert dir: %w", err)
}
Logger.Infof("Compiling lookup table for %d certificates\n", len(files))
log.Printf("[Certs] Compiling lookup table for %d certificates\n", len(files))
// find and parse certs
for _, i := range files {
@ -195,10 +172,6 @@ func (c *Certs) internalCompile(m map[string]*tls.Certificate) error {
// try to read key file
keyData, err := fs.ReadFile(c.kDir, keyName)
if err != nil {
// ignore the file if the certificate doesn't exist
if os.IsNotExist(err) {
continue
}
return fmt.Errorf("failed to read key file '%s': %w", keyName, err)
}

View File

@ -3,7 +3,7 @@ package certs
import (
"crypto/x509/pkix"
"fmt"
"github.com/mrmelon54/certgen"
"github.com/MrMelon54/certgen"
"github.com/stretchr/testify/assert"
"math/big"
"testing"
@ -16,7 +16,7 @@ func TestCertsNew_Lookup(t *testing.T) {
// type to test that certificate files can be found and read correctly. This
// uses a MapFS for performance during tests.
ca, err := certgen.MakeCaTls(2048, pkix.Name{
ca, err := certgen.MakeCaTls(4096, pkix.Name{
Country: []string{"GB"},
Organization: []string{"Violet"},
OrganizationalUnit: []string{"Development"},
@ -29,7 +29,7 @@ func TestCertsNew_Lookup(t *testing.T) {
domain := "example.com"
sn := int64(1)
serverTls, err := certgen.MakeServerTls(ca, 2048, pkix.Name{
serverTls, err := certgen.MakeServerTls(ca, 4096, pkix.Name{
Country: []string{"GB"},
Organization: []string{domain},
OrganizationalUnit: []string{domain},
@ -63,10 +63,6 @@ func TestCertsNew_Lookup(t *testing.T) {
}
func TestCertsNew_SelfSigned(t *testing.T) {
if testing.Short() {
return
}
certs := New(nil, nil, true)
cc := certs.GetCertForDomain("example.com")
leaf := certgen.TlsLeaf(cc)

View File

@ -2,28 +2,24 @@ package main
import (
"context"
"database/sql"
"encoding/json"
"flag"
"github.com/1f349/mjwt"
"github.com/1f349/violet"
"github.com/1f349/violet/certs"
"github.com/1f349/violet/domains"
errorPages "github.com/1f349/violet/error-pages"
"github.com/1f349/violet/favicons"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/1f349/violet/router"
"github.com/1f349/violet/servers"
"github.com/1f349/violet/servers/api"
"github.com/1f349/violet/servers/conf"
"github.com/1f349/violet/utils"
"github.com/charmbracelet/log"
"github.com/cloudflare/tableflip"
"fmt"
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/violet/certs"
"github.com/MrMelon54/violet/domains"
errorPages "github.com/MrMelon54/violet/error-pages"
"github.com/MrMelon54/violet/favicons"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/router"
"github.com/MrMelon54/violet/servers"
"github.com/MrMelon54/violet/servers/api"
"github.com/MrMelon54/violet/servers/conf"
"github.com/MrMelon54/violet/utils"
"github.com/google/subcommands"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"io/fs"
"log"
"net/http"
"os"
"os/signal"
@ -32,218 +28,156 @@ import (
"time"
)
type serveCmd struct {
configPath string
debugLog bool
pidFile string
}
type serveCmd struct{ configPath string }
func (s *serveCmd) Name() string { return "serve" }
func (s *serveCmd) Synopsis() string { return "Serve reverse proxy server" }
func (s *serveCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&s.configPath, "conf", "", "/path/to/config.json : path to the config file")
f.BoolVar(&s.debugLog, "debug", false, "enable debug logging")
f.StringVar(&s.pidFile, "pid-file", "", "path to pid file")
}
func (s *serveCmd) Usage() string {
return `serve [-conf <config file>] [-debug] [-pid-file <pid file>]
return `serve [-conf <config file>]
Serve reverse proxy server using information from config file
`
}
func (s *serveCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
if s.debugLog {
logger.Logger.SetLevel(log.DebugLevel)
}
logger.Logger.Info("Starting...")
upg, err := tableflip.New(tableflip.Options{
PIDFile: s.pidFile,
})
if err != nil {
panic(err)
}
defer upg.Stop()
func (s *serveCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
log.Println("[Violet] Starting...")
if s.configPath == "" {
logger.Logger.Info("Error: config flag is missing")
log.Println("[Violet] Error: config flag is missing")
return subcommands.ExitUsageError
}
openConf, err := os.Open(s.configPath)
if err != nil {
if os.IsNotExist(err) {
logger.Logger.Info("Error: missing config file")
log.Println("[Violet] Error: missing config file")
} else {
logger.Logger.Info("Error: open config file: ", err)
log.Println("[Violet] Error: open config file: ", err)
}
return subcommands.ExitFailure
}
var config startUpConfig
err = json.NewDecoder(openConf).Decode(&config)
var conf startUpConfig
err = json.NewDecoder(openConf).Decode(&conf)
if err != nil {
logger.Logger.Info("Error: invalid config file: ", err)
log.Println("[Violet] Error: invalid config file: ", err)
return subcommands.ExitFailure
}
// working directory is the parent of the config file
wd := filepath.Dir(s.configPath)
normalLoad(conf, wd)
return subcommands.ExitSuccess
}
func normalLoad(startUp startUpConfig, wd string) {
// the cert and key paths are useless in self-signed mode
if !config.SelfSigned {
if !startUp.SelfSigned {
// create path to cert dir
err := os.MkdirAll(filepath.Join(wd, "certs"), os.ModePerm)
if err != nil {
logger.Logger.Fatal("Failed to create certificate path")
log.Fatal("[Violet] Failed to create certificate path")
}
// create path to key dir
err = os.MkdirAll(filepath.Join(wd, "keys"), os.ModePerm)
if err != nil {
logger.Logger.Fatal("Failed to create certificate key path")
log.Fatal("[Violet] Failed to create certificate key path")
}
}
// errorPageDir stores an FS interface for accessing the error page directory
var errorPageDir fs.FS
if config.ErrorPagePath != "" {
errorPageDir = os.DirFS(config.ErrorPagePath)
err := os.MkdirAll(config.ErrorPagePath, os.ModePerm)
if startUp.ErrorPagePath != "" {
errorPageDir = os.DirFS(startUp.ErrorPagePath)
err := os.MkdirAll(startUp.ErrorPagePath, os.ModePerm)
if err != nil {
logger.Logger.Fatal("Failed to create error page", "path", config.ErrorPagePath)
log.Fatalf("[Violet] Failed to create error page path '%s'", startUp.ErrorPagePath)
}
}
// load the MJWT RSA public key from a pem encoded file
mJwtVerify, err := mjwt.NewMJwtVerifierFromFile(filepath.Join(wd, "signer.public.pem"))
if err != nil {
logger.Logger.Fatal("Failed to load MJWT verifier public key", "file", filepath.Join(wd, "signer.public.pem"), "err", err)
log.Fatalf("[Violet] Failed to load MJWT verifier public key from file '%s': %s", filepath.Join(wd, "signer.public.pem"), err)
}
// open sqlite database
db, err := violet.InitDB(filepath.Join(wd, "violet.db.sqlite"))
db, err := sql.Open("sqlite3", filepath.Join(wd, "violet.db.sqlite"))
if err != nil {
logger.Logger.Fatal("Failed to open database", "err", err)
log.Fatal("[Violet] Failed to open database")
}
certDir := os.DirFS(filepath.Join(wd, "certs"))
keyDir := os.DirFS(filepath.Join(wd, "keys"))
// setup registry for metrics
promRegistry := prometheus.NewRegistry()
promRegistry.MustRegister(
collectors.NewGoCollector(),
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
)
ws := websocket.NewServer()
allowedDomains := domains.New(db) // load allowed domains
acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store
allowedCerts := certs.New(certDir, keyDir, config.SelfSigned) // load certificate manager
hybridTransport := proxy.NewHybridTransport(ws) // load reverse proxy
dynamicFavicons := favicons.New(db, config.InkscapeCmd) // load dynamic favicon provider
dynamicErrorPages := errorPages.New(errorPageDir) // load dynamic error page provider
dynamicRouter := router.NewManager(db, hybridTransport) // load dynamic router manager
allowedDomains := domains.New(db) // load allowed domains
acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store
allowedCerts := certs.New(certDir, keyDir, startUp.SelfSigned) // load certificate manager
hybridTransport := proxy.NewHybridTransport() // load reverse proxy
dynamicFavicons := favicons.New(db, startUp.InkscapeCmd) // load dynamic favicon provider
dynamicErrorPages := errorPages.New(errorPageDir) // load dynamic error page provider
dynamicRouter := router.NewManager(db, hybridTransport) // load dynamic router manager
// struct containing config for the http servers
srvConf := &conf.Conf{
RateLimit: config.RateLimit,
DB: db,
Domains: allowedDomains,
Acme: acmeChallenges,
Certs: allowedCerts,
Favicons: dynamicFavicons,
Signer: mJwtVerify,
ErrorPages: dynamicErrorPages,
Router: dynamicRouter,
ApiListen: startUp.Listen.Api,
HttpListen: startUp.Listen.Http,
HttpsListen: startUp.Listen.Https,
RateLimit: startUp.RateLimit,
DB: db,
Domains: allowedDomains,
Acme: acmeChallenges,
Certs: allowedCerts,
Favicons: dynamicFavicons,
Signer: mJwtVerify,
ErrorPages: dynamicErrorPages,
Router: dynamicRouter,
}
// create the compilable list and run a first time compile
allCompilables := utils.MultiCompilable{allowedDomains, allowedCerts, dynamicFavicons, dynamicErrorPages, dynamicRouter}
allCompilables.Compile()
_, httpsPort, ok := utils.SplitDomainPort(config.Listen.Https, 443)
if !ok {
httpsPort = 443
}
var srvApi, srvHttp, srvHttps *http.Server
if config.Listen.Api != "" {
// Listen must be called before Ready
lnApi, err := upg.Listen("tcp", config.Listen.Api)
if err != nil {
logger.Logger.Fatal("Listen failed", "err", err)
}
srvApi = api.NewApiServer(srvConf, allCompilables, promRegistry)
srvApi.SetKeepAlivesEnabled(false)
l := logger.Logger.With("server", "API")
l.Info("Starting server", "addr", config.Listen.Api)
go utils.RunBackgroundHttp(l, srvApi, lnApi)
if srvConf.ApiListen != "" {
srvApi = api.NewApiServer(srvConf, allCompilables)
log.Printf("[API] Starting API server on: '%s'\n", srvApi.Addr)
go utils.RunBackgroundHttp("API", srvApi)
}
if config.Listen.Http != "" {
// Listen must be called before Ready
lnHttp, err := upg.Listen("tcp", config.Listen.Http)
if err != nil {
logger.Logger.Fatal("Listen failed", "err", err)
}
srvHttp = servers.NewHttpServer(uint16(httpsPort), srvConf, promRegistry)
srvHttp.SetKeepAlivesEnabled(false)
l := logger.Logger.With("server", "HTTP")
l.Info("Starting server", "addr", config.Listen.Http)
go utils.RunBackgroundHttp(l, srvHttp, lnHttp)
if srvConf.HttpListen != "" {
srvHttp = servers.NewHttpServer(srvConf)
log.Printf("[HTTP] Starting HTTP server on: '%s'\n", srvHttp.Addr)
go utils.RunBackgroundHttp("HTTP", srvHttp)
}
if config.Listen.Https != "" {
// Listen must be called before Ready
lnHttps, err := upg.Listen("tcp", config.Listen.Https)
if err != nil {
logger.Logger.Fatal("Listen failed", "err", err)
}
srvHttps = servers.NewHttpsServer(srvConf, promRegistry)
srvHttps.SetKeepAlivesEnabled(false)
l := logger.Logger.With("server", "HTTPS")
l.Info("Starting server", "addr", config.Listen.Https)
go utils.RunBackgroundHttps(l, srvHttps, lnHttps)
if srvConf.HttpsListen != "" {
srvHttps = servers.NewHttpsServer(srvConf)
log.Printf("[HTTPS] Starting HTTPS server on: '%s'\n", srvHttps.Addr)
go utils.RunBackgroundHttps("HTTPS", srvHttps)
}
// Do an upgrade on SIGHUP
go func() {
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGHUP)
for range sig {
err := upg.Upgrade()
if err != nil {
logger.Logger.Error("Failed upgrade", "err", err)
}
}
}()
// Wait for exit signal
sc := make(chan os.Signal, 1)
signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt, os.Kill)
<-sc
fmt.Println()
logger.Logger.Info("Ready")
if err := upg.Ready(); err != nil {
panic(err)
}
<-upg.Exit()
time.AfterFunc(30*time.Second, func() {
logger.Logger.Warn("Graceful shutdown timed out")
os.Exit(1)
})
// stop updating certificates
allowedCerts.Stop()
// close websockets first
ws.Shutdown()
// Stop servers
log.Printf("[Violet] Stopping...")
n := time.Now()
// close http servers
if srvApi != nil {
_ = srvApi.Close()
srvApi.Close()
}
if srvHttp != nil {
_ = srvHttp.Close()
srvHttp.Close()
}
if srvHttps != nil {
_ = srvHttps.Close()
srvHttps.Close()
}
return subcommands.ExitSuccess
log.Printf("[Violet] Took '%s' to shutdown\n", time.Now().Sub(n))
log.Println("[Violet] Goodbye")
}

View File

@ -2,18 +2,17 @@ package main
import (
"context"
"database/sql"
"encoding/json"
"flag"
"fmt"
"github.com/1f349/violet"
"github.com/1f349/violet/domains"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/1f349/violet/router"
"github.com/1f349/violet/target"
"github.com/AlecAivazis/survey/v2"
"github.com/MrMelon54/violet/domains"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/router"
"github.com/MrMelon54/violet/target"
"github.com/google/subcommands"
"log"
"net"
"net/http"
"net/url"
@ -42,7 +41,7 @@ func (s *setupCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
// get absolute path to specify files
wdAbs, err := filepath.Abs(s.wdPath)
if err != nil {
fmt.Println("Failed to get full directory path: ", err)
fmt.Println("[Violet] Failed to get full directory path: ", err)
return subcommands.ExitFailure
}
@ -50,11 +49,11 @@ func (s *setupCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
createFile := false
err = survey.AskOne(&survey.Confirm{Message: fmt.Sprintf("Create Violet config files in this directory: '%s'?", wdAbs)}, &createFile)
if err != nil {
fmt.Println("Error: ", err)
fmt.Println("[Violet] Error: ", err)
return subcommands.ExitFailure
}
if !createFile {
fmt.Println("Goodbye")
fmt.Println("[Violet] Goodbye")
return subcommands.ExitSuccess
}
@ -111,7 +110,7 @@ func (s *setupCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
},
}, &answers)
if err != nil {
fmt.Println("Error: ", err)
fmt.Println("[Violet] Error: ", err)
return subcommands.ExitFailure
}
@ -142,14 +141,14 @@ func (s *setupCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
RateLimit: answers.RateLimit,
})
if err != nil {
fmt.Println("Failed to write config file: ", err)
fmt.Println("[Violet] Failed to write config file: ", err)
return subcommands.ExitFailure
}
// open sqlite database
db, err := violet.InitDB(databaseFile)
db, err := sql.Open("sqlite3", databaseFile)
if err != nil {
logger.Logger.Fatal("Failed to open database", "err", err)
log.Fatalf("[Violet] Failed to open database '%s'...", databaseFile)
}
// domain manager to add a domain, no need to compile here as the program needs
@ -168,36 +167,33 @@ func (s *setupCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{})
return nil
}))
if err != nil {
fmt.Println("Error: ", err)
fmt.Println("[Violet] Error: ", err)
return subcommands.ExitFailure
}
// parse the api url
apiUrl, err := url.Parse(answers.ApiUrl)
if err != nil {
fmt.Println("Failed to parse API URL: ", err)
fmt.Println("[Violet] Failed to parse API URL: ", err)
return subcommands.ExitFailure
}
// add with the route manager, no need to compile as this will run when opened
// with the serve subcommand
routeManager := router.NewManager(db, proxy.NewHybridTransportWithCalls(&nilTransport{}, &nilTransport{}, &websocket.Server{}))
err = routeManager.InsertRoute(target.RouteWithActive{
Route: target.Route{
Src: path.Join(apiUrl.Host, apiUrl.Path),
Dst: answers.ApiListen,
Flags: target.FlagPre | target.FlagCors | target.FlagForwardHost | target.FlagForwardAddr,
},
Active: true,
routeManager := router.NewManager(db, proxy.NewHybridTransportWithCalls(&nilTransport{}, &nilTransport{}))
err = routeManager.InsertRoute(target.Route{
Src: path.Join(apiUrl.Host, apiUrl.Path),
Dst: answers.ApiListen,
Flags: target.FlagPre | target.FlagCors | target.FlagForwardHost | target.FlagForwardAddr,
})
if err != nil {
fmt.Println("Failed to insert api route into database: ", err)
fmt.Println("[Violet] Failed to insert api route into database: ", err)
return subcommands.ExitFailure
}
}
fmt.Println("Setup complete")
fmt.Printf("Run the reverse proxy with `violet serve -conf %s`\n", confFile)
fmt.Println("[Violet] Setup complete")
fmt.Printf("[Violet] Run the reverse proxy with `violet serve -conf %s`\n", confFile)
return subcommands.ExitSuccess
}

View File

@ -1,31 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.25.0
package database
import (
"context"
"database/sql"
)
type DBTX interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
PrepareContext(context.Context, string) (*sql.Stmt, error)
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}
func New(db DBTX) *Queries {
return &Queries{db: db}
}
type Queries struct {
db DBTX
}
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{
db: tx,
}
}

View File

@ -1,68 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.25.0
// source: domain.sql
package database
import (
"context"
)
const addDomain = `-- name: AddDomain :exec
INSERT OR
REPLACE
INTO domains (domain, active)
VALUES (?, ?)
`
type AddDomainParams struct {
Domain string `json:"domain"`
Active bool `json:"active"`
}
func (q *Queries) AddDomain(ctx context.Context, arg AddDomainParams) error {
_, err := q.db.ExecContext(ctx, addDomain, arg.Domain, arg.Active)
return err
}
const deleteDomain = `-- name: DeleteDomain :exec
INSERT OR
REPLACE
INTO domains(domain, active)
VALUES (?, false)
`
func (q *Queries) DeleteDomain(ctx context.Context, domain string) error {
_, err := q.db.ExecContext(ctx, deleteDomain, domain)
return err
}
const getActiveDomains = `-- name: GetActiveDomains :many
SELECT domain
FROM domains
WHERE active = 1
`
func (q *Queries) GetActiveDomains(ctx context.Context) ([]string, error) {
rows, err := q.db.QueryContext(ctx, getActiveDomains)
if err != nil {
return nil, err
}
defer rows.Close()
var items []string
for rows.Next() {
var domain string
if err := rows.Scan(&domain); err != nil {
return nil, err
}
items = append(items, domain)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

View File

@ -1,74 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.25.0
// source: favicon.sql
package database
import (
"context"
"database/sql"
)
const getFavicons = `-- name: GetFavicons :many
SELECT host, svg, png, ico
FROM favicons
`
type GetFaviconsRow struct {
Host string `json:"host"`
Svg sql.NullString `json:"svg"`
Png sql.NullString `json:"png"`
Ico sql.NullString `json:"ico"`
}
func (q *Queries) GetFavicons(ctx context.Context) ([]GetFaviconsRow, error) {
rows, err := q.db.QueryContext(ctx, getFavicons)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetFaviconsRow
for rows.Next() {
var i GetFaviconsRow
if err := rows.Scan(
&i.Host,
&i.Svg,
&i.Png,
&i.Ico,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateFaviconCache = `-- name: UpdateFaviconCache :exec
INSERT OR
REPLACE INTO favicons (host, svg, png, ico)
VALUES (?, ?, ?, ?)
`
type UpdateFaviconCacheParams struct {
Host string `json:"host"`
Svg sql.NullString `json:"svg"`
Png sql.NullString `json:"png"`
Ico sql.NullString `json:"ico"`
}
func (q *Queries) UpdateFaviconCache(ctx context.Context, arg UpdateFaviconCacheParams) error {
_, err := q.db.ExecContext(ctx, updateFaviconCache,
arg.Host,
arg.Svg,
arg.Png,
arg.Ico,
)
return err
}

View File

@ -1,4 +0,0 @@
DROP TABLE domains;
DROP TABLE favicons;
DROP TABLE routes;
DROP TABLE redirects;

View File

@ -1,36 +0,0 @@
CREATE TABLE IF NOT EXISTS domains
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT UNIQUE NOT NULL,
active BOOLEAN NOT NULL DEFAULT 1
);
CREATE TABLE IF NOT EXISTS favicons
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
host VARCHAR NOT NULL,
svg VARCHAR,
png VARCHAR,
ico VARCHAR
);
CREATE TABLE IF NOT EXISTS routes
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
source TEXT UNIQUE NOT NULL,
destination TEXT NOT NULL,
description TEXT NOT NULL,
flags INTEGER NOT NULL DEFAULT 0,
active BOOLEAN NOT NULL DEFAULT 1
);
CREATE TABLE IF NOT EXISTS redirects
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
source TEXT UNIQUE NOT NULL,
destination TEXT NOT NULL,
description TEXT NOT NULL,
flags INTEGER NOT NULL DEFAULT 0,
code INTEGER NOT NULL DEFAULT 0,
active BOOLEAN NOT NULL DEFAULT 1
);

View File

@ -1,44 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.25.0
package database
import (
"database/sql"
"github.com/1f349/violet/target"
)
type Domain struct {
ID int64 `json:"id"`
Domain string `json:"domain"`
Active bool `json:"active"`
}
type Favicon struct {
ID int64 `json:"id"`
Host string `json:"host"`
Svg sql.NullString `json:"svg"`
Png sql.NullString `json:"png"`
Ico sql.NullString `json:"ico"`
}
type Redirect struct {
ID int64 `json:"id"`
Source string `json:"source"`
Destination string `json:"destination"`
Description string `json:"description"`
Flags target.Flags `json:"flags"`
Code int64 `json:"code"`
Active bool `json:"active"`
}
type Route struct {
ID int64 `json:"id"`
Source string `json:"source"`
Destination string `json:"destination"`
Description string `json:"description"`
Flags target.Flags `json:"flags"`
Active bool `json:"active"`
}

View File

@ -1,16 +0,0 @@
-- name: GetActiveDomains :many
SELECT domain
FROM domains
WHERE active = 1;
-- name: AddDomain :exec
INSERT OR
REPLACE
INTO domains (domain, active)
VALUES (?, ?);
-- name: DeleteDomain :exec
INSERT OR
REPLACE
INTO domains(domain, active)
VALUES (?, false);

View File

@ -1,8 +0,0 @@
-- name: GetFavicons :many
SELECT host, svg, png, ico
FROM favicons;
-- name: UpdateFaviconCache :exec
INSERT OR
REPLACE INTO favicons (host, svg, png, ico)
VALUES (?, ?, ?, ?);

View File

@ -1,39 +0,0 @@
-- name: GetActiveRoutes :many
SELECT source, destination, flags
FROM routes
WHERE active = 1;
-- name: GetActiveRedirects :many
SELECT source, destination, flags, code
FROM redirects
WHERE active = 1;
-- name: GetAllRoutes :many
SELECT source, destination, description, flags, active
FROM routes;
-- name: GetAllRedirects :many
SELECT source, destination, description, flags, code, active
FROM redirects;
-- name: AddRoute :exec
INSERT OR
REPLACE
INTO routes (source, destination, description, flags, active)
VALUES (?, ?, ?, ?, ?);
-- name: AddRedirect :exec
INSERT OR
REPLACE
INTO redirects (source, destination, description, flags, code, active)
VALUES (?, ?, ?, ?, ?, ?);
-- name: RemoveRoute :exec
DELETE
FROM routes
WHERE source = ?;
-- name: RemoveRedirect :exec
DELETE
FROM redirects
WHERE source = ?;

View File

@ -1,250 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.25.0
// source: routing.sql
package database
import (
"context"
"github.com/1f349/violet/target"
)
const addRedirect = `-- name: AddRedirect :exec
INSERT OR
REPLACE
INTO redirects (source, destination, description, flags, code, active)
VALUES (?, ?, ?, ?, ?, ?)
`
type AddRedirectParams struct {
Source string `json:"source"`
Destination string `json:"destination"`
Description string `json:"description"`
Flags target.Flags `json:"flags"`
Code int64 `json:"code"`
Active bool `json:"active"`
}
func (q *Queries) AddRedirect(ctx context.Context, arg AddRedirectParams) error {
_, err := q.db.ExecContext(ctx, addRedirect,
arg.Source,
arg.Destination,
arg.Description,
arg.Flags,
arg.Code,
arg.Active,
)
return err
}
const addRoute = `-- name: AddRoute :exec
INSERT OR
REPLACE
INTO routes (source, destination, description, flags, active)
VALUES (?, ?, ?, ?, ?)
`
type AddRouteParams struct {
Source string `json:"source"`
Destination string `json:"destination"`
Description string `json:"description"`
Flags target.Flags `json:"flags"`
Active bool `json:"active"`
}
func (q *Queries) AddRoute(ctx context.Context, arg AddRouteParams) error {
_, err := q.db.ExecContext(ctx, addRoute,
arg.Source,
arg.Destination,
arg.Description,
arg.Flags,
arg.Active,
)
return err
}
const getActiveRedirects = `-- name: GetActiveRedirects :many
SELECT source, destination, flags, code
FROM redirects
WHERE active = 1
`
type GetActiveRedirectsRow struct {
Source string `json:"source"`
Destination string `json:"destination"`
Flags target.Flags `json:"flags"`
Code int64 `json:"code"`
}
func (q *Queries) GetActiveRedirects(ctx context.Context) ([]GetActiveRedirectsRow, error) {
rows, err := q.db.QueryContext(ctx, getActiveRedirects)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetActiveRedirectsRow
for rows.Next() {
var i GetActiveRedirectsRow
if err := rows.Scan(
&i.Source,
&i.Destination,
&i.Flags,
&i.Code,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getActiveRoutes = `-- name: GetActiveRoutes :many
SELECT source, destination, flags
FROM routes
WHERE active = 1
`
type GetActiveRoutesRow struct {
Source string `json:"source"`
Destination string `json:"destination"`
Flags target.Flags `json:"flags"`
}
func (q *Queries) GetActiveRoutes(ctx context.Context) ([]GetActiveRoutesRow, error) {
rows, err := q.db.QueryContext(ctx, getActiveRoutes)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetActiveRoutesRow
for rows.Next() {
var i GetActiveRoutesRow
if err := rows.Scan(&i.Source, &i.Destination, &i.Flags); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getAllRedirects = `-- name: GetAllRedirects :many
SELECT source, destination, description, flags, code, active
FROM redirects
`
type GetAllRedirectsRow struct {
Source string `json:"source"`
Destination string `json:"destination"`
Description string `json:"description"`
Flags target.Flags `json:"flags"`
Code int64 `json:"code"`
Active bool `json:"active"`
}
func (q *Queries) GetAllRedirects(ctx context.Context) ([]GetAllRedirectsRow, error) {
rows, err := q.db.QueryContext(ctx, getAllRedirects)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetAllRedirectsRow
for rows.Next() {
var i GetAllRedirectsRow
if err := rows.Scan(
&i.Source,
&i.Destination,
&i.Description,
&i.Flags,
&i.Code,
&i.Active,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getAllRoutes = `-- name: GetAllRoutes :many
SELECT source, destination, description, flags, active
FROM routes
`
type GetAllRoutesRow struct {
Source string `json:"source"`
Destination string `json:"destination"`
Description string `json:"description"`
Flags target.Flags `json:"flags"`
Active bool `json:"active"`
}
func (q *Queries) GetAllRoutes(ctx context.Context) ([]GetAllRoutesRow, error) {
rows, err := q.db.QueryContext(ctx, getAllRoutes)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetAllRoutesRow
for rows.Next() {
var i GetAllRoutesRow
if err := rows.Scan(
&i.Source,
&i.Destination,
&i.Description,
&i.Flags,
&i.Active,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const removeRedirect = `-- name: RemoveRedirect :exec
DELETE
FROM redirects
WHERE source = ?
`
func (q *Queries) RemoveRedirect(ctx context.Context, source string) error {
_, err := q.db.ExecContext(ctx, removeRedirect, source)
return err
}
const removeRoute = `-- name: RemoveRoute :exec
DELETE
FROM routes
WHERE source = ?
`
func (q *Queries) RemoveRoute(ctx context.Context, source string) error {
_, err := q.db.ExecContext(ctx, removeRoute, source)
return err
}

View File

@ -0,0 +1,6 @@
CREATE TABLE IF NOT EXISTS domains
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT UNIQUE,
active INTEGER DEFAULT 1
);

View File

@ -1,34 +1,41 @@
package domains
import (
"context"
"database/sql"
_ "embed"
"github.com/1f349/violet/database"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/utils"
"github.com/mrmelon54/rescheduler"
"github.com/MrMelon54/rescheduler"
"github.com/MrMelon54/violet/utils"
"log"
"strings"
"sync"
)
var Logger = logger.Logger.WithPrefix("Violet Domains")
//go:embed create-table-domains.sql
var createTableDomains string
// Domains is the domain list and management system.
type Domains struct {
db *database.Queries
db *sql.DB
s *sync.RWMutex
m map[string]struct{}
r *rescheduler.Rescheduler
}
// New creates a new domain list
func New(db *database.Queries) *Domains {
func New(db *sql.DB) *Domains {
a := &Domains{
db: db,
s: &sync.RWMutex{},
m: make(map[string]struct{}),
}
a.r = rescheduler.NewRescheduler(a.threadCompile)
// init domains table
_, err := a.db.Exec(createTableDomains)
if err != nil {
log.Printf("[WARN] Failed to generate 'domains' table\n")
return nil
}
return a
}
@ -70,7 +77,7 @@ func (d *Domains) threadCompile() {
// compile map and check errors
err := d.internalCompile(domainMap)
if err != nil {
Logger.Info("Compile faile", "err", err)
log.Printf("[Domains] Compile failed: %s\n", err)
return
}
@ -83,39 +90,43 @@ func (d *Domains) threadCompile() {
// internalCompile is a hidden internal method for querying the database during
// the Compile() method.
func (d *Domains) internalCompile(m map[string]struct{}) error {
Logger.Info("Updating domains from database")
log.Println("[Domains] Updating domains from database")
// sql or something?
rows, err := d.db.GetActiveDomains(context.Background())
rows, err := d.db.Query(`select domain from domains where active = 1`)
if err != nil {
return err
}
defer rows.Close()
for _, i := range rows {
m[i] = struct{}{}
// loop through rows and scan the allowed domain names
for rows.Next() {
var name string
err = rows.Scan(&name)
if err != nil {
return err
}
m[name] = struct{}{}
}
// check for errors
return nil
return rows.Err()
}
func (d *Domains) Put(domain string, active bool) {
d.s.Lock()
defer d.s.Unlock()
err := d.db.AddDomain(context.Background(), database.AddDomainParams{
Domain: domain,
Active: active,
})
_, err := d.db.Exec("INSERT OR REPLACE INTO domains (domain, active) VALUES (?, ?)", domain, active)
if err != nil {
logger.Logger.Infof("Database error: %s\n", err)
log.Printf("[Violet] Database error: %s\n", err)
}
}
func (d *Domains) Delete(domain string) {
d.s.Lock()
defer d.s.Unlock()
err := d.db.DeleteDomain(context.Background(), domain)
_, err := d.db.Exec("INSERT OR REPLACE INTO domains (domain, active) VALUES (?, ?)", domain, false)
if err != nil {
logger.Logger.Infof("Database error: %s\n", err)
log.Printf("[Violet] Database error: %s\n", err)
}
}

View File

@ -1,20 +1,18 @@
package domains
import (
"context"
"github.com/1f349/violet"
"github.com/1f349/violet/database"
"database/sql"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"testing"
)
func TestDomainsNew(t *testing.T) {
db, err := violet.InitDB("file:TestDomainsNew?mode=memory&cache=shared")
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
assert.NoError(t, err)
domains := New(db)
err = db.AddDomain(context.Background(), database.AddDomainParams{Domain: "example.com", Active: true})
_, err = db.Exec("INSERT OR IGNORE INTO domains (domain, active) VALUES (?, ?)", "example.com", 1)
assert.NoError(t, err)
domains.Compile()
@ -29,11 +27,11 @@ func TestDomainsNew(t *testing.T) {
func TestDomains_IsValid(t *testing.T) {
// open sqlite database
db, err := violet.InitDB("file:TestDomains_IsValid?mode=memory&cache=shared")
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
assert.NoError(t, err)
domains := New(db)
err = db.AddDomain(context.Background(), database.AddDomainParams{Domain: "example.com", Active: true})
_, err = domains.db.Exec("INSERT OR IGNORE INTO domains (domain, active) VALUES (?, ?)", "example.com", 1)
assert.NoError(t, err)
domains.s.Lock()

View File

@ -2,9 +2,9 @@ package error_pages
import (
"fmt"
"github.com/1f349/violet/logger"
"github.com/mrmelon54/rescheduler"
"github.com/MrMelon54/rescheduler"
"io/fs"
"log"
"net/http"
"path/filepath"
"strconv"
@ -12,8 +12,6 @@ import (
"sync"
)
var Logger = logger.Logger.WithPrefix("Violet Error Pages")
// ErrorPages stores the custom error pages and is called by the servers to
// output meaningful pages for HTTP error codes
type ErrorPages struct {
@ -80,7 +78,7 @@ func (e *ErrorPages) threadCompile() {
if e.dir != nil {
err := e.internalCompile(errorPageMap)
if err != nil {
Logger.Info("Compile failed", "err", err)
log.Printf("[ErrorPages] Compile failed: %s\n", err)
return
}
}
@ -98,7 +96,7 @@ func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) err
return fmt.Errorf("failed to read error pages dir: %w", err)
}
Logger.Info("Compiling lookup table", "page count", len(files))
log.Printf("[ErrorPages] Compiling lookup table for %d error pages\n", len(files))
// find and load error pages
for _, i := range files {
@ -113,20 +111,20 @@ func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) err
// if the extension is not 'html' then ignore the file
if ext != ".html" {
Logger.Warn("Ignoring non '.html' file in error pages directory", "name", name)
log.Printf("[ErrorPages] WARNING: ignoring non '.html' file in error pages directory: '%s'\n", name)
continue
}
// if the name can't be
nameInt, err := strconv.Atoi(strings.TrimSuffix(name, ".html"))
if err != nil {
Logger.Warn("Ignoring invalid error page in error pages directory", "name", name)
log.Printf("[ErrorPages] WARNING: ignoring invalid error page in error pages directory: '%s'\n", name)
continue
}
// check if code is in range 100-599
if nameInt < 100 || nameInt >= 600 {
Logger.Warn("Ignoring invalid error page in error pages directory must be 100-599", "name", name)
log.Printf("[ErrorPages] WARNING: ignoring invalid error page in error pages directory must be 100-599: '%s'\n", name)
continue
}

View File

@ -0,0 +1,7 @@
CREATE TABLE IF NOT EXISTS favicons (
id INTEGER PRIMARY KEY AUTOINCREMENT,
host VARCHAR,
svg VARCHAR,
png VARCHAR,
ico VARCHAR
);

View File

@ -1,7 +1,5 @@
package favicons
import "database/sql"
// FaviconImage stores the url, hash and raw bytes of an image
type FaviconImage struct {
Url string
@ -11,9 +9,9 @@ type FaviconImage struct {
// CreateFaviconImage outputs a FaviconImage with the specified URL or nil if
// the URL is an empty string.
func CreateFaviconImage(url sql.NullString) *FaviconImage {
if !url.Valid {
func CreateFaviconImage(url string) *FaviconImage {
if url == "" {
return nil
}
return &FaviconImage{Url: url.String}
return &FaviconImage{Url: url}
}

View File

@ -6,7 +6,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"github.com/mrmelon54/png2ico"
"github.com/MrMelon54/png2ico"
"image/png"
"io"
"net/http"
@ -74,7 +74,7 @@ func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error
// download SVG
l.Svg.Raw, err = getFaviconViaRequest(l.Svg.Url)
if err != nil {
return fmt.Errorf("favicons: failed to fetch SVG icon: %w", err)
return fmt.Errorf("[Favicons] Failed to fetch SVG icon: %w", err)
}
l.Svg.Hash = hex.EncodeToString(sha256.New().Sum(l.Svg.Raw))
}
@ -84,14 +84,14 @@ func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error
// download PNG
l.Png.Raw, err = getFaviconViaRequest(l.Png.Url)
if err != nil {
return fmt.Errorf("favicons: failed to fetch PNG icon: %w", err)
return fmt.Errorf("[Favicons] Failed to fetch PNG icon: %w", err)
}
} else if l.Svg != nil {
// generate PNG from SVG
l.Png = &FaviconImage{}
l.Png.Raw, err = convert(l.Svg.Raw)
if err != nil {
return fmt.Errorf("favicons: failed to generate PNG icon: %w", err)
return fmt.Errorf("[Favicons] Failed to generate PNG icon: %w", err)
}
}
@ -100,19 +100,19 @@ func (l *FaviconList) PreProcess(convert func(in []byte) ([]byte, error)) error
// download ICO
l.Ico.Raw, err = getFaviconViaRequest(l.Ico.Url)
if err != nil {
return fmt.Errorf("favicons: failed to fetch ICO icon: %w", err)
return fmt.Errorf("[Favicons] Failed to fetch ICO icon: %w", err)
}
} else if l.Png != nil {
// generate ICO from PNG
l.Ico = &FaviconImage{}
decode, err := png.Decode(bytes.NewReader(l.Png.Raw))
if err != nil {
return fmt.Errorf("favicons: failed to decode PNG icon: %w", err)
return fmt.Errorf("[Favicons] Failed to decode PNG icon: %w", err)
}
b := decode.Bounds()
l.Ico.Raw, err = png2ico.ConvertPngToIco(l.Png.Raw, b.Dx(), b.Dy())
if err != nil {
return fmt.Errorf("favicons: failed to generate ICO icon: %w", err)
return fmt.Errorf("[Favicons] Failed to generate ICO icon: %w", err)
}
}
@ -139,16 +139,16 @@ func (l *FaviconList) genSha256() {
var getFaviconViaRequest = func(url string) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("favicons: Failed to send request '%s': %w", url, err)
return nil, fmt.Errorf("[Favicons] Failed to send request '%s': %w", url, err)
}
req.Header.Set("X-Violet-Raw-Favicon", "1")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("favicons: failed to do request '%s': %w", url, err)
return nil, fmt.Errorf("[Favicons] Failed to do request '%s': %w", url, err)
}
rawBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("favicons: failed to read response '%s': %w", url, err)
return nil, fmt.Errorf("[Favicons] Failed to read response '%s': %w", url, err)
}
return rawBody, nil
}

View File

@ -1,24 +1,24 @@
package favicons
import (
"context"
"database/sql"
_ "embed"
"errors"
"fmt"
"github.com/1f349/violet/database"
"github.com/1f349/violet/logger"
"github.com/mrmelon54/rescheduler"
"github.com/MrMelon54/rescheduler"
"golang.org/x/sync/errgroup"
"log"
"sync"
)
var Logger = logger.Logger.WithPrefix("Violet Favicons")
var ErrFaviconNotFound = errors.New("favicon not found")
//go:embed create-table-favicons.sql
var createTableFavicons string
// Favicons is a dynamic favicon generator which supports overwriting favicons
type Favicons struct {
db *database.Queries
db *sql.DB
cmd string
cLock *sync.RWMutex
faviconMap map[string]*FaviconList
@ -26,7 +26,7 @@ type Favicons struct {
}
// New creates a new dynamic favicon generator
func New(db *database.Queries, inkscapeCmd string) *Favicons {
func New(db *sql.DB, inkscapeCmd string) *Favicons {
f := &Favicons{
db: db,
cmd: inkscapeCmd,
@ -35,6 +35,13 @@ func New(db *database.Queries, inkscapeCmd string) *Favicons {
}
f.r = rescheduler.NewRescheduler(f.threadCompile)
// init favicons table
_, err := f.db.Exec(createTableFavicons)
if err != nil {
log.Printf("[WARN] Failed to generate 'favicons' table\n")
return nil
}
// run compile to get the initial data
f.Compile()
return f
@ -68,7 +75,7 @@ func (f *Favicons) threadCompile() {
err := f.internalCompile(favicons)
if err != nil {
// log compile errors
Logger.Info("Compile failed", "err", err)
log.Printf("[Favicons] Compile failed: %s\n", err)
return
}
@ -82,23 +89,29 @@ func (f *Favicons) threadCompile() {
// favicons.
func (f *Favicons) internalCompile(m map[string]*FaviconList) error {
// query all rows in database
rows, err := f.db.GetFavicons(context.Background())
query, err := f.db.Query(`select host, svg, png, ico from favicons`)
if err != nil {
return fmt.Errorf("failed to prepare rows: %w", err)
return fmt.Errorf("failed to prepare query: %w", err)
}
// loop over rows and scan in data using error group to catch errors
var g errgroup.Group
for _, row := range rows {
for query.Next() {
var host, rawSvg, rawPng, rawIco string
err := query.Scan(&host, &rawSvg, &rawPng, &rawIco)
if err != nil {
return fmt.Errorf("failed to scan row: %w", err)
}
// create favicon list for this row
l := &FaviconList{
Ico: CreateFaviconImage(row.Ico),
Png: CreateFaviconImage(row.Png),
Svg: CreateFaviconImage(row.Svg),
Ico: CreateFaviconImage(rawIco),
Png: CreateFaviconImage(rawPng),
Svg: CreateFaviconImage(rawSvg),
}
// save the favicon list to the map
m[row.Host] = l
m[host] = l
// run the pre-process in a separate goroutine
g.Go(func() error {

View File

@ -2,11 +2,8 @@ package favicons
import (
"bytes"
"context"
"database/sql"
_ "embed"
"github.com/1f349/violet"
"github.com/1f349/violet/database"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"image/png"
@ -25,17 +22,11 @@ var (
func TestFaviconsNew(t *testing.T) {
getFaviconViaRequest = func(_ string) ([]byte, error) { return exampleSvg, nil }
db, err := violet.InitDB("file:TestFaviconsNew?mode=memory&cache=shared")
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
assert.NoError(t, err)
favicons := New(db, "inkscape")
err = db.UpdateFaviconCache(context.Background(), database.UpdateFaviconCacheParams{
Host: "example.com",
Svg: sql.NullString{
String: "https://example.com/assets/logo.svg",
Valid: true,
},
})
_, err = db.Exec("insert into favicons (host, svg, png, ico) values (?, ?, ?, ?)", "example.com", "https://example.com/assets/logo.svg", "", "")
assert.NoError(t, err)
favicons.cLock.Lock()
assert.NoError(t, favicons.internalCompile(favicons.faviconMap))

67
go.mod
View File

@ -1,62 +1,37 @@
module github.com/1f349/violet
module github.com/MrMelon54/violet
go 1.22
go 1.20
require (
github.com/1f349/mjwt v0.2.5
github.com/AlecAivazis/survey/v2 v2.3.7
github.com/charmbracelet/log v0.4.0
github.com/cloudflare/tableflip v1.2.3
github.com/golang-migrate/migrate/v4 v4.17.1
github.com/MrMelon54/certgen v0.0.1
github.com/MrMelon54/mjwt v0.1.1
github.com/MrMelon54/png2ico v1.0.1
github.com/MrMelon54/rescheduler v0.0.1
github.com/MrMelon54/trie v0.0.2
github.com/google/subcommands v1.2.0
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.1
github.com/julienschmidt/httprouter v1.3.0
github.com/mattn/go-sqlite3 v1.14.22
github.com/mrmelon54/certgen v0.0.2
github.com/mrmelon54/png2ico v1.0.2
github.com/mrmelon54/rescheduler v0.0.3
github.com/mrmelon54/trie v0.0.3
github.com/prometheus/client_golang v1.19.1
github.com/rs/cors v1.11.0
github.com/sethvargo/go-limiter v1.0.0
github.com/stretchr/testify v1.9.0
golang.org/x/net v0.25.0
golang.org/x/sync v0.7.0
github.com/mattn/go-sqlite3 v1.14.16
github.com/rs/cors v1.9.0
github.com/sethvargo/go-limiter v0.7.2
github.com/stretchr/testify v1.8.4
golang.org/x/net v0.9.0
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4
)
require (
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/becheran/wildmatch-go v1.0.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/lipgloss v0.10.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-logfmt/logfmt v0.6.0 // indirect
github.com/golang-jwt/jwt/v4 v4.5.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect
github.com/muesli/reflow v0.3.0 // indirect
github.com/muesli/termenv v0.15.2 // indirect
github.com/kr/pretty v0.1.0 // indirect
github.com/mattn/go-colorable v0.1.2 // indirect
github.com/mattn/go-isatty v0.0.8 // indirect
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.53.0 // indirect
github.com/prometheus/procfs v0.14.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect
golang.org/x/sys v0.20.0 // indirect
golang.org/x/term v0.20.0 // indirect
golang.org/x/text v0.15.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
golang.org/x/sys v0.7.0 // indirect
golang.org/x/term v0.7.0 // indirect
golang.org/x/text v0.9.0 // indirect
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

142
go.sum
View File

@ -1,162 +1,94 @@
github.com/1f349/mjwt v0.2.5 h1:IxjLaali22ayTzZ628lH7j0JDdYJoj6+CJ/VktCqtXQ=
github.com/1f349/mjwt v0.2.5/go.mod h1:KEs6jd9JjWrQW+8feP2pGAU7pdA3aYTqjkT/YQr73PU=
github.com/AlecAivazis/survey/v2 v2.3.7 h1:6I/u8FvytdGsgonrYsVn2t8t4QiRnh6QSTqkkhIiSjQ=
github.com/AlecAivazis/survey/v2 v2.3.7/go.mod h1:xUTIdE4KCOIjsBAE1JYsUPoCqYdZ1reCfTwbto0Fduo=
github.com/MrMelon54/certgen v0.0.1 h1:ycWdZ2RlxQ5qSuejeBVv4aXjGo5hdqqL4j4EjrXnFMk=
github.com/MrMelon54/certgen v0.0.1/go.mod h1:GHflVlSbtFLJZLpN1oWyUvDBRrR8qCWiwZLXCCnS2Gc=
github.com/MrMelon54/mjwt v0.1.1 h1:m+aTpxbhQCrOPKHN170DQMFR5r938LkviU38unob5Jw=
github.com/MrMelon54/mjwt v0.1.1/go.mod h1:oYrDBWK09Hju98xb+bRQ0wy+RuAzacxYvKYOZchR2Tk=
github.com/MrMelon54/png2ico v1.0.1 h1:zJoSSl4OkvSIMWGyGPvb8fWNa0KrUvMIjgNGLNLJhVQ=
github.com/MrMelon54/png2ico v1.0.1/go.mod h1:NOv3tO4497mInG+3tcFkIohmxCywUwMLU8WNxJZLVmU=
github.com/MrMelon54/rescheduler v0.0.1 h1:gzNvL8X81M00uYN0i9clFVrXCkG1UuLNYxDcvjKyBqo=
github.com/MrMelon54/rescheduler v0.0.1/go.mod h1:OQDFtZHdS4/qA/r7rtJUQA22/hbpnZ9MGQCXOPjhC6w=
github.com/MrMelon54/trie v0.0.2 h1:ZXWcX5ij62O9K4I/anuHmVg8L3tF0UGdlPceAASwKEY=
github.com/MrMelon54/trie v0.0.2/go.mod h1:sGCGOcqb+DxSxvHgSOpbpkmA7mFZR47YDExy9OCbVZI=
github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2 h1:+vx7roKuyA63nhn5WAunQHLTznkw5W8b1Xc0dNjp83s=
github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2/go.mod h1:HBCaDeC1lPdgDeDbhX8XFpy1jqjK0IBG8W5K+xYqA0w=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA=
github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMtwg/AFW3s=
github.com/charmbracelet/lipgloss v0.10.0/go.mod h1:Wig9DSfvANsxqkRsqj6x87irdy123SR4dOXlKa91ciE=
github.com/charmbracelet/log v0.4.0 h1:G9bQAcx8rWA2T3pWvx7YtPTPwgqpk7D68BX21IRW8ZM=
github.com/charmbracelet/log v0.4.0/go.mod h1:63bXt/djrizTec0l11H20t8FDSvA4CRZJ1KH22MdptM=
github.com/cloudflare/tableflip v1.2.3 h1:8I+B99QnnEWPHOY3fWipwVKxS70LGgUsslG7CSfmHMw=
github.com/cloudflare/tableflip v1.2.3/go.mod h1:P4gRehmV6Z2bY5ao5ml9Pd8u6kuEnlB37pUFMmv7j2E=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/creack/pty v1.1.17 h1:QeVUsEDNrLBW4tMgZHvxy18sKtr6VI492kBhUfhDJNI=
github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-migrate/migrate/v4 v4.17.1 h1:4zQ6iqL6t6AiItphxJctQb3cFqWiSpMnX7wLTPnnYO4=
github.com/golang-migrate/migrate/v4 v4.17.1/go.mod h1:m8hinFyWBn0SA4QKHuKh175Pm9wjmxj3S2Mia7dbXzM=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog=
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU=
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE=
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4=
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI=
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
github.com/mrmelon54/certgen v0.0.2 h1:4CMDkA/gGZu+E4iikU+5qdOWK7qOQrk58KtUfnmyYmY=
github.com/mrmelon54/certgen v0.0.2/go.mod h1:vwrWSXQmxZYqEyh+cf05IvDIFV2aYuxL4+O6ABIlN8M=
github.com/mrmelon54/png2ico v1.0.2 h1:KyJd3ATmDjxAJS28MTSf44GxzYnlZ+7KT8SXzGb3sN8=
github.com/mrmelon54/png2ico v1.0.2/go.mod h1:vp8Be9y5cz102ANon+BnsIzTUdet3VQRvOuWJTH9h0M=
github.com/mrmelon54/rescheduler v0.0.3 h1:TrkJL6S7PKvXuo1mvdgRgsILA/pk5L1lrXhV/q7IEzQ=
github.com/mrmelon54/rescheduler v0.0.3/go.mod h1:q415n6W1xcePPP5Rix6FOiADgcN66BYMyNOsFnNyoWQ=
github.com/mrmelon54/trie v0.0.3 h1:wZmws84FiGNBZJ00garLyQ2EQhtx0SipVoV7fK8+kZE=
github.com/mrmelon54/trie v0.0.3/go.mod h1:d3hl0YUBSWR3XN4S9BDLkGVzLT4VgwP2mZkBJM6uFpw=
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo=
github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE=
github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+aLCE=
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
github.com/prometheus/procfs v0.14.0 h1:Lw4VdGGoKEZilJsayHf0B+9YgLGREba2C6xr+Fdfq6s=
github.com/prometheus/procfs v0.14.0/go.mod h1:XL+Iwz8k8ZabyZfMFHPiilCniixqQarAy5Mu67pHlNQ=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/rs/cors v1.11.0 h1:0B9GE/r9Bc2UxRMMtymBkHTenPkHDv0CW4Y98GBY+po=
github.com/rs/cors v1.11.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/sethvargo/go-limiter v1.0.0 h1:JqW13eWEMn0VFv86OKn8wiYJY/m250WoXdrjRV0kLe4=
github.com/sethvargo/go-limiter v1.0.0/go.mod h1:01b6tW25Ap+MeLYBuD4aHunMrJoNO5PVUFdS9rac3II=
github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE=
github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/sethvargo/go-limiter v0.7.2 h1:FgC4N7RMpV5gMrUdda15FaFTkQ/L4fEqM7seXMs4oO8=
github.com/sethvargo/go-limiter v0.7.2/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.7.0 h1:BEvjmm5fURWqcfbSKTdpkDXYBrUS1c0m8agp14W48vQ=
golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,38 +0,0 @@
package violet
import (
"database/sql"
"embed"
"errors"
"github.com/1f349/violet/database"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/iofs"
)
//go:embed database/migrations/*.sql
var migrations embed.FS
func InitDB(p string) (*database.Queries, error) {
migDrv, err := iofs.New(migrations, "database/migrations")
if err != nil {
return nil, err
}
dbOpen, err := sql.Open("sqlite3", p)
if err != nil {
return nil, err
}
dbDrv, err := sqlite3.WithInstance(dbOpen, &sqlite3.Config{})
if err != nil {
return nil, err
}
mig, err := migrate.NewWithInstance("iofs", migDrv, "sqlite3", dbDrv)
if err != nil {
return nil, err
}
err = mig.Up()
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
return nil, err
}
return database.New(dbOpen), nil
}

View File

@ -1,12 +0,0 @@
package logger
import (
"github.com/charmbracelet/log"
"os"
)
var Logger = log.NewWithOptions(os.Stderr, log.Options{
ReportCaller: true,
ReportTimestamp: true,
Prefix: "Violet",
})

View File

@ -2,37 +2,30 @@ package proxy
import (
"crypto/tls"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/proxy/websocket"
"net"
"net/http"
"sync"
"time"
)
var loggerSecure = logger.Logger.WithPrefix("Violet Secure Transport")
var loggerInsecure = logger.Logger.WithPrefix("Violet Insecure Transport")
var loggerWebsocket = logger.Logger.WithPrefix("Violet Websocket Transport")
type HybridTransport struct {
baseDialer *net.Dialer
normalTransport http.RoundTripper
insecureTransport http.RoundTripper
socksSync *sync.RWMutex
socksTransport map[string]http.RoundTripper
ws *websocket.Server
}
// NewHybridTransport creates a new hybrid transport
func NewHybridTransport(ws *websocket.Server) *HybridTransport {
return NewHybridTransportWithCalls(nil, nil, ws)
func NewHybridTransport() *HybridTransport {
return NewHybridTransportWithCalls(nil, nil)
}
// NewHybridTransportWithCalls creates new hybrid transport with custom normal
// and insecure http.RoundTripper functions.
//
// NewHybridTransportWithCalls(nil, nil) is equivalent to NewHybridTransport()
func NewHybridTransportWithCalls(normal, insecure http.RoundTripper, ws *websocket.Server) *HybridTransport {
func NewHybridTransportWithCalls(normal, insecure http.RoundTripper) *HybridTransport {
h := &HybridTransport{
baseDialer: &net.Dialer{
Timeout: 30 * time.Second,
@ -40,7 +33,6 @@ func NewHybridTransportWithCalls(normal, insecure http.RoundTripper, ws *websock
},
normalTransport: normal,
insecureTransport: insecure,
ws: ws,
}
if h.normalTransport == nil {
h.normalTransport = &http.Transport{
@ -52,7 +44,6 @@ func NewHybridTransportWithCalls(normal, insecure http.RoundTripper, ws *websock
IdleConnTimeout: 30 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
DisableKeepAlives: true,
}
}
if h.insecureTransport == nil {
@ -66,7 +57,6 @@ func NewHybridTransportWithCalls(normal, insecure http.RoundTripper, ws *websock
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
DisableKeepAlives: true,
}
}
return h
@ -81,8 +71,3 @@ func (h *HybridTransport) SecureRoundTrip(req *http.Request) (*http.Response, er
func (h *HybridTransport) InsecureRoundTrip(req *http.Request) (*http.Response, error) {
return h.insecureTransport.RoundTrip(req)
}
// ConnectWebsocket calls the websocket upgrader and thus hijacks the connection
func (h *HybridTransport) ConnectWebsocket(rw http.ResponseWriter, req *http.Request) {
h.ws.Upgrade(rw, req)
}

View File

@ -7,7 +7,7 @@ import (
)
func TestNewHybridTransport(t *testing.T) {
h := NewHybridTransport(nil)
h := NewHybridTransport()
req, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
assert.NoError(t, err)
trip, err := h.SecureRoundTrip(req)

View File

@ -1,134 +0,0 @@
package websocket
import (
"github.com/1f349/violet/logger"
"github.com/gorilla/websocket"
"net/http"
"slices"
"sync"
"time"
)
var Logger = logger.Logger.WithPrefix("Violet Websocket")
var upgrader = websocket.Upgrader{
HandshakeTimeout: time.Second * 5,
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// allow requests from any origin
// the internal service can decide what origins to allow
return true
},
}
type Server struct {
connLock *sync.RWMutex
connStop bool
conns map[string]*websocket.Conn
}
func NewServer() *Server {
return &Server{
connLock: new(sync.RWMutex),
conns: make(map[string]*websocket.Conn),
}
}
func (s *Server) Upgrade(rw http.ResponseWriter, req *http.Request) {
req.URL.Scheme = "ws"
Logger.Info("Upgrading request", "url", req.URL, "origin", req.Header.Get("Origin"))
c, err := upgrader.Upgrade(rw, req, nil)
if err != nil {
return
}
defer c.Close()
s.connLock.Lock()
// no more connections allowed
if s.connStop {
s.connLock.Unlock()
return
}
// save connection for shutdown
s.conns[c.RemoteAddr().String()] = c
s.connLock.Unlock()
Logger.Info("Dialing", "url", req.URL)
// dial for internal connection
ic, _, err := websocket.DefaultDialer.DialContext(req.Context(), req.URL.String(), filterWebsocketHeaders(req.Header))
if err != nil {
Logger.Info("Failed to dial", "url", req.URL, "err", err)
s.Remove(c)
return
}
defer ic.Close()
d1 := make(chan struct{}, 1)
d2 := make(chan struct{}, 1)
// relay messages each way
go s.wsRelay(d1, c, ic)
go s.wsRelay(d2, ic, c)
// wait for done signal and close both connections
Logger.Info("Completed websocket hijacking")
// waiting until d1 or d2 close then automatically defer close both connections
select {
case <-d1:
case <-d2:
}
}
// filterWebsocketHeaders allows specific headers to forward to the underlying websocket connection
func filterWebsocketHeaders(headers http.Header) (out http.Header) {
out = make(http.Header)
for k, v := range headers {
if k == "Origin" {
out[k] = slices.Clone(v)
}
}
return
}
func (s *Server) wsRelay(done chan struct{}, a, b *websocket.Conn) {
defer func() {
close(done)
}()
for {
mt, message, err := a.ReadMessage()
if err != nil {
Logger.Info("Read message", "err", err)
return
}
if b.WriteMessage(mt, message) != nil {
return
}
}
}
func (s *Server) Remove(c *websocket.Conn) {
s.connLock.Lock()
delete(s.conns, c.RemoteAddr().String())
s.connLock.Unlock()
_ = c.Close()
}
func (s *Server) Shutdown() {
s.connLock.Lock()
defer s.connLock.Unlock()
// flag shutdown and close all open connections
s.connStop = true
for _, i := range s.conns {
_ = i.Close()
}
// clear connections, not required but do it anyway
s.conns = make(map[string]*websocket.Conn)
}

18
router/create-tables.sql Normal file
View File

@ -0,0 +1,18 @@
CREATE TABLE IF NOT EXISTS routes
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
source TEXT UNIQUE,
destination TEXT,
flags INTEGER DEFAULT 0,
active INTEGER DEFAULT 1
);
CREATE TABLE IF NOT EXISTS redirects
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
source TEXT UNIQUE,
destination TEXT,
flags INTEGER DEFAULT 0,
code INTEGER DEFAULT 0,
active INTEGER DEFAULT 1
);

View File

@ -1,33 +1,34 @@
package router
import (
"context"
"database/sql"
_ "embed"
"github.com/1f349/violet/database"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/target"
"github.com/mrmelon54/rescheduler"
"github.com/MrMelon54/rescheduler"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/target"
"log"
"net/http"
"strings"
"sync"
)
var Logger = logger.Logger.WithPrefix("Violet Manager")
// Manager is a database and mutex wrap around router allowing it to be
// dynamically regenerated after updating the database of routes.
type Manager struct {
db *database.Queries
db *sql.DB
s *sync.RWMutex
r *Router
p *proxy.HybridTransport
z *rescheduler.Rescheduler
}
var (
//go:embed create-tables.sql
createTables string
)
// NewManager create a new manager, initialises the routes and redirects tables
// in the database and runs a first time compile.
func NewManager(db *database.Queries, proxy *proxy.HybridTransport) *Manager {
func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager {
m := &Manager{
db: db,
s: &sync.RWMutex{},
@ -35,14 +36,20 @@ func NewManager(db *database.Queries, proxy *proxy.HybridTransport) *Manager {
p: proxy,
}
m.z = rescheduler.NewRescheduler(m.threadCompile)
// init routes table
_, err := m.db.Exec(createTables)
if err != nil {
log.Printf("[WARN] Failed to generate tables\n")
return nil
}
return m
}
func (m *Manager) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
m.s.RLock()
r := m.r
m.r.ServeHTTP(rw, req)
m.s.RUnlock()
r.ServeHTTP(rw, req)
}
func (m *Manager) Compile() {
@ -56,7 +63,7 @@ func (m *Manager) threadCompile() {
// compile router and check errors
err := m.internalCompile(router)
if err != nil {
Logger.Info("Compile failed", "err", err)
log.Printf("[Manager] Compile failed: %s\n", err)
return
}
@ -69,160 +76,123 @@ func (m *Manager) threadCompile() {
// internalCompile is a hidden internal method for querying the database during
// the Compile() method.
func (m *Manager) internalCompile(router *Router) error {
Logger.Info("Updating routes from database")
log.Println("[Manager] Updating routes from database")
// sql or something?
routeRows, err := m.db.GetActiveRoutes(context.Background())
rows, err := m.db.Query(`SELECT source, destination, flags FROM routes WHERE active = 1`)
if err != nil {
return err
}
defer rows.Close()
// loop through rows and scan the options
for rows.Next() {
var (
src, dst string
flags target.Flags
)
err := rows.Scan(&src, &dst, &flags)
if err != nil {
return err
}
for _, row := range routeRows {
router.AddRoute(target.Route{
Src: row.Source,
Dst: row.Destination,
Flags: row.Flags.NormaliseRouteFlags(),
})
}
// sql or something?
redirectsRows, err := m.db.GetActiveRedirects(context.Background())
if err != nil {
return err
}
for _, row := range redirectsRows {
router.AddRedirect(target.Redirect{
Src: row.Source,
Dst: row.Destination,
Flags: row.Flags.NormaliseRedirectFlags(),
Code: row.Code,
Src: src,
Dst: dst,
Flags: flags.NormaliseRouteFlags(),
})
}
// check for errors
return nil
}
func (m *Manager) GetAllRoutes(hosts []string) ([]target.RouteWithActive, error) {
if len(hosts) < 1 {
return []target.RouteWithActive{}, nil
if err := rows.Err(); err != nil {
return err
}
// sql or something?
rows, err = m.db.Query(`SELECT source,destination,flags,code FROM redirects WHERE active = 1`)
if err != nil {
return err
}
defer rows.Close()
// loop through rows and scan the options
for rows.Next() {
var (
src, dst string
flags target.Flags
code int
)
err := rows.Scan(&src, &dst, &flags, &code)
if err != nil {
return err
}
router.AddRedirect(target.Redirect{
Src: src,
Dst: dst,
Flags: flags.NormaliseRedirectFlags(),
Code: code,
})
}
// check for errors
return rows.Err()
}
func (m *Manager) GetAllRoutes() ([]target.RouteWithActive, error) {
s := make([]target.RouteWithActive, 0)
rows, err := m.db.GetAllRoutes(context.Background())
query, err := m.db.Query(`SELECT source, destination, flags, active FROM routes`)
if err != nil {
return nil, err
}
for _, row := range rows {
a := target.RouteWithActive{
Route: target.Route{
Src: row.Source,
Dst: row.Destination,
Desc: row.Description,
Flags: row.Flags,
},
Active: row.Active,
}
for _, i := range hosts {
// if this is never true then the domain was mistakenly grabbed from the database
if a.OnDomain(i) {
s = append(s, a)
break
}
for query.Next() {
var a target.RouteWithActive
if query.Scan(&a.Src, &a.Dst, &a.Flags, &a.Active) != nil {
return nil, err
}
s = append(s, a)
}
return s, nil
}
func (m *Manager) InsertRoute(route target.RouteWithActive) error {
return m.db.AddRoute(context.Background(), database.AddRouteParams{
Source: route.Src,
Destination: route.Dst,
Description: route.Desc,
Flags: route.Flags,
Active: route.Active,
})
func (m *Manager) InsertRoute(route target.Route) error {
_, err := m.db.Exec(`INSERT INTO routes (source, destination, flags) VALUES (?, ?, ?) ON CONFLICT(source) DO UPDATE SET destination = excluded.destination, flags = excluded.flags, active = 1`, route.Src, route.Dst, route.Flags)
return err
}
func (m *Manager) DeleteRoute(source string) error {
return m.db.RemoveRoute(context.Background(), source)
_, err := m.db.Exec(`UPDATE routes SET active = 0 WHERE source = ?`, source)
return err
}
func (m *Manager) GetAllRedirects(hosts []string) ([]target.RedirectWithActive, error) {
if len(hosts) < 1 {
return []target.RedirectWithActive{}, nil
}
func (m *Manager) GetAllRedirects() ([]target.RedirectWithActive, error) {
s := make([]target.RedirectWithActive, 0)
rows, err := m.db.GetAllRedirects(context.Background())
query, err := m.db.Query(`SELECT source, destination, flags, code, active FROM redirects`)
if err != nil {
return nil, err
}
for _, row := range rows {
a := target.RedirectWithActive{
Redirect: target.Redirect{
Src: row.Source,
Dst: row.Destination,
Desc: row.Description,
Flags: row.Flags,
Code: row.Code,
},
Active: row.Active,
}
for _, i := range hosts {
// if this is never true then the domain was mistakenly grabbed from the database
if a.OnDomain(i) {
s = append(s, a)
break
}
for query.Next() {
var a target.RedirectWithActive
if query.Scan(&a.Src, &a.Dst, &a.Flags, &a.Code, &a.Active) != nil {
return nil, err
}
s = append(s, a)
}
return s, nil
}
func (m *Manager) InsertRedirect(redirect target.RedirectWithActive) error {
return m.db.AddRedirect(context.Background(), database.AddRedirectParams{
Source: redirect.Src,
Destination: redirect.Dst,
Description: redirect.Desc,
Flags: redirect.Flags,
Code: redirect.Code,
Active: redirect.Active,
})
func (m *Manager) InsertRedirect(redirect target.Redirect) error {
_, err := m.db.Exec(`INSERT INTO redirects (source, destination, flags, code) VALUES (?, ?, ?, ?) ON CONFLICT(source) DO UPDATE SET destination = excluded.destination, flags = excluded.flags, code = excluded.code, active = 1`, redirect.Src, redirect.Dst, redirect.Flags, redirect.Code)
return err
}
func (m *Manager) DeleteRedirect(source string) error {
return m.db.RemoveRedirect(context.Background(), source)
}
// GenerateHostSearch this should help improve performance
// TODO(Melon) discover how to implement this correctly
func GenerateHostSearch(hosts []string) (string, []string) {
var searchString strings.Builder
searchString.WriteString("WHERE ")
hostArgs := make([]string, len(hosts)*2)
for i := range hosts {
if i != 0 {
searchString.WriteString(" OR ")
}
// these like checks are not perfect but do reduce load on the database
searchString.WriteString("source LIKE '%' + ? + '/%'")
searchString.WriteString(" OR source LIKE '%' + ?")
// loads the hostname into even and odd args
hostArgs[i*2] = hosts[i]
hostArgs[i*2+1] = hosts[i]
}
return searchString.String(), hostArgs
_, err := m.db.Exec(`UPDATE redirects SET active = 0 WHERE source = ?`, source)
return err
}

View File

@ -1,12 +1,9 @@
package router
import (
"context"
"github.com/1f349/violet"
"github.com/1f349/violet/database"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/1f349/violet/target"
"database/sql"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/target"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"net/http"
@ -24,11 +21,11 @@ func (f *fakeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
}
func TestNewManager(t *testing.T) {
db, err := violet.InitDB("file:TestNewManager?mode=memory&cache=shared")
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
assert.NoError(t, err)
ft := &fakeTransport{}
ht := proxy.NewHybridTransportWithCalls(ft, ft, &websocket.Server{})
ht := proxy.NewHybridTransportWithCalls(ft, ft)
m := NewManager(db, ht)
assert.NoError(t, m.internalCompile(m.r))
@ -41,13 +38,7 @@ func TestNewManager(t *testing.T) {
assert.Equal(t, http.StatusTeapot, res.StatusCode)
assert.Nil(t, ft.req)
err = db.AddRoute(context.Background(), database.AddRouteParams{
Source: "*.example.com",
Destination: "127.0.0.1:8080",
Description: "",
Flags: target.FlagAbs | target.FlagForwardHost | target.FlagForwardAddr,
Active: true,
})
_, err = db.Exec(`INSERT INTO routes (source, destination, flags, active) VALUES (?,?,?,1)`, "*.example.com", "127.0.0.1:8080", target.FlagAbs|target.FlagForwardHost|target.FlagForwardAddr)
assert.NoError(t, err)
assert.NoError(t, m.internalCompile(m.r))
@ -58,71 +49,3 @@ func TestNewManager(t *testing.T) {
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.NotNil(t, ft.req)
}
func TestManager_GetAllRoutes(t *testing.T) {
db, err := violet.InitDB("file:TestManager_GetAllRoutes?mode=memory&cache=shared")
assert.NoError(t, err)
m := NewManager(db, nil)
a := []error{
m.InsertRoute(target.RouteWithActive{Route: target.Route{Src: "example.com"}, Active: true}),
m.InsertRoute(target.RouteWithActive{Route: target.Route{Src: "test.example.com"}, Active: true}),
m.InsertRoute(target.RouteWithActive{Route: target.Route{Src: "example.com/hello"}, Active: true}),
m.InsertRoute(target.RouteWithActive{Route: target.Route{Src: "test.example.com/hello"}, Active: true}),
m.InsertRoute(target.RouteWithActive{Route: target.Route{Src: "example.org"}, Active: true}),
m.InsertRoute(target.RouteWithActive{Route: target.Route{Src: "test.example.org"}, Active: true}),
m.InsertRoute(target.RouteWithActive{Route: target.Route{Src: "example.org/hello"}, Active: true}),
m.InsertRoute(target.RouteWithActive{Route: target.Route{Src: "test.example.org/hello"}, Active: true}),
}
for _, i := range a {
if i != nil {
t.Fatal(i)
}
}
routes, err := m.GetAllRoutes([]string{"example.com"})
if err != nil {
t.Fatal(err)
}
assert.Equal(t, []target.RouteWithActive{
{Route: target.Route{Src: "example.com"}, Active: true},
{Route: target.Route{Src: "test.example.com"}, Active: true},
{Route: target.Route{Src: "example.com/hello"}, Active: true},
{Route: target.Route{Src: "test.example.com/hello"}, Active: true},
}, routes)
}
func TestManager_GetAllRedirects(t *testing.T) {
db, err := violet.InitDB("file:TestManager_GetAllRedirects?mode=memory&cache=shared")
assert.NoError(t, err)
m := NewManager(db, nil)
a := []error{
m.InsertRedirect(target.RedirectWithActive{Redirect: target.Redirect{Src: "example.com"}, Active: true}),
m.InsertRedirect(target.RedirectWithActive{Redirect: target.Redirect{Src: "test.example.com"}, Active: true}),
m.InsertRedirect(target.RedirectWithActive{Redirect: target.Redirect{Src: "example.com/hello"}, Active: true}),
m.InsertRedirect(target.RedirectWithActive{Redirect: target.Redirect{Src: "test.example.com/hello"}, Active: true}),
m.InsertRedirect(target.RedirectWithActive{Redirect: target.Redirect{Src: "example.org"}, Active: true}),
m.InsertRedirect(target.RedirectWithActive{Redirect: target.Redirect{Src: "test.example.org"}, Active: true}),
m.InsertRedirect(target.RedirectWithActive{Redirect: target.Redirect{Src: "example.org/hello"}, Active: true}),
m.InsertRedirect(target.RedirectWithActive{Redirect: target.Redirect{Src: "test.example.org/hello"}, Active: true}),
}
for _, i := range a {
if i != nil {
t.Fatal(i)
}
}
redirects, err := m.GetAllRedirects([]string{"example.com"})
if err != nil {
t.Fatal(err)
}
assert.Equal(t, []target.RedirectWithActive{
{Redirect: target.Redirect{Src: "example.com"}, Active: true},
{Redirect: target.Redirect{Src: "test.example.com"}, Active: true},
{Redirect: target.Redirect{Src: "example.com/hello"}, Active: true},
{Redirect: target.Redirect{Src: "test.example.com/hello"}, Active: true},
}, redirects)
}
func TestGenerateHostSearch(t *testing.T) {
query, args := GenerateHostSearch([]string{"example.com", "example.org"})
assert.Equal(t, "WHERE source LIKE '%' + ? + '/%' OR source LIKE '%' + ? OR source LIKE '%' + ? + '/%' OR source LIKE '%' + ?", query)
assert.Equal(t, []string{"example.com", "example.com", "example.org", "example.org"}, args)
}

View File

@ -2,10 +2,10 @@ package router
import (
"fmt"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/target"
"github.com/1f349/violet/utils"
"github.com/mrmelon54/trie"
"github.com/MrMelon54/trie"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/target"
"github.com/MrMelon54/violet/utils"
"net/http"
"strings"
)
@ -90,29 +90,29 @@ func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
func (r *Router) serveRouteHTTP(rw http.ResponseWriter, req *http.Request, host string) bool {
h := r.route[host]
return getServeData(rw, req, h)
}
func (r *Router) serveRedirectHTTP(rw http.ResponseWriter, req *http.Request, host string) bool {
h := r.redirect[host]
return getServeData(rw, req, h)
}
type serveDataInterface interface {
HasFlag(flag target.Flags) bool
ServeHTTP(rw http.ResponseWriter, req *http.Request)
}
func getServeData[T serveDataInterface](rw http.ResponseWriter, req *http.Request, h *trie.Trie[T]) bool {
if h == nil {
return false
}
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
for i := len(pairs) - 1; i >= 0; i-- {
if pairs[i].Value.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
pairs[i].Value.ServeHTTP(rw, req)
return true
if h != nil {
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
for i := len(pairs) - 1; i >= 0; i-- {
if pairs[i].Value.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
pairs[i].Value.ServeHTTP(rw, req)
return true
}
}
}
return false
}
func (r *Router) serveRedirectHTTP(rw http.ResponseWriter, req *http.Request, host string) bool {
h := r.redirect[host]
if h != nil {
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
for i := len(pairs) - 1; i >= 0; i-- {
if pairs[i].Value.Flags.HasFlag(target.FlagPre) || pairs[i].Key == req.URL.Path {
req.URL.Path = strings.TrimPrefix(req.URL.Path, pairs[i].Key)
pairs[i].Value.ServeHTTP(rw, req)
return true
}
}
}
return false

View File

@ -1,12 +1,8 @@
package router
import (
"fmt"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/1f349/violet/target"
"github.com/mrmelon54/trie"
"github.com/stretchr/testify/assert"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/target"
"net/http"
"net/http/httptest"
"net/url"
@ -184,7 +180,7 @@ func TestRouter_AddRoute(t *testing.T) {
transInsecure := &fakeTransport{}
for _, i := range routeTests {
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure, &websocket.Server{}))
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure))
dst := i.dst
dst.Dst = path.Join("127.0.0.1:8080", dst.Dst)
dst.Src = path.Join("example.com", i.path)
@ -198,7 +194,7 @@ func TestRouter_AddRoute(t *testing.T) {
if v == "" {
if transSecure.req != nil {
t.Logf("Test URL: %#v\n", req.URL)
t.Log(r.route["example.com"].String())
t.Log(r.redirect["example.com"].String())
t.Fatalf("%s => %s\n", k, v)
}
} else {
@ -270,32 +266,32 @@ func TestRouter_AddWildcardRoute(t *testing.T) {
transInsecure := &fakeTransport{}
for _, i := range routeTests {
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure, &websocket.Server{}))
r := New(proxy.NewHybridTransportWithCalls(transSecure, transInsecure))
dst := i.dst
dst.Dst = path.Join("127.0.0.1:8080", dst.Dst)
dst.Src = path.Join("*.example.com", i.path)
dst.Src = path.Join("example.com", i.path)
t.Logf("Running tests for %#v\n", dst)
r.AddRoute(dst)
for k, v := range i.tests {
u1 := &url.URL{Scheme: "https", Host: "test.example.com", Path: k}
u1 := &url.URL{Scheme: "https", Host: "example.com", Path: k}
req, _ := http.NewRequest(http.MethodGet, u1.String(), nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if v == "" {
if transSecure.req != nil {
t.Logf("Test URL: %#v\n", req.URL)
t.Log(r.route["*.example.com"].String())
t.Log(r.redirect["example.com"].String())
t.Fatalf("%s => %s\n", k, v)
}
} else {
if transSecure.req == nil {
t.Logf("Test URL: %#v\n", req.URL)
t.Log(r.route["*.example.com"].String())
t.Log(r.route["example.com"].String())
t.Fatalf("\nexpected %s => %s\n got %s => %s\n", k, v, k, "")
}
if v != transSecure.req.URL.Path {
t.Logf("Test URL: %#v\n", req.URL)
t.Log(r.route["*.example.com"].String())
t.Log(r.route["example.com"].String())
t.Fatalf("\nexpected %s => %s\n got %s => %s\n", k, v, k, transSecure.req.URL.Path)
}
transSecure.req = nil
@ -303,24 +299,3 @@ func TestRouter_AddWildcardRoute(t *testing.T) {
}
}
}
type fakeRoundTripper struct{}
func (f *fakeRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) {
rec := httptest.NewRecorder()
rec.WriteHeader(http.StatusNotFound)
return rec.Result(), nil
}
func TestGetServeData_Route(t *testing.T) {
hyb := proxy.NewHybridTransportWithCalls(&fakeRoundTripper{}, &fakeRoundTripper{}, nil)
req, err := http.NewRequest(http.MethodGet, "https://example.com/hello/world/this/is/a/test", nil)
assert.NoError(t, err)
h := trie.BuildFromMap(map[string]target.Route{
"/hello/world": {Flags: target.FlagPre, Proxy: hyb},
})
rec := httptest.NewRecorder()
pairs := h.GetAllKeyValues([]byte(req.URL.Path))
fmt.Printf("%#v\n", pairs)
assert.True(t, getServeData(rec, req, h))
}

View File

@ -2,13 +2,11 @@ package api
import (
"encoding/json"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/1f349/violet/servers/conf"
"github.com/1f349/violet/utils"
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/mjwt/claims"
"github.com/MrMelon54/violet/servers/conf"
"github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"net/http"
"time"
)
@ -17,16 +15,9 @@ import (
// endpoints for the software
//
// `/compile` - reloads all domains, routes and redirects
func NewApiServer(conf *conf.Conf, compileTarget utils.MultiCompilable, registry *prometheus.Registry) *http.Server {
func NewApiServer(conf *conf.Conf, compileTarget utils.MultiCompilable) *http.Server {
r := httprouter.New()
r.GET("/", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
http.Error(rw, "Violet API Endpoint", http.StatusOK)
})
r.GET("/metrics", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
promhttp.HandlerFor(registry, promhttp.HandlerOpts{}).ServeHTTP(rw, req)
})
// Endpoint for compile action
r.POST("/compile", checkAuthWithPerm(conf.Signer, "violet:compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, b AuthClaims) {
// Trigger the compile action
@ -48,6 +39,7 @@ func NewApiServer(conf *conf.Conf, compileTarget utils.MultiCompilable, registry
// Create and run http server
return &http.Server{
Addr: conf.ApiListen,
Handler: r,
ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute,
@ -70,7 +62,6 @@ func domainManage(verify mjwt.Verifier, domains utils.DomainProvider) httprouter
// add domain with active state
domains.Put(params.ByName("domain"), req.Method == http.MethodPut)
domains.Compile()
rw.WriteHeader(http.StatusAccepted)
})
}
@ -90,21 +81,11 @@ func acmeChallengeManage(verify mjwt.Verifier, domains utils.DomainProvider, acm
})
}
// getDomainOwnershipClaims returns the domains marked as owned from PermStorage,
// they match `domain:owns=<fqdn>` where fqdn will be returned
func getDomainOwnershipClaims(perms *claims.PermStorage) []string {
a := perms.Search("domain:owns=*")
for i := range a {
a[i] = a[i][len("domain:owns="):]
}
return a
}
// validateDomainOwnershipClaims validates if the claims contain the
// `domain:owns=<fqdn>` field with the matching top level domain
// `owns=<fqdn>` field with the matching top level domain
func validateDomainOwnershipClaims(a string, perms *claims.PermStorage) bool {
if fqdn, ok := utils.GetTopFqdn(a); ok {
if perms.Has("domain:owns=" + fqdn) {
if perms.Has("owns=" + fqdn) {
return true
}
}

View File

@ -1,9 +1,9 @@
package api
import (
"github.com/1f349/violet/servers/conf"
"github.com/1f349/violet/utils"
"github.com/1f349/violet/utils/fake"
"github.com/MrMelon54/violet/servers/conf"
"github.com/MrMelon54/violet/utils"
"github.com/MrMelon54/violet/utils/fake"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
@ -17,7 +17,7 @@ func TestNewApiServer_Compile(t *testing.T) {
Signer: fake.SnakeOilProv,
}
f := &fake.Compilable{}
srv := NewApiServer(apiConf, utils.MultiCompilable{f}, nil)
srv := NewApiServer(apiConf, utils.MultiCompilable{f})
req, err := http.NewRequest(http.MethodPost, "https://example.com/compile", nil)
assert.NoError(t, err)
@ -43,7 +43,7 @@ func TestNewApiServer_AcmeChallenge_Put(t *testing.T) {
Acme: utils.NewAcmeChallenge(),
Signer: fake.SnakeOilProv,
}
srv := NewApiServer(apiConf, utils.MultiCompilable{}, nil)
srv := NewApiServer(apiConf, utils.MultiCompilable{})
acmeKey := fake.GenSnakeOilKey("violet:acme-challenge")
// Valid domain
@ -87,7 +87,7 @@ func TestNewApiServer_AcmeChallenge_Delete(t *testing.T) {
Acme: utils.NewAcmeChallenge(),
Signer: fake.SnakeOilProv,
}
srv := NewApiServer(apiConf, utils.MultiCompilable{}, nil)
srv := NewApiServer(apiConf, utils.MultiCompilable{})
acmeKey := fake.GenSnakeOilKey("violet:acme-challenge")
// Valid domain

View File

@ -1,9 +1,9 @@
package api
import (
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/auth"
"github.com/1f349/violet/utils"
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/mjwt/auth"
"github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter"
"net/http"
)

View File

@ -1,7 +1,7 @@
package api
import (
"github.com/1f349/violet/target"
"github.com/MrMelon54/violet/target"
)
type sourceJson struct {
@ -10,11 +10,11 @@ type sourceJson struct {
func (s sourceJson) GetSource() string { return s.Src }
type routeSource target.RouteWithActive
type routeSource target.Route
func (r routeSource) GetSource() string { return r.Src }
type redirectSource target.RedirectWithActive
type redirectSource target.Redirect
func (r redirectSource) GetSource() string { return r.Src }

View File

@ -2,12 +2,12 @@ package api
import (
"encoding/json"
"github.com/1f349/mjwt"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/router"
"github.com/1f349/violet/target"
"github.com/1f349/violet/utils"
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/violet/router"
"github.com/MrMelon54/violet/target"
"github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter"
"log"
"net/http"
"strings"
)
@ -15,11 +15,8 @@ import (
func SetupTargetApis(r *httprouter.Router, verify mjwt.Verifier, manager *router.Manager) {
// Endpoint for routes
r.GET("/route", checkAuthWithPerm(verify, "violet:route", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
domains := getDomainOwnershipClaims(b.Claims.Perms)
routes, err := manager.GetAllRoutes(domains)
routes, err := manager.GetAllRoutes()
if err != nil {
logger.Logger.Infof("Failed to get routes from database: %s\n", err)
apiError(rw, http.StatusInternalServerError, "Failed to get routes from database")
return
}
@ -27,9 +24,9 @@ func SetupTargetApis(r *httprouter.Router, verify mjwt.Verifier, manager *router
_ = json.NewEncoder(rw).Encode(routes)
}))
r.POST("/route", parseJsonAndCheckOwnership[routeSource](verify, "route", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t routeSource) {
err := manager.InsertRoute(target.RouteWithActive(t))
err := manager.InsertRoute(target.Route(t))
if err != nil {
logger.Logger.Infof("Failed to insert route into database: %s\n", err)
log.Printf("[Violet] Failed to insert route into database: %s\n", err)
apiError(rw, http.StatusInternalServerError, "Failed to insert route into database")
return
}
@ -38,7 +35,7 @@ func SetupTargetApis(r *httprouter.Router, verify mjwt.Verifier, manager *router
r.DELETE("/route", parseJsonAndCheckOwnership[sourceJson](verify, "route", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t sourceJson) {
err := manager.DeleteRoute(t.Src)
if err != nil {
logger.Logger.Infof("Failed to delete route from database: %s\n", err)
log.Printf("[Violet] Failed to delete route from database: %s\n", err)
apiError(rw, http.StatusInternalServerError, "Failed to delete route from database")
return
}
@ -47,11 +44,8 @@ func SetupTargetApis(r *httprouter.Router, verify mjwt.Verifier, manager *router
// Endpoint for redirects
r.GET("/redirect", checkAuthWithPerm(verify, "violet:redirect", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
domains := getDomainOwnershipClaims(b.Claims.Perms)
redirects, err := manager.GetAllRedirects(domains)
redirects, err := manager.GetAllRedirects()
if err != nil {
logger.Logger.Infof("Failed to get redirects from database: %s\n", err)
apiError(rw, http.StatusInternalServerError, "Failed to get redirects from database")
return
}
@ -59,9 +53,9 @@ func SetupTargetApis(r *httprouter.Router, verify mjwt.Verifier, manager *router
_ = json.NewEncoder(rw).Encode(redirects)
}))
r.POST("/redirect", parseJsonAndCheckOwnership[redirectSource](verify, "redirect", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t redirectSource) {
err := manager.InsertRedirect(target.RedirectWithActive(t))
err := manager.InsertRedirect(target.Redirect(t))
if err != nil {
logger.Logger.Infof("Failed to insert redirect into database: %s\n", err)
log.Printf("[Violet] Failed to insert redirect into database: %s\n", err)
apiError(rw, http.StatusInternalServerError, "Failed to insert redirect into database")
return
}
@ -70,7 +64,7 @@ func SetupTargetApis(r *httprouter.Router, verify mjwt.Verifier, manager *router
r.DELETE("/redirect", parseJsonAndCheckOwnership[sourceJson](verify, "redirect", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, t sourceJson) {
err := manager.DeleteRedirect(t.Src)
if err != nil {
logger.Logger.Infof("Failed to delete redirect from database: %s\n", err)
log.Printf("[Violet] Failed to delete redirect from database: %s\n", err)
apiError(rw, http.StatusInternalServerError, "Failed to delete redirect from database")
return
}

View File

@ -1,23 +1,26 @@
package conf
import (
"github.com/1f349/mjwt"
"github.com/1f349/violet/database"
errorPages "github.com/1f349/violet/error-pages"
"github.com/1f349/violet/favicons"
"github.com/1f349/violet/router"
"github.com/1f349/violet/utils"
"database/sql"
"github.com/MrMelon54/mjwt"
errorPages "github.com/MrMelon54/violet/error-pages"
"github.com/MrMelon54/violet/favicons"
"github.com/MrMelon54/violet/router"
"github.com/MrMelon54/violet/utils"
)
// Conf stores the shared configuration for the API, HTTP and HTTPS servers.
type Conf struct {
RateLimit uint64 // rate limit per minute
DB *database.Queries
Domains utils.DomainProvider
Acme utils.AcmeChallengeProvider
Certs utils.CertProvider
Favicons *favicons.Favicons
Signer mjwt.Verifier
ErrorPages *errorPages.ErrorPages
Router *router.Manager
ApiListen string // api server listen address
HttpListen string // http server listen address
HttpsListen string // https server listen address
RateLimit uint64 // rate limit per minute
DB *sql.DB
Domains utils.DomainProvider
Acme utils.AcmeChallengeProvider
Certs utils.CertProvider
Favicons *favicons.Favicons
Signer mjwt.Verifier
ErrorPages *errorPages.ErrorPages
Router *router.Manager
}

View File

@ -2,11 +2,9 @@ package servers
import (
"fmt"
"github.com/1f349/violet/servers/conf"
"github.com/1f349/violet/servers/metrics"
"github.com/1f349/violet/utils"
"github.com/MrMelon54/violet/servers/conf"
"github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter"
"github.com/prometheus/client_golang/prometheus"
"net/http"
"net/url"
"time"
@ -17,9 +15,13 @@ import (
//
// `/.well-known/acme-challenge/{token}` is used for outputting answers for
// acme challenges, this is used for Let's Encrypt HTTP verification.
func NewHttpServer(httpsPort uint16, conf *conf.Conf, registry *prometheus.Registry) *http.Server {
func NewHttpServer(conf *conf.Conf) *http.Server {
r := httprouter.New()
var secureExtend string
_, httpsPort, ok := utils.SplitDomainPort(conf.HttpsListen, 443)
if !ok {
httpsPort = 443
}
if httpsPort != 443 {
secureExtend = fmt.Sprintf(":%d", httpsPort)
}
@ -59,16 +61,10 @@ func NewHttpServer(httpsPort uint16, conf *conf.Conf, registry *prometheus.Regis
utils.FastRedirect(rw, req, u.String(), http.StatusPermanentRedirect)
})
metricsMiddleware := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
r.ServeHTTP(rw, req)
})
if registry != nil {
metricsMiddleware = metrics.New(registry, nil).WrapHandler("violet-http-insecure", r)
}
// Create and run http server
return &http.Server{
Handler: metricsMiddleware,
Addr: conf.HttpListen,
Handler: r,
ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute,
WriteTimeout: time.Minute,

View File

@ -2,9 +2,9 @@ package servers
import (
"bytes"
"github.com/1f349/violet/servers/conf"
"github.com/1f349/violet/utils"
"github.com/1f349/violet/utils/fake"
"github.com/MrMelon54/violet/servers/conf"
"github.com/MrMelon54/violet/utils"
"github.com/MrMelon54/violet/utils/fake"
"github.com/stretchr/testify/assert"
"io"
"net/http"
@ -18,7 +18,7 @@ func TestNewHttpServer_AcmeChallenge(t *testing.T) {
Acme: utils.NewAcmeChallenge(),
Signer: fake.SnakeOilProv,
}
srv := NewHttpServer(443, httpConf, nil)
srv := NewHttpServer(httpConf)
httpConf.Acme.Put("example.com", "456", "456def")
req, err := http.NewRequest(http.MethodGet, "https://example.com/.well-known/acme-challenge/456", nil)

View File

@ -3,78 +3,47 @@ package servers
import (
"crypto/tls"
"fmt"
"github.com/1f349/violet/favicons"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/servers/conf"
"github.com/1f349/violet/servers/metrics"
"github.com/1f349/violet/utils"
"github.com/prometheus/client_golang/prometheus"
"github.com/MrMelon54/violet/favicons"
"github.com/MrMelon54/violet/servers/conf"
"github.com/MrMelon54/violet/utils"
"github.com/sethvargo/go-limiter/httplimit"
"github.com/sethvargo/go-limiter/memorystore"
"log"
"net"
"net/http"
"path"
"runtime"
"time"
)
// NewHttpsServer creates and runs a http server containing the public https
// endpoints for the reverse proxy.
func NewHttpsServer(conf *conf.Conf, registry *prometheus.Registry) *http.Server {
r := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
logger.Logger.Debug("Request", "method", req.Method, "url", req.URL, "remote", req.RemoteAddr, "host", req.Host, "length", req.ContentLength, "goroutine", runtime.NumGoroutine())
conf.Router.ServeHTTP(rw, req)
})
favMiddleware := setupFaviconMiddleware(conf.Favicons, r)
metricsMeta := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
r.ServeHTTP(rw, req)
})
if registry != nil {
metricsMiddleware := metrics.New(registry, nil).WrapHandler("violet-https", favMiddleware)
metricsMeta = func(rw http.ResponseWriter, req *http.Request) {
metricsMiddleware.ServeHTTP(rw, metrics.AddHostCtx(req))
}
}
rateLimiter := setupRateLimiter(conf.RateLimit, metricsMeta)
hsts := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains")
rateLimiter.ServeHTTP(rw, req)
})
func NewHttpsServer(conf *conf.Conf) *http.Server {
return &http.Server{
Handler: hsts,
TLSConfig: &tls.Config{
// Suggested by https://ssl-config.mozilla.org/#server=go&version=1.21.5&config=intermediate
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
},
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
// error out on invalid domains
if !conf.Domains.IsValid(info.ServerName) {
return nil, fmt.Errorf("invalid hostname used: '%s'", info.ServerName)
}
Addr: conf.HttpsListen,
Handler: setupRateLimiter(conf.RateLimit, setupFaviconMiddleware(conf.Favicons, conf.Router)),
TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
// error out on invalid domains
if !conf.Domains.IsValid(info.ServerName) {
return nil, fmt.Errorf("invalid hostname used: '%s'", info.ServerName)
}
// find a certificate
cert := conf.Certs.GetCertForDomain(info.ServerName)
if cert == nil {
return nil, fmt.Errorf("failed to find certificate for: '%s'", info.ServerName)
}
// find a certificate
cert := conf.Certs.GetCertForDomain(info.ServerName)
if cert == nil {
return nil, fmt.Errorf("failed to find certificate for: '%s'", info.ServerName)
}
// time to return
return cert, nil
},
},
// time to return
return cert, nil
}},
ReadTimeout: 150 * time.Second,
ReadHeaderTimeout: 150 * time.Second,
WriteTimeout: 150 * time.Second,
IdleTimeout: 150 * time.Second,
MaxHeaderBytes: 4096000,
ConnState: func(conn net.Conn, state http.ConnState) {
fmt.Printf("[HTTPS] %s => %s: %s\n", conn.LocalAddr(), conn.RemoteAddr(), state.String())
},
}
}
@ -87,24 +56,19 @@ func setupRateLimiter(rateLimit uint64, next http.Handler) http.Handler {
Interval: time.Minute,
})
if err != nil {
logger.Logger.Fatal("Failed to initialize memory store", "err", err)
log.Fatalln(err)
}
// create a middleware using ips as the key for rate limits
middleware, err := httplimit.NewMiddleware(store, httplimit.IPKeyFunc())
if err != nil {
logger.Logger.Fatal("Failed to initialize httplimit middleware", "err", err)
log.Fatalln(err)
}
return middleware.Handle(next)
}
func setupFaviconMiddleware(fav *favicons.Favicons, next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Header.Get("X-Violet-Loop-Detect") == "1" {
rw.WriteHeader(http.StatusLoopDetected)
_, _ = rw.Write([]byte("Detected a routing loop\n"))
return
}
if req.Header.Get("X-Violet-Raw-Favicon") != "1" {
switch req.URL.Path {
case "/favicon.svg", "/favicon.png", "/favicon.ico":

View File

@ -1,13 +1,12 @@
package servers
import (
"github.com/1f349/violet"
"github.com/1f349/violet/certs"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/1f349/violet/router"
"github.com/1f349/violet/servers/conf"
"github.com/1f349/violet/utils/fake"
"database/sql"
"github.com/MrMelon54/violet/certs"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/router"
"github.com/MrMelon54/violet/servers/conf"
"github.com/MrMelon54/violet/utils/fake"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"net/http"
@ -25,7 +24,7 @@ func (f *fakeTransport) RoundTrip(_ *http.Request) (*http.Response, error) {
}
func TestNewHttpsServer_RateLimit(t *testing.T) {
db, err := violet.InitDB("file:TestNewHttpsServer_RateLimit?mode=memory&cache=shared")
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
assert.NoError(t, err)
ft := &fakeTransport{}
@ -34,9 +33,9 @@ func TestNewHttpsServer_RateLimit(t *testing.T) {
Domains: &fake.Domains{},
Certs: certs.New(nil, nil, true),
Signer: fake.SnakeOilProv,
Router: router.NewManager(db, proxy.NewHybridTransportWithCalls(ft, ft, &websocket.Server{})),
Router: router.NewManager(db, proxy.NewHybridTransportWithCalls(ft, ft)),
}
srv := NewHttpsServer(httpsConf, nil)
srv := NewHttpsServer(httpsConf)
req, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
req.RemoteAddr = "127.0.0.1:1447"

View File

@ -1,118 +0,0 @@
package metrics
import (
"context"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
"net/http"
)
// Copyright 2022 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package metrics is adapted from
// https://github.com/bwplotka/correlator/tree/main/examples/observability/ping/pkg/httpinstrumentation
// https://github.com/prometheus/client_golang/blob/main/examples/middleware/httpmiddleware/httpmiddleware.go
type Middleware interface {
// WrapHandler wraps the given HTTP handler for instrumentation.
WrapHandler(handlerName string, handler http.Handler) http.HandlerFunc
}
type middleware struct {
buckets []float64
registry prometheus.Registerer
}
// WrapHandler wraps the given HTTP handler for instrumentation:
// It registers four metric collectors (if not already done) and reports HTTP
// metrics to the (newly or already) registered collectors.
// Each has a constant label named "handler" with the provided handlerName as
// value.
func (m *middleware) WrapHandler(handlerName string, handler http.Handler) http.HandlerFunc {
reg := prometheus.WrapRegistererWith(prometheus.Labels{"handler": handlerName}, m.registry)
requestsTotal := promauto.With(reg).NewCounterVec(
prometheus.CounterOpts{
Name: "http_requests_total",
Help: "Tracks the number of HTTP requests.",
}, []string{"method", "code", "host"},
)
requestDuration := promauto.With(reg).NewHistogramVec(
prometheus.HistogramOpts{
Name: "http_request_duration_seconds",
Help: "Tracks the latencies for HTTP requests.",
Buckets: m.buckets,
},
[]string{"method", "code", "host"},
)
requestSize := promauto.With(reg).NewSummaryVec(
prometheus.SummaryOpts{
Name: "http_request_size_bytes",
Help: "Tracks the size of HTTP requests.",
},
[]string{"method", "code", "host"},
)
responseSize := promauto.With(reg).NewSummaryVec(
prometheus.SummaryOpts{
Name: "http_response_size_bytes",
Help: "Tracks the size of HTTP responses.",
},
[]string{"method", "code", "host"},
)
hostCtxGetter := promhttp.WithLabelFromCtx("host", func(ctx context.Context) string {
s, _ := ctx.Value(hostCtxKey(0)).(string)
return s
})
// Wraps the provided http.Handler to observe the request result with the provided metrics.
base := promhttp.InstrumentHandlerCounter(
requestsTotal,
promhttp.InstrumentHandlerDuration(
requestDuration,
promhttp.InstrumentHandlerRequestSize(
requestSize,
promhttp.InstrumentHandlerResponseSize(
responseSize,
handler,
hostCtxGetter,
),
hostCtxGetter,
),
hostCtxGetter,
),
hostCtxGetter,
)
return base.ServeHTTP
}
// New returns a Middleware interface.
func New(registry prometheus.Registerer, buckets []float64) Middleware {
if buckets == nil {
buckets = prometheus.ExponentialBuckets(0.1, 1.5, 5)
}
return &middleware{
buckets: buckets,
registry: registry,
}
}
type hostCtxKey uint8
func AddHostCtx(req *http.Request) *http.Request {
return req.WithContext(context.WithValue(req.Context(), hostCtxKey(0), req.Host))
}

View File

@ -1,15 +0,0 @@
version: "2"
sql:
- engine: sqlite
queries: database/queries
schema: database/migrations
gen:
go:
package: "database"
out: "database"
emit_json_tags: true
overrides:
- column: "routes.flags"
go_type: "github.com/1f349/violet/target.Flags"
- column: "redirects.flags"
go_type: "github.com/1f349/violet/target.Flags"

View File

@ -10,11 +10,10 @@ const (
FlagForwardHost
FlagForwardAddr
FlagIgnoreCert
FlagWebsocket
)
var (
routeFlagMask = FlagPre | FlagAbs | FlagCors | FlagSecureMode | FlagForwardHost | FlagForwardAddr | FlagIgnoreCert | FlagWebsocket
routeFlagMask = FlagPre | FlagAbs | FlagCors | FlagSecureMode | FlagForwardHost | FlagForwardAddr | FlagIgnoreCert
redirectFlagMask = FlagPre | FlagAbs
)

View File

@ -2,7 +2,7 @@ package target
import (
"fmt"
"github.com/1f349/violet/utils"
"github.com/MrMelon54/violet/utils"
"net/http"
"net/url"
"path"
@ -14,9 +14,8 @@ import (
type Redirect struct {
Src string `json:"src"` // request source
Dst string `json:"dst"` // redirect destination
Desc string `json:"desc"` // description for admin panel use
Flags Flags `json:"flags"` // extra flags
Code int64 `json:"code"` // status code used to redirect
Code int `json:"code"` // status code used to redirect
}
type RedirectWithActive struct {
@ -24,18 +23,7 @@ type RedirectWithActive struct {
Active bool `json:"active"`
}
func (r Redirect) OnDomain(domain string) bool {
// if there is no / then the first part is still the domain
domainPart, _, _ := strings.Cut(r.Src, "/")
if domainPart == domain {
return true
}
// domainPart could start with a subdomain
return strings.HasSuffix(domainPart, "."+domain)
}
func (r Redirect) HasFlag(flag Flags) bool {
func (r Route) HasFlag(flag Flags) bool {
return r.Flags&flag != 0
}
@ -72,13 +60,8 @@ func (r Redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
Path: p,
}
// close the incoming body after use
if req.Body != nil {
defer req.Body.Close()
}
// use fast redirect for speed
utils.FastRedirect(rw, req, u.String(), int(code))
utils.FastRedirect(rw, req, u.String(), code)
}
// String outputs a debug string for the redirect.

View File

@ -7,22 +7,6 @@ import (
"testing"
)
func TestRedirect_OnDomain(t *testing.T) {
assert.True(t, Route{Src: "example.com"}.OnDomain("example.com"))
assert.True(t, Route{Src: "test.example.com"}.OnDomain("example.com"))
assert.True(t, Route{Src: "example.com/hello"}.OnDomain("example.com"))
assert.True(t, Route{Src: "test.example.com/hello"}.OnDomain("example.com"))
assert.False(t, Route{Src: "example.com"}.OnDomain("example.org"))
assert.False(t, Route{Src: "test.example.com"}.OnDomain("example.org"))
assert.False(t, Route{Src: "example.com/hello"}.OnDomain("example.org"))
assert.False(t, Route{Src: "test.example.com/hello"}.OnDomain("example.org"))
}
func TestRedirect_HasFlag(t *testing.T) {
assert.True(t, Route{Flags: FlagPre | FlagAbs}.HasFlag(FlagPre))
assert.False(t, Route{Flags: FlagPre | FlagAbs}.HasFlag(FlagCors))
}
func TestRedirect_ServeHTTP(t *testing.T) {
a := []struct {
Redirect
@ -35,7 +19,7 @@ func TestRedirect_ServeHTTP(t *testing.T) {
res := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "https://www.example.com/hello/world", nil)
i.ServeHTTP(res, req)
assert.Equal(t, i.Code, int64(res.Code))
assert.Equal(t, i.Code, res.Code)
assert.Equal(t, i.target, res.Header().Get("Location"))
}
}

View File

@ -1,14 +1,14 @@
package target
import (
"context"
"fmt"
"github.com/1f349/violet/logger"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/utils"
websocket2 "github.com/gorilla/websocket"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/utils"
"github.com/rs/cors"
"golang.org/x/net/http/httpguts"
"io"
"log"
"net"
"net/http"
"net/textproto"
@ -17,13 +17,10 @@ import (
"strings"
)
var Logger = logger.Logger.WithPrefix("Violet Serve Route")
// serveApiCors outputs the cors headers to make APIs work.
var serveApiCors = cors.New(cors.Options{
// allow all origins for api requests
AllowOriginFunc: func(origin string) bool { return true },
AllowedHeaders: []string{"Content-Type", "Authorization"},
AllowedOrigins: []string{"*"}, // allow all origins for api requests
AllowedHeaders: []string{"Content-Type", "Authorization"},
AllowedMethods: []string{
http.MethodGet,
http.MethodHead,
@ -41,7 +38,6 @@ var serveApiCors = cors.New(cors.Options{
type Route struct {
Src string `json:"src"` // request source
Dst string `json:"dst"` // proxy destination
Desc string `json:"desc"` // description for admin panel use
Flags Flags `json:"flags"` // extra flags
Headers http.Header `json:"-"` // extra headers
Proxy *proxy.HybridTransport `json:"-"` // reverse proxy handler
@ -52,21 +48,6 @@ type RouteWithActive struct {
Active bool `json:"active"`
}
func (r Route) OnDomain(domain string) bool {
// if there is no / then the first part is still the domain
domainPart, _, _ := strings.Cut(r.Src, "/")
if domainPart == domain {
return true
}
// domainPart could start with a subdomain
return strings.HasSuffix(domainPart, "."+domain)
}
func (r Route) HasFlag(flag Flags) bool {
return r.Flags&flag != 0
}
// UpdateHeaders takes an existing set of headers and overwrites them with the
// extra headers.
func (r Route) UpdateHeaders(header http.Header) {
@ -129,7 +110,8 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
// create the internal request
req2, err := http.NewRequest(req.Method, u.String(), req.Body)
if err != nil {
utils.RespondVioletError(rw, http.StatusBadGateway, "Invalid request for proxy")
log.Printf("[ServeRoute::ServeHTTP()] Error generating new request: %s\n", err)
utils.RespondVioletError(rw, http.StatusBadGateway, "error generating new request")
return
}
@ -156,20 +138,12 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
if r.HasFlag(FlagForwardHost) {
req2.Host = req.Host
}
if r.HasFlag(FlagForwardAddr) {
req2.Header.Add("X-Forwarded-For", req.RemoteAddr)
}
// adds extra request metadata
if r.internalReverseProxyMeta(rw, req, req2) {
return
}
// switch to websocket handler
// internally the http hijack method is called
if r.HasFlag(FlagWebsocket) && websocket2.IsWebSocketUpgrade(req2) {
r.Proxy.ConnectWebsocket(rw, req2)
return
}
req2.Header.Set("X-Violet-Loop-Detect", "1")
r.internalReverseProxyMeta(rw, req)
// serve request with reverse proxy
var resp *http.Response
@ -179,19 +153,8 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
resp, err = r.Proxy.SecureRoundTrip(req2)
}
if err != nil {
Logger.Warn("Error receiving internal round trip response", "route src", r.Src, "url", req2.URL.String(), "err", err)
utils.RespondVioletError(rw, http.StatusBadGateway, "Error receiving internal round trip response")
return
}
// make sure to close response body after use
if resp.Body != nil {
defer resp.Body.Close()
}
if resp.StatusCode == http.StatusLoopDetected {
Logger.Warn("Loop Detected", "method", req.Method, "url", req.URL, "url2", req2.URL.String())
utils.RespondVioletError(rw, http.StatusLoopDetected, "Error loop detected")
log.Printf("[ServeRoute::ServeHTTP()] Error receiving internal round trip response: %s\n", err)
utils.RespondVioletError(rw, http.StatusBadGateway, "error receiving internal round trip response")
return
}
@ -203,6 +166,14 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
if resp.Body != nil {
_, err := io.Copy(rw, resp.Body)
if err != nil {
// hijack and close upon error
if h, ok := rw.(http.Hijacker); ok {
hijack, _, err := h.Hijack()
if err != nil {
return
}
_ = hijack.Close()
}
return
}
}
@ -212,20 +183,21 @@ func (r Route) internalServeHTTP(rw http.ResponseWriter, req *http.Request) {
// due to the highly custom nature of this reverse proxy software we use a copy
// of the code instead of the full httputil implementation to prevent overhead
// from the more generic implementation
func (r Route) internalReverseProxyMeta(rw http.ResponseWriter, req, req2 *http.Request) bool {
func (r Route) internalReverseProxyMeta(rw http.ResponseWriter, req *http.Request) {
outreq := req.Clone(context.Background())
if req.ContentLength == 0 {
req2.Body = nil // Issue 16036: nil Body for http.Transport retries
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
}
if req2.Header == nil {
req2.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
if outreq.Header == nil {
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
}
reqUpType := upgradeType(req2.Header)
reqUpType := upgradeType(outreq.Header)
if !asciiIsPrint(reqUpType) {
utils.RespondVioletError(rw, http.StatusBadRequest, fmt.Sprintf("Invalid protocol %s", reqUpType))
return true
utils.RespondVioletError(rw, http.StatusBadRequest, fmt.Sprintf("client tried to switch to invalid protocol %q", reqUpType))
return
}
removeHopByHopHeaders(req2.Header)
removeHopByHopHeaders(outreq.Header)
// Issue 21096: tell backend applications that care about trailer support
// that we support trailers. (We do, but we don't go out of our way to
@ -233,33 +205,29 @@ func (r Route) internalReverseProxyMeta(rw http.ResponseWriter, req, req2 *http.
// mentioning.) Note that we look at req.Header, not outreq.Header, since
// the latter has passed through removeHopByHopHeaders.
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
req2.Header.Set("Te", "trailers")
outreq.Header.Set("Te", "trailers")
}
// After stripping all the hop-by-hop connection headers above, add back any
// necessary for protocol upgrades, such as for websockets.
if reqUpType != "" {
req2.Header.Set("Connection", "Upgrade")
req2.Header.Set("Upgrade", reqUpType)
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
}
if r.HasFlag(FlagForwardAddr) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
prior, ok := req2.Header["X-Forwarded-For"]
omit := ok && prior == nil // Issue 38079: nil now means don't populate the header
if len(prior) > 0 {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
if !omit {
req2.Header.Set("X-Forwarded-For", clientIP)
}
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
prior, ok := outreq.Header["X-Forwarded-For"]
omit := ok && prior == nil // Issue 38079: nil now means don't populate the header
if len(prior) > 0 {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
if !omit {
outreq.Header.Set("X-Forwarded-For", clientIP)
}
}
return false
}
// String outputs a debug string for the route.

View File

@ -2,11 +2,9 @@ package target
import (
"bytes"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/proxy/websocket"
"github.com/MrMelon54/violet/proxy"
"github.com/stretchr/testify/assert"
"io"
"net"
"net/http"
"net/http/httptest"
"testing"
@ -18,7 +16,7 @@ type proxyTester struct {
}
func (p *proxyTester) makeHybridTransport() *proxy.HybridTransport {
return proxy.NewHybridTransportWithCalls(p, p, &websocket.Server{})
return proxy.NewHybridTransportWithCalls(p, p)
}
func (p *proxyTester) RoundTrip(req *http.Request) (*http.Response, error) {
@ -27,17 +25,6 @@ func (p *proxyTester) RoundTrip(req *http.Request) (*http.Response, error) {
return &http.Response{StatusCode: http.StatusOK}, nil
}
func TestRoute_OnDomain(t *testing.T) {
assert.True(t, Route{Src: "example.com"}.OnDomain("example.com"))
assert.True(t, Route{Src: "test.example.com"}.OnDomain("example.com"))
assert.True(t, Route{Src: "example.com/hello"}.OnDomain("example.com"))
assert.True(t, Route{Src: "test.example.com/hello"}.OnDomain("example.com"))
assert.False(t, Route{Src: "example.com"}.OnDomain("example.org"))
assert.False(t, Route{Src: "test.example.com"}.OnDomain("example.org"))
assert.False(t, Route{Src: "example.com/hello"}.OnDomain("example.org"))
assert.False(t, Route{Src: "test.example.com/hello"}.OnDomain("example.org"))
}
func TestRoute_HasFlag(t *testing.T) {
assert.True(t, Route{Flags: FlagPre | FlagAbs}.HasFlag(FlagPre))
assert.False(t, Route{Flags: FlagPre | FlagAbs}.HasFlag(FlagCors))
@ -53,7 +40,7 @@ func TestRoute_ServeHTTP(t *testing.T) {
{Route{Dst: "2.2.2.2/world", Flags: FlagAbs | FlagSecureMode}, "https://2.2.2.2/world"},
{Route{Dst: "api.example.com/world", Flags: FlagAbs | FlagSecureMode | FlagForwardHost}, "https://api.example.com/world"},
{Route{Dst: "api.example.org/world", Flags: FlagAbs | FlagSecureMode | FlagForwardAddr}, "https://api.example.org/world"},
{Route{Dst: "3.3.3.3/headers", Flags: FlagAbs, Headers: http.Header{"X-Other": []string{"test value"}, "X-Violet-Loop-Detect": []string{"1"}}}, "http://3.3.3.3/headers"},
{Route{Dst: "3.3.3.3/headers", Flags: FlagAbs, Headers: http.Header{"X-Other": []string{"test value"}}}, "http://3.3.3.3/headers"},
}
for _, i := range a {
pt := &proxyTester{}
@ -65,9 +52,7 @@ func TestRoute_ServeHTTP(t *testing.T) {
assert.True(t, pt.got)
assert.Equal(t, i.target, pt.req.URL.String())
if i.HasFlag(FlagForwardAddr) {
host, _, err := net.SplitHostPort(req.RemoteAddr)
assert.NoError(t, err)
assert.Equal(t, host, pt.req.Header.Get("X-Forwarded-For"))
assert.Equal(t, req.RemoteAddr, pt.req.Header.Get("X-Forwarded-For"))
}
if i.HasFlag(FlagForwardHost) {
assert.Equal(t, req.Host, pt.req.Host)
@ -90,7 +75,7 @@ func TestRoute_ServeHTTP_Cors(t *testing.T) {
assert.Equal(t, http.MethodOptions, pt.req.Method)
assert.Equal(t, "http://1.1.1.1:8080/hello/test", pt.req.URL.String())
assert.Equal(t, "Origin", res.Header().Get("Vary"))
assert.Equal(t, "https://test.example.com", res.Header().Get("Access-Control-Allow-Origin"))
assert.Equal(t, "*", res.Header().Get("Access-Control-Allow-Origin"))
assert.Equal(t, "true", res.Header().Get("Access-Control-Allow-Credentials"))
assert.Equal(t, "Origin", res.Header().Get("Vary"))
}

View File

@ -1,7 +1,6 @@
package utils
import (
"golang.org/x/net/publicsuffix"
"strconv"
"strings"
)
@ -66,8 +65,23 @@ func GetParentDomain(domain string) (string, bool) {
//
// hello.world.example.com => example.com
func GetTopFqdn(domain string) (string, bool) {
out, err := publicsuffix.EffectiveTLDPlusOne(domain)
return out, err == nil
var countDot int
n := strings.LastIndexFunc(domain, func(r rune) bool {
// return true if this is the second '.'
// otherwise counts one and continues
if r == '.' {
if countDot == 1 {
return true
}
countDot++
}
return false
})
// if a valid index isn't found then return false
if n == -1 {
return "", false
}
return domain[n+1:], true
}
// SplitHostPath extracts the host/path from the input

View File

@ -52,11 +52,7 @@ func TestGetBaseDomain(t *testing.T) {
}
func TestGetTopFqdn(t *testing.T) {
domain, ok := GetTopFqdn("example.com")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain)
domain, ok = GetTopFqdn("www.example.com")
domain, ok := GetTopFqdn("www.example.com")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain)

View File

@ -1,6 +1,6 @@
package fake
import "github.com/1f349/violet/utils"
import "github.com/MrMelon54/violet/utils"
// Compilable implements utils.Compilable and stores if the Compile function
// is called.

View File

@ -1,6 +1,6 @@
package fake
import "github.com/1f349/violet/utils"
import "github.com/MrMelon54/violet/utils"
// Domains implements DomainProvider and makes sure `example.com` is valid
type Domains struct{}

View File

@ -3,9 +3,9 @@ package fake
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/auth"
"github.com/1f349/mjwt/claims"
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/mjwt/auth"
"github.com/MrMelon54/mjwt/claims"
"time"
)

View File

@ -1,35 +1,33 @@
package utils
import (
"errors"
"github.com/charmbracelet/log"
"net"
"log"
"net/http"
"strings"
)
// logHttpServerError is the internal function powering the logging in
// RunBackgroundHttp and RunBackgroundHttps.
func logHttpServerError(logger *log.Logger, err error) {
func logHttpServerError(prefix string, err error) {
if err != nil {
if errors.Is(err, http.ErrServerClosed) {
logger.Info("The http server shutdown successfully")
if err == http.ErrServerClosed {
log.Printf("[%s] The http server shutdown successfully\n", prefix)
} else {
logger.Info("Error trying to host the http server", "err", err.Error())
log.Printf("[%s] Error trying to host the http server: %s\n", prefix, err.Error())
}
}
}
// RunBackgroundHttp runs a http server and logs when the server closes or
// errors.
func RunBackgroundHttp(logger *log.Logger, s *http.Server, ln net.Listener) {
logHttpServerError(logger, s.Serve(ln))
func RunBackgroundHttp(prefix string, s *http.Server) {
logHttpServerError(prefix, s.ListenAndServe())
}
// RunBackgroundHttps runs a http server with TLS encryption and logs when the
// server closes or errors.
func RunBackgroundHttps(logger *log.Logger, s *http.Server, ln net.Listener) {
logHttpServerError(logger, s.ServeTLS(ln, "", ""))
func RunBackgroundHttps(prefix string, s *http.Server) {
logHttpServerError(prefix, s.ListenAndServeTLS("", ""))
}
// GetBearer returns the bearer from the Authorization header or an empty string

View File

@ -1,77 +0,0 @@
openapi: 3.0.3
info:
title: Violet
description: Violet
version: 1.0.0
contact:
name: Webmaster
email: webmaster@1f349.net
servers:
- url: 'https://api.1f349.net/v1/violet'
paths:
/compile:
post:
summary: Compile quick access data
tags:
- compile
responses:
'202':
description: Compile trigger sent
/domain/{domain}:
put:
summary: Add an allowed domain
tags:
- domain
parameters:
- name: domain
in: path
required: true
description: The domain to add
schema:
type: string
responses:
'202':
description: Domain added and compiled list reloaded
delete:
summary: Remove an allowed domain
tags:
- domain
parameters:
- name: domain
in: path
required: true
description: The domain to remove
schema:
type: string
responses:
'202':
description: Domain removed and compiled list reloaded
/acme-challenge/{domain}/{key}/{value}:
put:
summary: Add ACME challenge value
tags:
- acme-challenge
parameters:
- name: domain
in: path
required: true
description: The domain to add the challenge on
schema:
type: string
responses:
'202':
description: ACME challenge added
delete:
summary: Add ACME challenge value
tags:
- acme-challenge
parameters:
- name: domain
in: path
required: true
description: The domain to add the challenge on
schema:
type: string
responses:
'202':
description: ACME challenge added