Server config and comments

This commit is contained in:
Melon 2023-04-22 18:11:21 +01:00
parent a4eab71e33
commit 371a150c2a
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
14 changed files with 351 additions and 39 deletions

View File

@ -1,7 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="SqlDialectMappings"> <component name="SqlDialectMappings">
<file url="file://$PROJECT_DIR$/cmd/violet/init.sql" dialect="GenericSQL" />
<file url="PROJECT" dialect="SQLite" /> <file url="PROJECT" dialect="SQLite" />
</component> </component>
</project> </project>

120
certs/certs.go Normal file
View File

@ -0,0 +1,120 @@
package certs
import (
"code.mrmelon54.com/melon/certgen"
"crypto/tls"
"fmt"
"github.com/MrMelon54/violet/utils"
"io/fs"
"log"
"path/filepath"
"sync"
)
type Certs struct {
cDir fs.FS
kDir fs.FS
s *sync.RWMutex
m map[string]*tls.Certificate
}
func New(certDir fs.FS, keyDir fs.FS) *Certs {
a := &Certs{
cDir: certDir,
kDir: keyDir,
s: &sync.RWMutex{},
m: make(map[string]*tls.Certificate),
}
a.Compile()
return a
}
func (c *Certs) GetCertForDomain(domain string) *tls.Certificate {
// safety read lock
c.s.RLock()
defer c.s.RUnlock()
// lookup and return cert
if cert, ok := c.m[domain]; ok {
return cert
}
// lookup and return wildcard cert
if wildcardDomain, ok := utils.ReplaceSubdomainWithWildcard(domain); ok {
if cert, ok := c.m[wildcardDomain]; ok {
return cert
}
}
// no cert found
return nil
}
func (c *Certs) Compile() {
// async compile magic
go func() {
certMap := make(map[string]*tls.Certificate)
err := c.internalCompile(certMap)
if err != nil {
log.Printf("[Certs] Compile failed: %s\n", err)
return
}
// lock while replacing the map
c.s.Lock()
c.m = certMap
c.s.Unlock()
}()
}
func (c *Certs) internalCompile(m map[string]*tls.Certificate) error {
// try to read dir
files, err := fs.ReadDir(c.cDir, "")
if err != nil {
return fmt.Errorf("failed to read cert dir: %w", err)
}
log.Printf("[Certs] Compiling lookup table for %d certificates\n", len(files))
// find and parse certs
for _, i := range files {
// skip dirs
if i.IsDir() {
continue
}
// get file name and extension
name := i.Name()
ext := filepath.Ext(name)
keyName := name[:len(name)-len(ext)] + "key"
// try to read cert file
certData, err := fs.ReadFile(c.cDir, name)
if err != nil {
return fmt.Errorf("failed to read cert file '%s': %w", name, err)
}
// try to read key file
keyData, err := fs.ReadFile(c.kDir, keyName)
if err != nil {
return fmt.Errorf("failed to read key file '%s': %w", keyName, err)
}
// load key pair
pair, err := tls.X509KeyPair(certData, keyData)
if err != nil {
return fmt.Errorf("failed to load x509 key pair '%s + %s': %w", name, keyName, err)
}
// load tls leaf
cert := &pair
leaf := certgen.TlsLeaf(cert)
// save in map under each dns name
for _, j := range leaf.DNSNames {
m[j] = cert
}
}
// well no errors happened
return nil
}

View File

@ -1 +0,0 @@
create table acme_challenge (id integer not null primary key, key varchar, value varchar);

View File

