From 69bce2d12de4a6c71bff01bdfac63cc08df5b7e8 Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Fri, 8 Mar 2024 16:05:39 +0000 Subject: [PATCH] Initial support for sqlc and migrations --- cmd/violet/serve.go | 4 +- cmd/violet/setup.go | 4 +- database/db.go | 31 +++ database/domain.sql.go | 68 +++++ database/favicon.sql.go | 74 ++++++ .../migrations/20240308125121_init.down.sql | 4 + .../migrations/20240308125121_init.up.sql | 36 +++ database/models.go | 44 +++ database/queries/domain.sql | 16 ++ database/queries/favicon.sql | 8 + database/queries/routing.sql | 39 +++ database/routing.sql.go | 250 ++++++++++++++++++ domains/create-table-domains.sql | 6 - domains/domains.go | 39 +-- domains/domains_test.go | 12 +- favicons/create-table-favicons.sql | 7 - favicons/favicon-image.go | 8 +- favicons/favicons.go | 37 +-- favicons/favicons_test.go | 13 +- go.mod | 35 +-- go.sum | 78 +++--- initdb.go | 38 +++ router/create-tables.sql | 20 -- router/manager.go | 127 ++++----- router/manager_test.go | 26 +- servers/conf/conf.go | 4 +- servers/http.go | 7 +- servers/https.go | 14 +- servers/https_test.go | 4 +- sqlc.yaml | 15 ++ target/redirect.go | 4 +- target/redirect_test.go | 2 +- target/route_test.go | 2 +- 33 files changed, 830 insertions(+), 246 deletions(-) create mode 100644 database/db.go create mode 100644 database/domain.sql.go create mode 100644 database/favicon.sql.go create mode 100644 database/migrations/20240308125121_init.down.sql create mode 100644 database/migrations/20240308125121_init.up.sql create mode 100644 database/models.go create mode 100644 database/queries/domain.sql create mode 100644 database/queries/favicon.sql create mode 100644 database/queries/routing.sql create mode 100644 database/routing.sql.go delete mode 100644 domains/create-table-domains.sql delete mode 100644 favicons/create-table-favicons.sql create mode 100644 initdb.go delete mode 100644 router/create-tables.sql create mode 100644 sqlc.yaml diff --git a/cmd/violet/serve.go b/cmd/violet/serve.go index b3be1c8..f92f02d 100644 --- a/cmd/violet/serve.go +++ b/cmd/violet/serve.go @@ -2,10 +2,10 @@ package main import ( "context" - "database/sql" "encoding/json" "flag" "github.com/1f349/mjwt" + "github.com/1f349/violet" "github.com/1f349/violet/certs" "github.com/1f349/violet/domains" errorPages "github.com/1f349/violet/error-pages" @@ -107,7 +107,7 @@ func normalLoad(startUp startUpConfig, wd string) { } // open sqlite database - db, err := sql.Open("sqlite3", filepath.Join(wd, "violet.db.sqlite")) + db, err := violet.InitDB(filepath.Join(wd, "violet.db.sqlite")) if err != nil { log.Fatal("[Violet] Failed to open database") } diff --git a/cmd/violet/setup.go b/cmd/violet/setup.go index 3789e5a..9ecc558 100644 --- a/cmd/violet/setup.go +++ b/cmd/violet/setup.go @@ -2,10 +2,10 @@ package main import ( "context" - "database/sql" "encoding/json" "flag" "fmt" + "github.com/1f349/violet" "github.com/1f349/violet/domains" "github.com/1f349/violet/proxy" "github.com/1f349/violet/proxy/websocket" @@ -147,7 +147,7 @@ func (s *setupCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) } // open sqlite database - db, err := sql.Open("sqlite3", databaseFile) + db, err := violet.InitDB(databaseFile) if err != nil { log.Fatalf("[Violet] Failed to open database '%s'...", databaseFile) } diff --git a/database/db.go b/database/db.go new file mode 100644 index 0000000..61f5bf4 --- /dev/null +++ b/database/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package database + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/database/domain.sql.go b/database/domain.sql.go new file mode 100644 index 0000000..4cad183 --- /dev/null +++ b/database/domain.sql.go @@ -0,0 +1,68 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: domain.sql + +package database + +import ( + "context" +) + +const addDomain = `-- name: AddDomain :exec +INSERT OR +REPLACE +INTO domains (domain, active) +VALUES (?, ?) +` + +type AddDomainParams struct { + Domain string `json:"domain"` + Active bool `json:"active"` +} + +func (q *Queries) AddDomain(ctx context.Context, arg AddDomainParams) error { + _, err := q.db.ExecContext(ctx, addDomain, arg.Domain, arg.Active) + return err +} + +const deleteDomain = `-- name: DeleteDomain :exec +INSERT OR +REPLACE +INTO domains(domain, active) +VALUES (?, false) +` + +func (q *Queries) DeleteDomain(ctx context.Context, domain string) error { + _, err := q.db.ExecContext(ctx, deleteDomain, domain) + return err +} + +const getActiveDomains = `-- name: GetActiveDomains :many +SELECT domain +FROM domains +WHERE active = 1 +` + +func (q *Queries) GetActiveDomains(ctx context.Context) ([]string, error) { + rows, err := q.db.QueryContext(ctx, getActiveDomains) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var domain string + if err := rows.Scan(&domain); err != nil { + return nil, err + } + items = append(items, domain) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/database/favicon.sql.go b/database/favicon.sql.go new file mode 100644 index 0000000..1513cce --- /dev/null +++ b/database/favicon.sql.go @@ -0,0 +1,74 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: favicon.sql + +package database + +import ( + "context" + "database/sql" +) + +const getFavicons = `-- name: GetFavicons :many +SELECT host, svg, png, ico +FROM favicons +` + +type GetFaviconsRow struct { + Host string `json:"host"` + Svg sql.NullString `json:"svg"` + Png sql.NullString `json:"png"` + Ico sql.NullString `json:"ico"` +} + +func (q *Queries) GetFavicons(ctx context.Context) ([]GetFaviconsRow, error) { + rows, err := q.db.QueryContext(ctx, getFavicons) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetFaviconsRow + for rows.Next() { + var i GetFaviconsRow + if err := rows.Scan( + &i.Host, + &i.Svg, + &i.Png, + &i.Ico, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateFaviconCache = `-- name: UpdateFaviconCache :exec +INSERT OR +REPLACE INTO favicons (host, svg, png, ico) +VALUES (?, ?, ?, ?) +` + +type UpdateFaviconCacheParams struct { + Host string `json:"host"` + Svg sql.NullString `json:"svg"` + Png sql.NullString `json:"png"` + Ico sql.NullString `json:"ico"` +} + +func (q *Queries) UpdateFaviconCache(ctx context.Context, arg UpdateFaviconCacheParams) error { + _, err := q.db.ExecContext(ctx, updateFaviconCache, + arg.Host, + arg.Svg, + arg.Png, + arg.Ico, + ) + return err +} diff --git a/database/migrations/20240308125121_init.down.sql b/database/migrations/20240308125121_init.down.sql new file mode 100644 index 0000000..faf92e0 --- /dev/null +++ b/database/migrations/20240308125121_init.down.sql @@ -0,0 +1,4 @@ +DROP TABLE domains; +DROP TABLE favicons; +DROP TABLE routes; +DROP TABLE redirects; diff --git a/database/migrations/20240308125121_init.up.sql b/database/migrations/20240308125121_init.up.sql new file mode 100644 index 0000000..e9469f4 --- /dev/null +++ b/database/migrations/20240308125121_init.up.sql @@ -0,0 +1,36 @@ +CREATE TABLE IF NOT EXISTS domains +( + id INTEGER PRIMARY KEY AUTOINCREMENT, + domain TEXT UNIQUE NOT NULL, + active BOOLEAN NOT NULL DEFAULT 1 +); + +CREATE TABLE IF NOT EXISTS favicons +( + id INTEGER PRIMARY KEY AUTOINCREMENT, + host VARCHAR NOT NULL, + svg VARCHAR, + png VARCHAR, + ico VARCHAR +); + +CREATE TABLE IF NOT EXISTS routes +( + id INTEGER PRIMARY KEY AUTOINCREMENT, + source TEXT UNIQUE NOT NULL, + destination TEXT NOT NULL, + description TEXT NOT NULL, + flags INTEGER NOT NULL DEFAULT 0, + active BOOLEAN NOT NULL DEFAULT 1 +); + +CREATE TABLE IF NOT EXISTS redirects +( + id INTEGER PRIMARY KEY AUTOINCREMENT, + source TEXT UNIQUE NOT NULL, + destination TEXT NOT NULL, + description TEXT NOT NULL, + flags INTEGER NOT NULL DEFAULT 0, + code INTEGER NOT NULL DEFAULT 0, + active BOOLEAN NOT NULL DEFAULT 1 +); diff --git a/database/models.go b/database/models.go new file mode 100644 index 0000000..836327b --- /dev/null +++ b/database/models.go @@ -0,0 +1,44 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package database + +import ( + "database/sql" + + "github.com/1f349/violet/target" +) + +type Domain struct { + ID int64 `json:"id"` + Domain string `json:"domain"` + Active bool `json:"active"` +} + +type Favicon struct { + ID int64 `json:"id"` + Host string `json:"host"` + Svg sql.NullString `json:"svg"` + Png sql.NullString `json:"png"` + Ico sql.NullString `json:"ico"` +} + +type Redirect struct { + ID int64 `json:"id"` + Source string `json:"source"` + Destination string `json:"destination"` + Description string `json:"description"` + Flags target.Flags `json:"flags"` + Code int64 `json:"code"` + Active bool `json:"active"` +} + +type Route struct { + ID int64 `json:"id"` + Source string `json:"source"` + Destination string `json:"destination"` + Description string `json:"description"` + Flags target.Flags `json:"flags"` + Active bool `json:"active"` +} diff --git a/database/queries/domain.sql b/database/queries/domain.sql new file mode 100644 index 0000000..7c4df8b --- /dev/null +++ b/database/queries/domain.sql @@ -0,0 +1,16 @@ +-- name: GetActiveDomains :many +SELECT domain +FROM domains +WHERE active = 1; + +-- name: AddDomain :exec +INSERT OR +REPLACE +INTO domains (domain, active) +VALUES (?, ?); + +-- name: DeleteDomain :exec +INSERT OR +REPLACE +INTO domains(domain, active) +VALUES (?, false); diff --git a/database/queries/favicon.sql b/database/queries/favicon.sql new file mode 100644 index 0000000..bb208e6 --- /dev/null +++ b/database/queries/favicon.sql @@ -0,0 +1,8 @@ +-- name: GetFavicons :many +SELECT host, svg, png, ico +FROM favicons; + +-- name: UpdateFaviconCache :exec +INSERT OR +REPLACE INTO favicons (host, svg, png, ico) +VALUES (?, ?, ?, ?); diff --git a/database/queries/routing.sql b/database/queries/routing.sql new file mode 100644 index 0000000..0972c3b --- /dev/null +++ b/database/queries/routing.sql @@ -0,0 +1,39 @@ +-- name: GetActiveRoutes :many +SELECT source, destination, flags +FROM routes +WHERE active = 1; + +-- name: GetActiveRedirects :many +SELECT source, destination, flags, code +FROM redirects +WHERE active = 1; + +-- name: GetAllRoutes :many +SELECT source, destination, description, flags, active +FROM routes; + +-- name: GetAllRedirects :many +SELECT source, destination, description, flags, code, active +FROM redirects; + +-- name: AddRoute :exec +INSERT OR +REPLACE +INTO routes (source, destination, description, flags, active) +VALUES (?, ?, ?, ?, ?); + +-- name: AddRedirect :exec +INSERT OR +REPLACE +INTO redirects (source, destination, description, flags, code, active) +VALUES (?, ?, ?, ?, ?, ?); + +-- name: RemoveRoute :exec +DELETE +FROM routes +WHERE source = ?; + +-- name: RemoveRedirect :exec +DELETE +FROM redirects +WHERE source = ?; diff --git a/database/routing.sql.go b/database/routing.sql.go new file mode 100644 index 0000000..e747cb5 --- /dev/null +++ b/database/routing.sql.go @@ -0,0 +1,250 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: routing.sql + +package database + +import ( + "context" + + "github.com/1f349/violet/target" +) + +const addRedirect = `-- name: AddRedirect :exec +INSERT OR +REPLACE +INTO redirects (source, destination, description, flags, code, active) +VALUES (?, ?, ?, ?, ?, ?) +` + +type AddRedirectParams struct { + Source string `json:"source"` + Destination string `json:"destination"` + Description string `json:"description"` + Flags target.Flags `json:"flags"` + Code int64 `json:"code"` + Active bool `json:"active"` +} + +func (q *Queries) AddRedirect(ctx context.Context, arg AddRedirectParams) error { + _, err := q.db.ExecContext(ctx, addRedirect, + arg.Source, + arg.Destination, + arg.Description, + arg.Flags, + arg.Code, + arg.Active, + ) + return err +} + +const addRoute = `-- name: AddRoute :exec +INSERT OR +REPLACE +INTO routes (source, destination, description, flags, active) +VALUES (?, ?, ?, ?, ?) +` + +type AddRouteParams struct { + Source string `json:"source"` + Destination string `json:"destination"` + Description string `json:"description"` + Flags target.Flags `json:"flags"` + Active bool `json:"active"` +} + +func (q *Queries) AddRoute(ctx context.Context, arg AddRouteParams) error { + _, err := q.db.ExecContext(ctx, addRoute, + arg.Source, + arg.Destination, + arg.Description, + arg.Flags, + arg.Active, + ) + return err +} + +const getActiveRedirects = `-- name: GetActiveRedirects :many +SELECT source, destination, flags, code +FROM redirects +WHERE active = 1 +` + +type GetActiveRedirectsRow struct { + Source string `json:"source"` + Destination string `json:"destination"` + Flags target.Flags `json:"flags"` + Code int64 `json:"code"` +} + +func (q *Queries) GetActiveRedirects(ctx context.Context) ([]GetActiveRedirectsRow, error) { + rows, err := q.db.QueryContext(ctx, getActiveRedirects) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetActiveRedirectsRow + for rows.Next() { + var i GetActiveRedirectsRow + if err := rows.Scan( + &i.Source, + &i.Destination, + &i.Flags, + &i.Code, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getActiveRoutes = `-- name: GetActiveRoutes :many +SELECT source, destination, flags +FROM routes +WHERE active = 1 +` + +type GetActiveRoutesRow struct { + Source string `json:"source"` + Destination string `json:"destination"` + Flags target.Flags `json:"flags"` +} + +func (q *Queries) GetActiveRoutes(ctx context.Context) ([]GetActiveRoutesRow, error) { + rows, err := q.db.QueryContext(ctx, getActiveRoutes) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetActiveRoutesRow + for rows.Next() { + var i GetActiveRoutesRow + if err := rows.Scan(&i.Source, &i.Destination, &i.Flags); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getAllRedirects = `-- name: GetAllRedirects :many +SELECT source, destination, description, flags, code, active +FROM redirects +` + +type GetAllRedirectsRow struct { + Source string `json:"source"` + Destination string `json:"destination"` + Description string `json:"description"` + Flags target.Flags `json:"flags"` + Code int64 `json:"code"` + Active bool `json:"active"` +} + +func (q *Queries) GetAllRedirects(ctx context.Context) ([]GetAllRedirectsRow, error) { + rows, err := q.db.QueryContext(ctx, getAllRedirects) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetAllRedirectsRow + for rows.Next() { + var i GetAllRedirectsRow + if err := rows.Scan( + &i.Source, + &i.Destination, + &i.Description, + &i.Flags, + &i.Code, + &i.Active, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getAllRoutes = `-- name: GetAllRoutes :many +SELECT source, destination, description, flags, active +FROM routes +` + +type GetAllRoutesRow struct { + Source string `json:"source"` + Destination string `json:"destination"` + Description string `json:"description"` + Flags target.Flags `json:"flags"` + Active bool `json:"active"` +} + +func (q *Queries) GetAllRoutes(ctx context.Context) ([]GetAllRoutesRow, error) { + rows, err := q.db.QueryContext(ctx, getAllRoutes) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetAllRoutesRow + for rows.Next() { + var i GetAllRoutesRow + if err := rows.Scan( + &i.Source, + &i.Destination, + &i.Description, + &i.Flags, + &i.Active, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const removeRedirect = `-- name: RemoveRedirect :exec +DELETE +FROM redirects +WHERE source = ? +` + +func (q *Queries) RemoveRedirect(ctx context.Context, source string) error { + _, err := q.db.ExecContext(ctx, removeRedirect, source) + return err +} + +const removeRoute = `-- name: RemoveRoute :exec +DELETE +FROM routes +WHERE source = ? +` + +func (q *Queries) RemoveRoute(ctx context.Context, source string) error { + _, err := q.db.ExecContext(ctx, removeRoute, source) + return err +} diff --git a/domains/create-table-domains.sql b/domains/create-table-domains.sql deleted file mode 100644 index 279adcf..0000000 --- a/domains/create-table-domains.sql +++ /dev/null @@ -1,6 +0,0 @@ -CREATE TABLE IF NOT EXISTS domains -( - id INTEGER PRIMARY KEY AUTOINCREMENT, - domain TEXT UNIQUE, - active INTEGER DEFAULT 1 -); diff --git a/domains/domains.go b/domains/domains.go index e7283b9..868f358 100644 --- a/domains/domains.go +++ b/domains/domains.go @@ -1,8 +1,9 @@ package domains import ( - "database/sql" + "context" _ "embed" + "github.com/1f349/violet/database" "github.com/1f349/violet/utils" "github.com/MrMelon54/rescheduler" "log" @@ -10,32 +11,22 @@ import ( "sync" ) -//go:embed create-table-domains.sql -var createTableDomains string - // Domains is the domain list and management system. type Domains struct { - db *sql.DB + db *database.Queries s *sync.RWMutex m map[string]struct{} r *rescheduler.Rescheduler } // New creates a new domain list -func New(db *sql.DB) *Domains { +func New(db *database.Queries) *Domains { a := &Domains{ db: db, s: &sync.RWMutex{}, m: make(map[string]struct{}), } a.r = rescheduler.NewRescheduler(a.threadCompile) - - // init domains table - _, err := a.db.Exec(createTableDomains) - if err != nil { - log.Printf("[WARN] Failed to generate 'domains' table\n") - return nil - } return a } @@ -93,30 +84,26 @@ func (d *Domains) internalCompile(m map[string]struct{}) error { log.Println("[Domains] Updating domains from database") // sql or something? - rows, err := d.db.Query(`select domain from domains where active = 1`) + rows, err := d.db.GetActiveDomains(context.Background()) if err != nil { return err } - defer rows.Close() - // loop through rows and scan the allowed domain names - for rows.Next() { - var name string - err = rows.Scan(&name) - if err != nil { - return err - } - m[name] = struct{}{} + for _, i := range rows { + m[i] = struct{}{} } // check for errors - return rows.Err() + return nil } func (d *Domains) Put(domain string, active bool) { d.s.Lock() defer d.s.Unlock() - _, err := d.db.Exec("INSERT OR REPLACE INTO domains (domain, active) VALUES (?, ?)", domain, active) + err := d.db.AddDomain(context.Background(), database.AddDomainParams{ + Domain: domain, + Active: active, + }) if err != nil { log.Printf("[Violet] Database error: %s\n", err) } @@ -125,7 +112,7 @@ func (d *Domains) Put(domain string, active bool) { func (d *Domains) Delete(domain string) { d.s.Lock() defer d.s.Unlock() - _, err := d.db.Exec("INSERT OR REPLACE INTO domains (domain, active) VALUES (?, ?)", domain, false) + err := d.db.DeleteDomain(context.Background(), domain) if err != nil { log.Printf("[Violet] Database error: %s\n", err) } diff --git a/domains/domains_test.go b/domains/domains_test.go index d2a0e3f..86484e7 100644 --- a/domains/domains_test.go +++ b/domains/domains_test.go @@ -1,18 +1,20 @@ package domains import ( - "database/sql" + "context" + "github.com/1f349/violet" + "github.com/1f349/violet/database" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" "testing" ) func TestDomainsNew(t *testing.T) { - db, err := sql.Open("sqlite3", "file::memory:?cache=shared") + db, err := violet.InitDB("file:TestDomainsNew?mode=memory&cache=shared") assert.NoError(t, err) domains := New(db) - _, err = db.Exec("INSERT OR IGNORE INTO domains (domain, active) VALUES (?, ?)", "example.com", 1) + err = db.AddDomain(context.Background(), database.AddDomainParams{Domain: "example.com", Active: true}) assert.NoError(t, err) domains.Compile() @@ -27,11 +29,11 @@ func TestDomainsNew(t *testing.T) { func TestDomains_IsValid(t *testing.T) { // open sqlite database - db, err := sql.Open("sqlite3", "file::memory:?cache=shared") + db, err := violet.InitDB("file:TestDomains_IsValid?mode=memory&cache=shared") assert.NoError(t, err) domains := New(db) - _, err = domains.db.Exec("INSERT OR IGNORE INTO domains (domain, active) VALUES (?, ?)", "example.com", 1) + err = db.AddDomain(context.Background(), database.AddDomainParams{Domain: "example.com", Active: true}) assert.NoError(t, err) domains.s.Lock() diff --git a/favicons/create-table-favicons.sql b/favicons/create-table-favicons.sql deleted file mode 100644 index edc48ec..0000000 --- a/favicons/create-table-favicons.sql +++ /dev/null @@ -1,7 +0,0 @@ -CREATE TABLE IF NOT EXISTS favicons ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - host VARCHAR, - svg VARCHAR, - png VARCHAR, - ico VARCHAR -); diff --git a/favicons/favicon-image.go b/favicons/favicon-image.go index 91a49bc..bf1d42c 100644 --- a/favicons/favicon-image.go +++ b/favicons/favicon-image.go @@ -1,5 +1,7 @@ package favicons +import "database/sql" + // FaviconImage stores the url, hash and raw bytes of an image type FaviconImage struct { Url string @@ -9,9 +11,9 @@ type FaviconImage struct { // CreateFaviconImage outputs a FaviconImage with the specified URL or nil if // the URL is an empty string. -func CreateFaviconImage(url string) *FaviconImage { - if url == "" { +func CreateFaviconImage(url sql.NullString) *FaviconImage { + if !url.Valid { return nil } - return &FaviconImage{Url: url} + return &FaviconImage{Url: url.String} } diff --git a/favicons/favicons.go b/favicons/favicons.go index 0864859..7fb3d1e 100644 --- a/favicons/favicons.go +++ b/favicons/favicons.go @@ -1,10 +1,11 @@ package favicons import ( - "database/sql" + "context" _ "embed" "errors" "fmt" + "github.com/1f349/violet/database" "github.com/MrMelon54/rescheduler" "golang.org/x/sync/errgroup" "log" @@ -13,12 +14,9 @@ import ( var ErrFaviconNotFound = errors.New("favicon not found") -//go:embed create-table-favicons.sql -var createTableFavicons string - // Favicons is a dynamic favicon generator which supports overwriting favicons type Favicons struct { - db *sql.DB + db *database.Queries cmd string cLock *sync.RWMutex faviconMap map[string]*FaviconList @@ -26,7 +24,7 @@ type Favicons struct { } // New creates a new dynamic favicon generator -func New(db *sql.DB, inkscapeCmd string) *Favicons { +func New(db *database.Queries, inkscapeCmd string) *Favicons { f := &Favicons{ db: db, cmd: inkscapeCmd, @@ -35,13 +33,6 @@ func New(db *sql.DB, inkscapeCmd string) *Favicons { } f.r = rescheduler.NewRescheduler(f.threadCompile) - // init favicons table - _, err := f.db.Exec(createTableFavicons) - if err != nil { - log.Printf("[WARN] Failed to generate 'favicons' table\n") - return nil - } - // run compile to get the initial data f.Compile() return f @@ -89,29 +80,23 @@ func (f *Favicons) threadCompile() { // favicons. func (f *Favicons) internalCompile(m map[string]*FaviconList) error { // query all rows in database - query, err := f.db.Query(`select host, svg, png, ico from favicons`) + rows, err := f.db.GetFavicons(context.Background()) if err != nil { - return fmt.Errorf("failed to prepare query: %w", err) + return fmt.Errorf("failed to prepare rows: %w", err) } // loop over rows and scan in data using error group to catch errors var g errgroup.Group - for query.Next() { - var host, rawSvg, rawPng, rawIco string - err := query.Scan(&host, &rawSvg, &rawPng, &rawIco) - if err != nil { - return fmt.Errorf("failed to scan row: %w", err) - } - + for _, row := range rows { // create favicon list for this row l := &FaviconList{ - Ico: CreateFaviconImage(rawIco), - Png: CreateFaviconImage(rawPng), - Svg: CreateFaviconImage(rawSvg), + Ico: CreateFaviconImage(row.Ico), + Png: CreateFaviconImage(row.Png), + Svg: CreateFaviconImage(row.Svg), } // save the favicon list to the map - m[host] = l + m[row.Host] = l // run the pre-process in a separate goroutine g.Go(func() error { diff --git a/favicons/favicons_test.go b/favicons/favicons_test.go index 92bfd08..6faceea 100644 --- a/favicons/favicons_test.go +++ b/favicons/favicons_test.go @@ -2,8 +2,11 @@ package favicons import ( "bytes" + "context" "database/sql" _ "embed" + "github.com/1f349/violet" + "github.com/1f349/violet/database" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" "image/png" @@ -22,11 +25,17 @@ var ( func TestFaviconsNew(t *testing.T) { getFaviconViaRequest = func(_ string) ([]byte, error) { return exampleSvg, nil } - db, err := sql.Open("sqlite3", "file::memory:?cache=shared") + db, err := violet.InitDB("file:TestFaviconsNew?mode=memory&cache=shared") assert.NoError(t, err) favicons := New(db, "inkscape") - _, err = db.Exec("insert into favicons (host, svg, png, ico) values (?, ?, ?, ?)", "example.com", "https://example.com/assets/logo.svg", "", "") + err = db.UpdateFaviconCache(context.Background(), database.UpdateFaviconCacheParams{ + Host: "example.com", + Svg: sql.NullString{ + String: "https://example.com/assets/logo.svg", + Valid: true, + }, + }) assert.NoError(t, err) favicons.cLock.Lock() assert.NoError(t, favicons.internalCompile(favicons.faviconMap)) diff --git a/go.mod b/go.mod index c12f711..3492a25 100644 --- a/go.mod +++ b/go.mod @@ -1,26 +1,27 @@ module github.com/1f349/violet -go 1.21.4 +go 1.22 require ( - github.com/1f349/mjwt v0.2.1 + github.com/1f349/mjwt v0.2.5 github.com/AlecAivazis/survey/v2 v2.3.7 github.com/MrMelon54/certgen v0.0.1 github.com/MrMelon54/exit-reload v0.0.1 github.com/MrMelon54/png2ico v1.0.1 github.com/MrMelon54/rescheduler v0.0.2 github.com/MrMelon54/trie v0.0.2 + github.com/golang-migrate/migrate/v4 v4.17.0 github.com/google/subcommands v1.2.0 - github.com/google/uuid v1.4.0 + github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.1 github.com/julienschmidt/httprouter v1.3.0 - github.com/mattn/go-sqlite3 v1.14.18 - github.com/prometheus/client_golang v1.18.0 + github.com/mattn/go-sqlite3 v1.14.22 + github.com/prometheus/client_golang v1.19.0 github.com/rs/cors v1.10.1 github.com/sethvargo/go-limiter v0.7.2 - github.com/stretchr/testify v1.8.4 - golang.org/x/net v0.19.0 - golang.org/x/sync v0.5.0 + github.com/stretchr/testify v1.9.0 + golang.org/x/net v0.22.0 + golang.org/x/sync v0.6.0 ) require ( @@ -29,21 +30,23 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang-jwt/jwt/v4 v4.5.0 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/kr/text v0.2.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_model v0.5.0 // indirect - github.com/prometheus/common v0.45.0 // indirect - github.com/prometheus/procfs v0.12.0 // indirect - github.com/rogpeppe/go-internal v1.11.0 // indirect - golang.org/x/sys v0.15.0 // indirect - golang.org/x/term v0.15.0 // indirect + github.com/prometheus/client_model v0.6.0 // indirect + github.com/prometheus/common v0.50.0 // indirect + github.com/prometheus/procfs v0.13.0 // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect + go.uber.org/atomic v1.11.0 // indirect + golang.org/x/sys v0.18.0 // indirect + golang.org/x/term v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect - google.golang.org/protobuf v1.31.0 // indirect + google.golang.org/protobuf v1.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f46489e..205253d 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/1f349/mjwt v0.2.1 h1:REdiM/MaNjYQwHvI39LaMPhlvMg4Vy9SgomWMsKTNz8= -github.com/1f349/mjwt v0.2.1/go.mod h1:KEs6jd9JjWrQW+8feP2pGAU7pdA3aYTqjkT/YQr73PU= +github.com/1f349/mjwt v0.2.5 h1:IxjLaali22ayTzZ628lH7j0JDdYJoj6+CJ/VktCqtXQ= +github.com/1f349/mjwt v0.2.5/go.mod h1:KEs6jd9JjWrQW+8feP2pGAU7pdA3aYTqjkT/YQr73PU= github.com/AlecAivazis/survey/v2 v2.3.7 h1:6I/u8FvytdGsgonrYsVn2t8t4QiRnh6QSTqkkhIiSjQ= github.com/AlecAivazis/survey/v2 v2.3.7/go.mod h1:xUTIdE4KCOIjsBAE1JYsUPoCqYdZ1reCfTwbto0Fduo= github.com/MrMelon54/certgen v0.0.1 h1:ycWdZ2RlxQ5qSuejeBVv4aXjGo5hdqqL4j4EjrXnFMk= @@ -28,16 +28,21 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/golang-migrate/migrate/v4 v4.17.0 h1:rd40H3QXU0AA4IoLllFcEAEo9dYKRHYND2gB4p7xcaU= +github.com/golang-migrate/migrate/v4 v4.17.0/go.mod h1:+Cp2mtLP4/aXDTKb9wmXYitdrNx2HGs45rbWAo6OsKM= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= -github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= -github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= @@ -48,6 +53,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= @@ -55,10 +62,8 @@ github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI= -github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= -github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg= -github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0/go.mod h1:QUyp042oQthUoa9bqDv0ER0wrtXnBruoNd7aNjkbP+k= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= @@ -66,37 +71,44 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v1.18.0 h1:HzFfmkOzH5Q8L8G+kSJKUx5dtG87sewO+FoDDqP5Tbk= -github.com/prometheus/client_golang v1.18.0/go.mod h1:T+GXkCk5wSJyOqMIzVgvvjFDlkOQntgjkJWKrN5txjA= -github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= -github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= -github.com/prometheus/common v0.45.0 h1:2BGz0eBc2hdMDLnO/8n0jeB3oPrt2D08CekT0lneoxM= -github.com/prometheus/common v0.45.0/go.mod h1:YJmSTw9BoKxJplESWWxlbyttQR4uaEcGyv9MZjVOJsY= -github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= -github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= +github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU= +github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k= +github.com/prometheus/client_model v0.6.0 h1:k1v3CzpSRUTrKMppY35TLwPvxHqBu0bYgxZzqGIgaos= +github.com/prometheus/client_model v0.6.0/go.mod h1:NTQHnmxFpouOD0DpvP4XujX3CdOAGQPoaGhyTchlyt8= +github.com/prometheus/common v0.50.0 h1:YSZE6aa9+luNa2da6/Tik0q0A5AbR+U003TItK57CPQ= +github.com/prometheus/common v0.50.0/go.mod h1:wHFBCEVWVmHMUpg7pYcOm2QUR/ocQdYSJVQJKnHc3xQ= +github.com/prometheus/procfs v0.13.0 h1:GqzLlQyfsPbaEHaQkO7tbDlriv/4o5Hudv6OXHGKX7o= +github.com/prometheus/procfs v0.13.0/go.mod h1:cd4PFCR54QLnGKPaKGA6l+cfuNXtht43ZKY6tow0Y1g= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/rs/cors v1.10.1 h1:L0uuZVXIKlI1SShY2nhFfo44TYvDPQ1w4oFkUJNfhyo= github.com/rs/cors v1.10.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/sethvargo/go-limiter v0.7.2 h1:FgC4N7RMpV5gMrUdda15FaFTkQ/L4fEqM7seXMs4oO8= github.com/sethvargo/go-limiter v0.7.2/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= +golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= -golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -105,12 +117,12 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= -golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= +golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= +golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= @@ -121,10 +133,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= -google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/initdb.go b/initdb.go new file mode 100644 index 0000000..62f8793 --- /dev/null +++ b/initdb.go @@ -0,0 +1,38 @@ +package violet + +import ( + "database/sql" + "embed" + "errors" + "github.com/1f349/violet/database" + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/sqlite3" + "github.com/golang-migrate/migrate/v4/source/iofs" +) + +//go:embed database/migrations/*.sql +var migrations embed.FS + +func InitDB(p string) (*database.Queries, error) { + migDrv, err := iofs.New(migrations, "database/migrations") + if err != nil { + return nil, err + } + dbOpen, err := sql.Open("sqlite3", p) + if err != nil { + return nil, err + } + dbDrv, err := sqlite3.WithInstance(dbOpen, &sqlite3.Config{}) + if err != nil { + return nil, err + } + mig, err := migrate.NewWithInstance("iofs", migDrv, "sqlite3", dbDrv) + if err != nil { + return nil, err + } + err = mig.Up() + if err != nil && !errors.Is(err, migrate.ErrNoChange) { + return nil, err + } + return database.New(dbOpen), nil +} diff --git a/router/create-tables.sql b/router/create-tables.sql deleted file mode 100644 index 6f03820..0000000 --- a/router/create-tables.sql +++ /dev/null @@ -1,20 +0,0 @@ -CREATE TABLE IF NOT EXISTS routes -( - id INTEGER PRIMARY KEY AUTOINCREMENT, - source TEXT UNIQUE, - destination TEXT, - description TEXT, - flags INTEGER DEFAULT 0, - active INTEGER DEFAULT 1 -); - -CREATE TABLE IF NOT EXISTS redirects -( - id INTEGER PRIMARY KEY AUTOINCREMENT, - source TEXT UNIQUE, - destination TEXT, - description TEXT, - flags INTEGER DEFAULT 0, - code INTEGER DEFAULT 0, - active INTEGER DEFAULT 1 -); diff --git a/router/manager.go b/router/manager.go index 24bd4e8..95d2926 100644 --- a/router/manager.go +++ b/router/manager.go @@ -1,8 +1,9 @@ package router import ( - "database/sql" + "context" _ "embed" + "github.com/1f349/violet/database" "github.com/1f349/violet/proxy" "github.com/1f349/violet/target" "github.com/MrMelon54/rescheduler" @@ -15,21 +16,16 @@ import ( // Manager is a database and mutex wrap around router allowing it to be // dynamically regenerated after updating the database of routes. type Manager struct { - db *sql.DB + db *database.Queries s *sync.RWMutex r *Router p *proxy.HybridTransport z *rescheduler.Rescheduler } -var ( - //go:embed create-tables.sql - createTables string -) - // NewManager create a new manager, initialises the routes and redirects tables // in the database and runs a first time compile. -func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager { +func NewManager(db *database.Queries, proxy *proxy.HybridTransport) *Manager { m := &Manager{ db: db, s: &sync.RWMutex{}, @@ -37,13 +33,6 @@ func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager { p: proxy, } m.z = rescheduler.NewRescheduler(m.threadCompile) - - // init routes table - _, err := m.db.Exec(createTables) - if err != nil { - log.Printf("[WARN] Failed to generate tables\n") - return nil - } return m } @@ -81,64 +70,36 @@ func (m *Manager) internalCompile(router *Router) error { log.Println("[Manager] Updating routes from database") // sql or something? - rows, err := m.db.Query(`SELECT source, destination, flags FROM routes WHERE active = 1`) + routeRows, err := m.db.GetActiveRoutes(context.Background()) if err != nil { return err } - defer rows.Close() - - // loop through rows and scan the options - for rows.Next() { - var ( - src, dst string - flags target.Flags - ) - err := rows.Scan(&src, &dst, &flags) - if err != nil { - return err - } + for _, row := range routeRows { router.AddRoute(target.Route{ - Src: src, - Dst: dst, - Flags: flags.NormaliseRouteFlags(), + Src: row.Source, + Dst: row.Destination, + Flags: row.Flags.NormaliseRouteFlags(), }) } - // check for errors - if err := rows.Err(); err != nil { - return err - } - // sql or something? - rows, err = m.db.Query(`SELECT source,destination,flags,code FROM redirects WHERE active = 1`) + redirectsRows, err := m.db.GetActiveRedirects(context.Background()) if err != nil { return err } - defer rows.Close() - - // loop through rows and scan the options - for rows.Next() { - var ( - src, dst string - flags target.Flags - code int - ) - err := rows.Scan(&src, &dst, &flags, &code) - if err != nil { - return err - } + for _, row := range redirectsRows { router.AddRedirect(target.Redirect{ - Src: src, - Dst: dst, - Flags: flags.NormaliseRedirectFlags(), - Code: code, + Src: row.Source, + Dst: row.Destination, + Flags: row.Flags.NormaliseRedirectFlags(), + Code: row.Code, }) } // check for errors - return rows.Err() + return nil } func (m *Manager) GetAllRoutes(hosts []string) ([]target.RouteWithActive, error) { @@ -148,15 +109,20 @@ func (m *Manager) GetAllRoutes(hosts []string) ([]target.RouteWithActive, error) s := make([]target.RouteWithActive, 0) - query, err := m.db.Query(`SELECT source, destination, description, flags, active FROM routes`) + rows, err := m.db.GetAllRoutes(context.Background()) if err != nil { return nil, err } - for query.Next() { - var a target.RouteWithActive - if err := query.Scan(&a.Src, &a.Dst, &a.Desc, &a.Flags, &a.Active); err != nil { - return nil, err + for _, row := range rows { + a := target.RouteWithActive{ + Route: target.Route{ + Src: row.Source, + Dst: row.Destination, + Desc: row.Description, + Flags: row.Flags, + }, + Active: row.Active, } for _, i := range hosts { @@ -172,13 +138,17 @@ func (m *Manager) GetAllRoutes(hosts []string) ([]target.RouteWithActive, error) } func (m *Manager) InsertRoute(route target.RouteWithActive) error { - _, err := m.db.Exec(`INSERT INTO routes (source, destination, description, flags, active) VALUES (?, ?, ?, ?, ?) ON CONFLICT(source) DO UPDATE SET destination = excluded.destination, description = excluded.description, flags = excluded.flags, active = excluded.active`, route.Src, route.Dst, route.Desc, route.Flags, route.Active) - return err + return m.db.AddRoute(context.Background(), database.AddRouteParams{ + Source: route.Src, + Destination: route.Dst, + Description: route.Desc, + Flags: route.Flags, + Active: route.Active, + }) } func (m *Manager) DeleteRoute(source string) error { - _, err := m.db.Exec(`DELETE FROM routes WHERE source = ?`, source) - return err + return m.db.RemoveRoute(context.Background(), source) } func (m *Manager) GetAllRedirects(hosts []string) ([]target.RedirectWithActive, error) { @@ -188,15 +158,21 @@ func (m *Manager) GetAllRedirects(hosts []string) ([]target.RedirectWithActive, s := make([]target.RedirectWithActive, 0) - query, err := m.db.Query(`SELECT source, destination, description, flags, code, active FROM redirects`) + rows, err := m.db.GetAllRedirects(context.Background()) if err != nil { return nil, err } - for query.Next() { - var a target.RedirectWithActive - if err := query.Scan(&a.Src, &a.Dst, &a.Desc, &a.Flags, &a.Code, &a.Active); err != nil { - return nil, err + for _, row := range rows { + a := target.RedirectWithActive{ + Redirect: target.Redirect{ + Src: row.Source, + Dst: row.Destination, + Desc: row.Description, + Flags: row.Flags, + Code: row.Code, + }, + Active: row.Active, } for _, i := range hosts { @@ -212,13 +188,18 @@ func (m *Manager) GetAllRedirects(hosts []string) ([]target.RedirectWithActive, } func (m *Manager) InsertRedirect(redirect target.RedirectWithActive) error { - _, err := m.db.Exec(`INSERT INTO redirects (source, destination, description, flags, code, active) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT(source) DO UPDATE SET destination = excluded.destination, description = excluded.description, flags = excluded.flags, code = excluded.code, active = excluded.active`, redirect.Src, redirect.Dst, redirect.Desc, redirect.Flags, redirect.Code, redirect.Active) - return err + return m.db.AddRedirect(context.Background(), database.AddRedirectParams{ + Source: redirect.Src, + Destination: redirect.Dst, + Description: redirect.Desc, + Flags: redirect.Flags, + Code: redirect.Code, + Active: redirect.Active, + }) } func (m *Manager) DeleteRedirect(source string) error { - _, err := m.db.Exec(`DELETE FROM redirects WHERE source = ?`, source) - return err + return m.db.RemoveRedirect(context.Background(), source) } // GenerateHostSearch this should help improve performance diff --git a/router/manager_test.go b/router/manager_test.go index efab3b0..50b0a9d 100644 --- a/router/manager_test.go +++ b/router/manager_test.go @@ -1,7 +1,9 @@ package router import ( - "database/sql" + "context" + "github.com/1f349/violet" + "github.com/1f349/violet/database" "github.com/1f349/violet/proxy" "github.com/1f349/violet/proxy/websocket" "github.com/1f349/violet/target" @@ -22,7 +24,7 @@ func (f *fakeTransport) RoundTrip(req *http.Request) (*http.Response, error) { } func TestNewManager(t *testing.T) { - db, err := sql.Open("sqlite3", "file::memory:?cache=shared") + db, err := violet.InitDB("file:TestNewManager?mode=memory&cache=shared") assert.NoError(t, err) ft := &fakeTransport{} @@ -39,7 +41,13 @@ func TestNewManager(t *testing.T) { assert.Equal(t, http.StatusTeapot, res.StatusCode) assert.Nil(t, ft.req) - _, err = db.Exec(`INSERT INTO routes (source, destination, flags, active) VALUES (?,?,?,1)`, "*.example.com", "127.0.0.1:8080", target.FlagAbs|target.FlagForwardHost|target.FlagForwardAddr) + err = db.AddRoute(context.Background(), database.AddRouteParams{ + Source: "*.example.com", + Destination: "127.0.0.1:8080", + Description: "", + Flags: target.FlagAbs | target.FlagForwardHost | target.FlagForwardAddr, + Active: true, + }) assert.NoError(t, err) assert.NoError(t, m.internalCompile(m.r)) @@ -52,10 +60,8 @@ func TestNewManager(t *testing.T) { } func TestManager_GetAllRoutes(t *testing.T) { - db, err := sql.Open("sqlite3", "file:GetAllRoutes?mode=memory&cache=shared") - if err != nil { - t.Fatal(err) - } + db, err := violet.InitDB("file:TestManager_GetAllRoutes?mode=memory&cache=shared") + assert.NoError(t, err) m := NewManager(db, nil) a := []error{ m.InsertRoute(target.RouteWithActive{Route: target.Route{Src: "example.com"}, Active: true}), @@ -85,10 +91,8 @@ func TestManager_GetAllRoutes(t *testing.T) { } func TestManager_GetAllRedirects(t *testing.T) { - db, err := sql.Open("sqlite3", "file:GetAllRedirects?mode=memory&cache=shared") - if err != nil { - t.Fatal(err) - } + db, err := violet.InitDB("file:TestManager_GetAllRedirects?mode=memory&cache=shared") + assert.NoError(t, err) m := NewManager(db, nil) a := []error{ m.InsertRedirect(target.RedirectWithActive{Redirect: target.Redirect{Src: "example.com"}, Active: true}), diff --git a/servers/conf/conf.go b/servers/conf/conf.go index f4541ab..cdbb428 100644 --- a/servers/conf/conf.go +++ b/servers/conf/conf.go @@ -1,8 +1,8 @@ package conf import ( - "database/sql" "github.com/1f349/mjwt" + "github.com/1f349/violet/database" errorPages "github.com/1f349/violet/error-pages" "github.com/1f349/violet/favicons" "github.com/1f349/violet/router" @@ -15,7 +15,7 @@ type Conf struct { HttpListen string // http server listen address HttpsListen string // https server listen address RateLimit uint64 // rate limit per minute - DB *sql.DB + DB *database.Queries Domains utils.DomainProvider Acme utils.AcmeChallengeProvider Certs utils.CertProvider diff --git a/servers/http.go b/servers/http.go index 69edc6c..886b750 100644 --- a/servers/http.go +++ b/servers/http.go @@ -63,7 +63,12 @@ func NewHttpServer(conf *conf.Conf, registry *prometheus.Registry) *http.Server utils.FastRedirect(rw, req, u.String(), http.StatusPermanentRedirect) }) - metricsMiddleware := metrics.New(registry, nil).WrapHandler("violet-http-insecure", r) + metricsMiddleware := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + r.ServeHTTP(rw, req) + }) + if registry != nil { + metricsMiddleware = metrics.New(registry, nil).WrapHandler("violet-http-insecure", r) + } // Create and run http server return &http.Server{ diff --git a/servers/https.go b/servers/https.go index 4d4a0c5..5a05ee5 100644 --- a/servers/https.go +++ b/servers/https.go @@ -25,14 +25,20 @@ func NewHttpsServer(conf *conf.Conf, registry *prometheus.Registry) *http.Server conf.Router.ServeHTTP(rw, req) }) favMiddleware := setupFaviconMiddleware(conf.Favicons, r) - rateLimiter := setupRateLimiter(conf.RateLimit, favMiddleware) - metricsMiddleware := metrics.New(registry, nil).WrapHandler("violet-https", rateLimiter) + metricsMeta := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - metricsMiddleware.ServeHTTP(rw, metrics.AddHostCtx(req)) + r.ServeHTTP(rw, req) }) + if registry != nil { + metricsMiddleware := metrics.New(registry, nil).WrapHandler("violet-https", favMiddleware) + metricsMeta = func(rw http.ResponseWriter, req *http.Request) { + metricsMiddleware.ServeHTTP(rw, metrics.AddHostCtx(req)) + } + } + rateLimiter := setupRateLimiter(conf.RateLimit, metricsMeta) hsts := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains") - metricsMeta.ServeHTTP(rw, req) + rateLimiter.ServeHTTP(rw, req) }) return &http.Server{ diff --git a/servers/https_test.go b/servers/https_test.go index ef86813..26bf7b7 100644 --- a/servers/https_test.go +++ b/servers/https_test.go @@ -1,7 +1,7 @@ package servers import ( - "database/sql" + "github.com/1f349/violet" "github.com/1f349/violet/certs" "github.com/1f349/violet/proxy" "github.com/1f349/violet/proxy/websocket" @@ -25,7 +25,7 @@ func (f *fakeTransport) RoundTrip(_ *http.Request) (*http.Response, error) { } func TestNewHttpsServer_RateLimit(t *testing.T) { - db, err := sql.Open("sqlite3", "file::memory:?cache=shared") + db, err := violet.InitDB("file:TestNewHttpsServer_RateLimit?mode=memory&cache=shared") assert.NoError(t, err) ft := &fakeTransport{} diff --git a/sqlc.yaml b/sqlc.yaml new file mode 100644 index 0000000..953e616 --- /dev/null +++ b/sqlc.yaml @@ -0,0 +1,15 @@ +version: "2" +sql: + - engine: sqlite + queries: database/queries + schema: database/migrations + gen: + go: + package: "database" + out: "database" + emit_json_tags: true + overrides: + - column: "routes.flags" + go_type: "github.com/1f349/violet/target.Flags" + - column: "redirects.flags" + go_type: "github.com/1f349/violet/target.Flags" diff --git a/target/redirect.go b/target/redirect.go index a80f438..752d376 100644 --- a/target/redirect.go +++ b/target/redirect.go @@ -16,7 +16,7 @@ type Redirect struct { Dst string `json:"dst"` // redirect destination Desc string `json:"desc"` // description for admin panel use Flags Flags `json:"flags"` // extra flags - Code int `json:"code"` // status code used to redirect + Code int64 `json:"code"` // status code used to redirect } type RedirectWithActive struct { @@ -78,7 +78,7 @@ func (r Redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } // use fast redirect for speed - utils.FastRedirect(rw, req, u.String(), code) + utils.FastRedirect(rw, req, u.String(), int(code)) } // String outputs a debug string for the redirect. diff --git a/target/redirect_test.go b/target/redirect_test.go index c027432..51bd280 100644 --- a/target/redirect_test.go +++ b/target/redirect_test.go @@ -35,7 +35,7 @@ func TestRedirect_ServeHTTP(t *testing.T) { res := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "https://www.example.com/hello/world", nil) i.ServeHTTP(res, req) - assert.Equal(t, i.Code, res.Code) + assert.Equal(t, i.Code, int64(res.Code)) assert.Equal(t, i.target, res.Header().Get("Location")) } } diff --git a/target/route_test.go b/target/route_test.go index 30da39e..625e18f 100644 --- a/target/route_test.go +++ b/target/route_test.go @@ -90,7 +90,7 @@ func TestRoute_ServeHTTP_Cors(t *testing.T) { assert.Equal(t, http.MethodOptions, pt.req.Method) assert.Equal(t, "http://1.1.1.1:8080/hello/test", pt.req.URL.String()) assert.Equal(t, "Origin", res.Header().Get("Vary")) - assert.Equal(t, "*", res.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "https://test.example.com", res.Header().Get("Access-Control-Allow-Origin")) assert.Equal(t, "true", res.Header().Get("Access-Control-Allow-Credentials")) assert.Equal(t, "Origin", res.Header().Get("Vary")) }