@ -3,8 +3,8 @@ package main
import ( import (
"database/sql" "database/sql"
_ "embed" _ "embed"
"errors"
"flag" "flag"
"github.com/MrMelon54/violet/certs"
"github.com/MrMelon54/violet/domains" "github.com/MrMelon54/violet/domains"
"github.com/MrMelon54/violet/proxy" "github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/router" "github.com/MrMelon54/violet/router"
@ -15,45 +15,65 @@ import (
"os" "os"
) )
//go:embed init.sql
var initSql string
var ( var (
databasePath = flag.String("db", "", "/path/to/database.sqlite") databasePath = flag.String("db", "", "/path/to/database.sqlite")
certPath = flag.String("cert", "", "/path/to/certificates") keyPath = flag.String("keys", "", "/path/to/keys : path contains the keys with names matching the certificates and '.key' extensions")
apiListen = flag.String("api", "127.0.0.1:8080", "address for api listening") certPath = flag.String("certs", "", "/path/to/certificates : path contains the certificates to load in armoured PEM encoding")
httpListen = flag.String("http", "0.0.0.0:80", "address for http listening") errorPagePath = flag.String("errors", "", "/path/to/error-pages : path contains the custom error pages")
httpsListen = flag.String("https", "0.0.0.0:443", "address for https listening") apiListen = flag.String("api", "127.0.0.1:8080", "address for api listening")
httpListen = flag.String("http", "0.0.0.0:80", "address for http listening")
httpsListen = flag.String("https", "0.0.0.0:443", "address for https listening")
) )
func main() { func main() {
log.Println("[Violet] Starting...") log.Println("[Violet] Starting...")
_, err := os.Stat(*certPath) // create paths
if errors.Is(err, os.ErrNotExist) { err := os.MkdirAll(*certPath, os.ModePerm)
log.Fatalf("[Violet] Certificate path '%s' does not exists", *certPath) if err != nil {
log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath)
}
err = os.MkdirAll(*keyPath, os.ModePerm)
if err != nil {
log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath)
} }
_, err = os.Stat(*databasePath) // open sqlite database
dbExists := !errors.Is(err, os.ErrNotExist)
db, err := sql.Open("sqlite3", *databasePath) db, err := sql.Open("sqlite3", *databasePath)
if err != nil { if err != nil {
log.Fatalf("[Violet] Failed to open database '%s'...", *databasePath) log.Fatalf("[Violet] Failed to open database '%s'...", *databasePath)
} }
if !dbExists { // load allowed domains
log.Println("[Violet] Creating new database and running init.sql")
_, err = db.Exec(initSql)
if err != nil {
log.Fatalf("[Violet] Failed to run init.sql")
}
}
allowedDomains := domains.New(db) allowedDomains := domains.New(db)
// load allowed certificates
allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath))
// create reverse proxy and
reverseProxy := proxy.CreateHybridReverseProxy() reverseProxy := proxy.CreateHybridReverseProxy()
r := router.New(reverseProxy) r := router.New(reverseProxy)
servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{allowedDomains}) srvConf := &servers.Conf{
servers.NewHttpServer(*httpListen, 0, allowedDomains, db) ApiListen: *apiListen,
HttpListen: *httpListen,
HttpsListen: *httpsListen,
DB: db,
Domains: allowedDomains,
Certs: allowedCerts,
Favicons: dynamicFavicons,
Verify: apiVerify,
ErrorPages: dynamicErrorPages,
Proxy: reverseProxy,
}
if *apiListen != "" {
servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{allowedDomains})
}
if *httpListen != "" {
servers.NewHttpServer(srvConf)
}
if *httpsListen != "" {
servers.NewHttpsServer(srvConf)
}
} }

View File

@ -14,14 +14,18 @@ type Domains struct {
m map[string]struct{} m map[string]struct{}
} }
// New creates a new domain list
func New(db *sql.DB) *Domains { func New(db *sql.DB) *Domains {
return &Domains{ a := &Domains{
db: db, db: db,
s: &sync.RWMutex{}, s: &sync.RWMutex{},
m: make(map[string]struct{}), m: make(map[string]struct{}),
} }
a.Compile()
return a
} }
// IsValid returns true if a domain is valid.
func (d *Domains) IsValid(host string) bool { func (d *Domains) IsValid(host string) bool {
// remove the port // remove the port
domain, ok := utils.GetDomainWithoutPort(host) domain, ok := utils.GetDomainWithoutPort(host)
@ -34,6 +38,7 @@ func (d *Domains) IsValid(host string) bool {
defer d.s.RUnlock() defer d.s.RUnlock()
// check root domains `www.example.com`, `example.com`, `com` // check root domains `www.example.com`, `example.com`, `com`
// TODO: could be faster using indexes and cropping the string?
n := strings.Split(domain, ".") n := strings.Split(domain, ".")
for i := 0; i < len(n); i++ { for i := 0; i < len(n); i++ {
if _, ok := d.m[strings.Join(n[i:], ".")]; ok { if _, ok := d.m[strings.Join(n[i:], ".")]; ok {
@ -43,6 +48,10 @@ func (d *Domains) IsValid(host string) bool {
return false return false
} }
// Compile downloads the list of domains from the database and loads them into
// memory for faster lookups.
//
// This method is asynchronous and uses locks for safety.
func (d *Domains) Compile() { func (d *Domains) Compile() {
// async compile magic // async compile magic
go func() { go func() {
@ -59,6 +68,8 @@ func (d *Domains) Compile() {
}() }()
} }
// internalCompile is a hidden internal method for querying the database during
// the Compile() method.
func (d *Domains) internalCompile(m map[string]struct{}) error { func (d *Domains) internalCompile(m map[string]struct{}) error {
log.Println("[Domains] Updating domains from database") log.Println("[Domains] Updating domains from database")

View File

@ -0,0 +1,46 @@
package error_pages
import (
"fmt"
"net/http"
"sync"
)
// ErrorPages stores the custom error pages and is called by the servers to
// output meaningful pages for HTTP error codes
type ErrorPages struct {
s *sync.RWMutex
m map[int]func(rw http.ResponseWriter)
generic func(rw http.ResponseWriter, code int)
dir string
}
func New(dir string) *ErrorPages {
return &ErrorPages{
s: &sync.RWMutex{},
m: make(map[int]func(rw http.ResponseWriter)),
generic: func(rw http.ResponseWriter, code int) {
a := http.StatusText(code)
if a != "" {
http.Error(rw, fmt.Sprintf("%d %s\n", code, a), code)
return
}
http.Error(rw, fmt.Sprintf("%d Unknown Error Code\n", code), code)
},
dir: dir,
}
}
func (e *ErrorPages) Compile() {
}
func (e *ErrorPages) ServeError(rw http.ResponseWriter, code int) {
e.s.RLock()
defer e.s.RUnlock()
if p, ok := e.m[code]; ok {
p(rw)
return
}
e.generic(rw, code)
}

3
favicons/favicons.go Normal file
View File

@ -0,0 +1,3 @@
package favicons
type Favicons struct{}

3
go.mod
View File

@ -3,12 +3,15 @@ module github.com/MrMelon54/violet
go 1.20 go 1.20
require ( require (
code.mrmelon54.com/melon/certgen v0.0.0-20220830133534-0fb4cb7e67d1
code.mrmelon54.com/melon/summer-utils v0.0.3 code.mrmelon54.com/melon/summer-utils v0.0.3
github.com/MrMelon54/trie v0.0.2 github.com/MrMelon54/trie v0.0.2
github.com/gorilla/mux v1.8.0
github.com/julienschmidt/httprouter v1.3.0 github.com/julienschmidt/httprouter v1.3.0
github.com/mattn/go-sqlite3 v1.14.16 github.com/mattn/go-sqlite3 v1.14.16
github.com/mrmelon54/mjwt v0.0.1 github.com/mrmelon54/mjwt v0.0.1
github.com/rs/cors v1.9.0 github.com/rs/cors v1.9.0
github.com/sethvargo/go-limiter v0.7.2
github.com/stretchr/testify v1.8.2 github.com/stretchr/testify v1.8.2
golang.org/x/net v0.9.0 golang.org/x/net v0.9.0
) )

6
go.sum
View File

@ -1,3 +1,5 @@
code.mrmelon54.com/melon/certgen v0.0.0-20220830133534-0fb4cb7e67d1 h1:tll8DwvO1CL+xXJIMLyDmQYoYr/gA4BkcUFtNHB1BFo=
code.mrmelon54.com/melon/certgen v0.0.0-20220830133534-0fb4cb7e67d1/go.mod h1:Liyhe1bkNyeVfw6LicCgrQ+4oUT/w/qONLjvejkUim0=
code.mrmelon54.com/melon/summer-utils v0.0.3 h1:Bz4o5BBOqWCNGpKkxUum4rwMn/DIdyMCKGQ/D6SXD6Q= code.mrmelon54.com/melon/summer-utils v0.0.3 h1:Bz4o5BBOqWCNGpKkxUum4rwMn/DIdyMCKGQ/D6SXD6Q=
code.mrmelon54.com/melon/summer-utils v0.0.3/go.mod h1:Gh/baXSzkf1ZhHonpPP8oQkyhhmFZcC2yTMlrwclDUw= code.mrmelon54.com/melon/summer-utils v0.0.3/go.mod h1:Gh/baXSzkf1ZhHonpPP8oQkyhhmFZcC2yTMlrwclDUw=
github.com/MrMelon54/trie v0.0.2 h1:ZXWcX5ij62O9K4I/anuHmVg8L3tF0UGdlPceAASwKEY= github.com/MrMelon54/trie v0.0.2 h1:ZXWcX5ij62O9K4I/anuHmVg8L3tF0UGdlPceAASwKEY=
@ -7,6 +9,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
@ -19,6 +23,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE= github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE=
github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/sethvargo/go-limiter v0.7.2 h1:FgC4N7RMpV5gMrUdda15FaFTkQ/L4fEqM7seXMs4oO8=
github.com/sethvargo/go-limiter v0.7.2/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=

24
servers/conf.go Normal file
View File

@ -0,0 +1,24 @@
package servers
import (
"database/sql"
"github.com/MrMelon54/violet/certs"
"github.com/MrMelon54/violet/domains"
errorPages "github.com/MrMelon54/violet/error-pages"
"github.com/MrMelon54/violet/favicons"
"github.com/mrmelon54/mjwt"
"net/http/httputil"
)
type Conf struct {
ApiListen string
HttpListen string
HttpsListen string
DB *sql.DB
Domains *domains.Domains
Certs *certs.Certs
Favicons *favicons.Favicons
Verify mjwt.Provider
ErrorPages *errorPages.ErrorPages
Proxy *httputil.ReverseProxy
}

View File

@ -1,9 +1,7 @@
package servers package servers
import ( import (
"database/sql"
"fmt" "fmt"
"github.com/MrMelon54/violet/domains"
"github.com/MrMelon54/violet/utils" "github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"log" "log"
@ -17,9 +15,13 @@ import (
// //
// `/.well-known/acme-challenge/{token}` is used for outputting answers for // `/.well-known/acme-challenge/{token}` is used for outputting answers for
// acme challenges, this is used for Lets Encrypt HTTP verification. // acme challenges, this is used for Lets Encrypt HTTP verification.
func NewHttpServer(listen string, httpsPort int, domainCheck *domains.Domains, db *sql.DB) *http.Server { func NewHttpServer(conf *Conf) *http.Server {
r := httprouter.New() r := httprouter.New()
var secureExtend string var secureExtend string
_, httpsPort, ok := utils.SplitDomainPort(conf.HttpsListen, 443)
if !ok {
httpsPort = 443
}
if httpsPort != 443 { if httpsPort != 443 {
secureExtend = fmt.Sprintf(":%d", httpsPort) secureExtend = fmt.Sprintf(":%d", httpsPort)
} }
@ -28,7 +30,7 @@ func NewHttpServer(listen string, httpsPort int, domainCheck *domains.Domains, d
r.GET("/.well-known/acme-challenge/{key}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { r.GET("/.well-known/acme-challenge/{key}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
if h, ok := utils.GetDomainWithoutPort(req.Host); ok { if h, ok := utils.GetDomainWithoutPort(req.Host); ok {
// check if the host is valid // check if the host is valid
if !domainCheck.IsValid(req.Host) { if !conf.Domains.IsValid(req.Host) {
http.Error(rw, fmt.Sprintf("%d %s\n", 420, "Invalid host"), 420) http.Error(rw, fmt.Sprintf("%d %s\n", 420, "Invalid host"), 420)
return return
} }
@ -41,7 +43,7 @@ func NewHttpServer(listen string, httpsPort int, domainCheck *domains.Domains, d
} }
// prepare for executing query // prepare for executing query
prepare, err := db.Prepare("select value from acme_challenges limit 1 where domain = ? and key = ?") prepare, err := conf.DB.Prepare("select value from acme_challenges limit 1 where domain = ? and key = ?")
if err != nil { if err != nil {
utils.RespondHttpStatus(rw, http.StatusInternalServerError) utils.RespondHttpStatus(rw, http.StatusInternalServerError)
return return
@ -79,7 +81,7 @@ func NewHttpServer(listen string, httpsPort int, domainCheck *domains.Domains, d
// Create and run http server // Create and run http server
s := &http.Server{ s := &http.Server{
Addr: listen, Addr: conf.HttpListen,
Handler: r, Handler: r,
ReadTimeout: time.Minute, ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute, ReadHeaderTimeout: time.Minute,

View File

@ -1 +1,79 @@
package servers package servers
import (
"crypto/tls"
"fmt"
"github.com/MrMelon54/violet/router"
"github.com/MrMelon54/violet/utils"
"github.com/gorilla/mux"
"github.com/sethvargo/go-limiter/httplimit"
"github.com/sethvargo/go-limiter/memorystore"
"log"
"net"
"net/http"
"time"
)
// NewHttpsServer creates and runs a http server containing the public https
// endpoints for the reverse proxy.
func NewHttpsServer(conf *Conf) *http.Server {
r := router.New(conf.Proxy)
s := &http.Server{
Addr: conf.HttpsListen,
Handler: setupRateLimiter(300).Middleware(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusNotImplemented)
_, _ = rw.Write([]byte(fmt.Sprintf("%#v\n", req)))
_ = r
// TODO: serve from router and proxy
// r.ServeHTTP(rw, req)
})),
DisableGeneralOptionsHandler: false,
TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
// error out on invalid domains
if !conf.Domains.IsValid(info.ServerName) {
return nil, fmt.Errorf("invalid hostname used: '%s'", info.ServerName)
}
// find a certificate
cert := conf.Certs.GetCertForDomain(info.ServerName)
if cert == nil {
return nil, fmt.Errorf("failed to find certificate for: '%s'", info.ServerName)
}
// time to return
return cert, nil
}},
ReadTimeout: 150 * time.Second,
ReadHeaderTimeout: 150 * time.Second,
WriteTimeout: 150 * time.Second,
IdleTimeout: 150 * time.Second,
MaxHeaderBytes: 4096000,
ConnState: func(conn net.Conn, state http.ConnState) {
fmt.Printf("%s => %s: %s\n", conn.LocalAddr(), conn.RemoteAddr(), state.String())
},
}
log.Printf("[HTTPS] Starting HTTPS server on: '%s'\n", s.Addr)
go utils.RunBackgroundHttps("HTTPS", s)
return s
}
// setupRateLimiter is an internal function to create a middleware to manage
// rate limits.
func setupRateLimiter(rateLimit uint64) mux.MiddlewareFunc {
// create memory store
store, err := memorystore.New(&memorystore.Config{
Tokens: rateLimit,
Interval: time.Minute,
})
if err != nil {
log.Fatalln(err)
}
// create a middleware using ips as the key for rate limits
middleware, err := httplimit.NewMiddleware(store, httplimit.IPKeyFunc())
if err != nil {
log.Fatalln(err)
}
return middleware.Handle
}

View File

@ -1,16 +1,17 @@
package utils package utils
import ( import (
"fmt" "strconv"
"strings" "strings"
) )
func SplitDomainPort(host string, defaultPort uint16) (domain string, port uint16, ok bool) { func SplitDomainPort(host string, defaultPort int) (domain string, port int, ok bool) {
a := strings.SplitN(host, ":", 2) a := strings.SplitN(host, ":", 2)
switch len(a) { switch len(a) {
case 2: case 2:
domain = a[0] domain = a[0]
_, err := fmt.Sscanf(a[1], "%d", &port) p, err := strconv.Atoi(a[1])
port = p
ok = err == nil ok = err == nil
case 1: case 1:
domain = a[0] domain = a[0]

View File

@ -9,12 +9,12 @@ func TestSplitDomainPort(t *testing.T) {
domain, port, ok := SplitDomainPort("www.example.com:5612", 443) domain, port, ok := SplitDomainPort("www.example.com:5612", 443)
assert.True(t, ok, "Output should be true") assert.True(t, ok, "Output should be true")
assert.Equal(t, "www.example.com", domain) assert.Equal(t, "www.example.com", domain)
assert.Equal(t, uint16(5612), port) assert.Equal(t, int(5612), port)
domain, port, ok = SplitDomainPort("example.com", 443) domain, port, ok = SplitDomainPort("example.com", 443)
assert.True(t, ok, "Output should be true") assert.True(t, ok, "Output should be true")
assert.Equal(t, "example.com", domain) assert.Equal(t, "example.com", domain)
assert.Equal(t, uint16(443), port) assert.Equal(t, int(443), port)
} }
func TestDomainWithoutPort(t *testing.T) { func TestDomainWithoutPort(t *testing.T) {