mirror of
https://github.com/1f349/violet.git
synced 2024-11-24 12:21:33 +00:00
Server config and comments
This commit is contained in:
parent
a4eab71e33
commit
371a150c2a
@ -1,7 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="SqlDialectMappings">
|
||||
<file url="file://$PROJECT_DIR$/cmd/violet/init.sql" dialect="GenericSQL" />
|
||||
<file url="PROJECT" dialect="SQLite" />
|
||||
</component>
|
||||
</project>
|
120
certs/certs.go
Normal file
120
certs/certs.go
Normal file
@ -0,0 +1,120 @@
|
||||
package certs
|
||||
|
||||
import (
|
||||
"code.mrmelon54.com/melon/certgen"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"io/fs"
|
||||
"log"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Certs struct {
|
||||
cDir fs.FS
|
||||
kDir fs.FS
|
||||
s *sync.RWMutex
|
||||
m map[string]*tls.Certificate
|
||||
}
|
||||
|
||||
func New(certDir fs.FS, keyDir fs.FS) *Certs {
|
||||
a := &Certs{
|
||||
cDir: certDir,
|
||||
kDir: keyDir,
|
||||
s: &sync.RWMutex{},
|
||||
m: make(map[string]*tls.Certificate),
|
||||
}
|
||||
a.Compile()
|
||||
return a
|
||||
}
|
||||
|
||||
func (c *Certs) GetCertForDomain(domain string) *tls.Certificate {
|
||||
// safety read lock
|
||||
c.s.RLock()
|
||||
defer c.s.RUnlock()
|
||||
|
||||
// lookup and return cert
|
||||
if cert, ok := c.m[domain]; ok {
|
||||
return cert
|
||||
}
|
||||
|
||||
// lookup and return wildcard cert
|
||||
if wildcardDomain, ok := utils.ReplaceSubdomainWithWildcard(domain); ok {
|
||||
if cert, ok := c.m[wildcardDomain]; ok {
|
||||
return cert
|
||||
}
|
||||
}
|
||||
|
||||
// no cert found
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Certs) Compile() {
|
||||
// async compile magic
|
||||
go func() {
|
||||
certMap := make(map[string]*tls.Certificate)
|
||||
err := c.internalCompile(certMap)
|
||||
if err != nil {
|
||||
log.Printf("[Certs] Compile failed: %s\n", err)
|
||||
return
|
||||
}
|
||||
// lock while replacing the map
|
||||
c.s.Lock()
|
||||
c.m = certMap
|
||||
c.s.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *Certs) internalCompile(m map[string]*tls.Certificate) error {
|
||||
// try to read dir
|
||||
files, err := fs.ReadDir(c.cDir, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read cert dir: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[Certs] Compiling lookup table for %d certificates\n", len(files))
|
||||
|
||||
// find and parse certs
|
||||
for _, i := range files {
|
||||
// skip dirs
|
||||
if i.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// get file name and extension
|
||||
name := i.Name()
|
||||
ext := filepath.Ext(name)
|
||||
keyName := name[:len(name)-len(ext)] + "key"
|
||||
|
||||
// try to read cert file
|
||||
certData, err := fs.ReadFile(c.cDir, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read cert file '%s': %w", name, err)
|
||||
}
|
||||
|
||||
// try to read key file
|
||||
keyData, err := fs.ReadFile(c.kDir, keyName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read key file '%s': %w", keyName, err)
|
||||
}
|
||||
|
||||
// load key pair
|
||||
pair, err := tls.X509KeyPair(certData, keyData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load x509 key pair '%s + %s': %w", name, keyName, err)
|
||||
}
|
||||
|
||||
// load tls leaf
|
||||
cert := &pair
|
||||
leaf := certgen.TlsLeaf(cert)
|
||||
|
||||
// save in map under each dns name
|
||||
for _, j := range leaf.DNSNames {
|
||||
m[j] = cert
|
||||
}
|
||||
}
|
||||
|
||||
// well no errors happened
|
||||
return nil
|
||||
}
|
@ -1 +0,0 @@
|
||||
create table acme_challenge (id integer not null primary key, key varchar, value varchar);
|
@ -3,8 +3,8 @@ package main
|
||||
import (
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"flag"
|
||||
"github.com/MrMelon54/violet/certs"
|
||||
"github.com/MrMelon54/violet/domains"
|
||||
"github.com/MrMelon54/violet/proxy"
|
||||
"github.com/MrMelon54/violet/router"
|
||||
@ -15,12 +15,11 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
//go:embed init.sql
|
||||
var initSql string
|
||||
|
||||
var (
|
||||
databasePath = flag.String("db", "", "/path/to/database.sqlite")
|
||||
certPath = flag.String("cert", "", "/path/to/certificates")
|
||||
keyPath = flag.String("keys", "", "/path/to/keys : path contains the keys with names matching the certificates and '.key' extensions")
|
||||
certPath = flag.String("certs", "", "/path/to/certificates : path contains the certificates to load in armoured PEM encoding")
|
||||
errorPagePath = flag.String("errors", "", "/path/to/error-pages : path contains the custom error pages")
|
||||
apiListen = flag.String("api", "127.0.0.1:8080", "address for api listening")
|
||||
httpListen = flag.String("http", "0.0.0.0:80", "address for http listening")
|
||||
httpsListen = flag.String("https", "0.0.0.0:443", "address for https listening")
|
||||
@ -29,31 +28,52 @@ var (
|
||||
func main() {
|
||||
log.Println("[Violet] Starting...")
|
||||
|
||||
_, err := os.Stat(*certPath)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
log.Fatalf("[Violet] Certificate path '%s' does not exists", *certPath)
|
||||
// create paths
|
||||
err := os.MkdirAll(*certPath, os.ModePerm)
|
||||
if err != nil {
|
||||
log.Fatalf("[Violet] Failed to create certificate path '%s' does not exist", *certPath)
|
||||
}
|
||||
err = os.MkdirAll(*keyPath, os.ModePerm)
|
||||
if err != nil {
|
||||
log.Fatalf("[Violet] Failed to create certificate key path '%s' does not exist", *keyPath)
|
||||
}
|
||||
|
||||
_, err = os.Stat(*databasePath)
|
||||
dbExists := !errors.Is(err, os.ErrNotExist)
|
||||
|
||||
// open sqlite database
|
||||
db, err := sql.Open("sqlite3", *databasePath)
|
||||
if err != nil {
|
||||
log.Fatalf("[Violet] Failed to open database '%s'...", *databasePath)
|
||||
}
|
||||
|
||||
if !dbExists {
|
||||
log.Println("[Violet] Creating new database and running init.sql")
|
||||
_, err = db.Exec(initSql)
|
||||
if err != nil {
|
||||
log.Fatalf("[Violet] Failed to run init.sql")
|
||||
}
|
||||
}
|
||||
|
||||
// load allowed domains
|
||||
allowedDomains := domains.New(db)
|
||||
|
||||
// load allowed certificates
|
||||
allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath))
|
||||
|
||||
// create reverse proxy and
|
||||
reverseProxy := proxy.CreateHybridReverseProxy()
|
||||
r := router.New(reverseProxy)
|
||||
|
||||
srvConf := &servers.Conf{
|
||||
ApiListen: *apiListen,
|
||||
HttpListen: *httpListen,
|
||||
HttpsListen: *httpsListen,
|
||||
DB: db,
|
||||
Domains: allowedDomains,
|
||||
Certs: allowedCerts,
|
||||
Favicons: dynamicFavicons,
|
||||
Verify: apiVerify,
|
||||
ErrorPages: dynamicErrorPages,
|
||||
Proxy: reverseProxy,
|
||||
}
|
||||
|
||||
if *apiListen != "" {
|
||||
servers.NewApiServer(*apiListen, nil, utils.MultiCompilable{allowedDomains})
|
||||
servers.NewHttpServer(*httpListen, 0, allowedDomains, db)
|
||||
}
|
||||
if *httpListen != "" {
|
||||
servers.NewHttpServer(srvConf)
|
||||
}
|
||||
if *httpsListen != "" {
|
||||
servers.NewHttpsServer(srvConf)
|
||||
}
|
||||
}
|
||||
|
@ -14,14 +14,18 @@ type Domains struct {
|
||||
m map[string]struct{}
|
||||
}
|
||||
|
||||
// New creates a new domain list
|
||||
func New(db *sql.DB) *Domains {
|
||||
return &Domains{
|
||||
a := &Domains{
|
||||
db: db,
|
||||
s: &sync.RWMutex{},
|
||||
m: make(map[string]struct{}),
|
||||
}
|
||||
a.Compile()
|
||||
return a
|
||||
}
|
||||
|
||||
// IsValid returns true if a domain is valid.
|
||||
func (d *Domains) IsValid(host string) bool {
|
||||
// remove the port
|
||||
domain, ok := utils.GetDomainWithoutPort(host)
|
||||
@ -34,6 +38,7 @@ func (d *Domains) IsValid(host string) bool {
|
||||
defer d.s.RUnlock()
|
||||
|
||||
// check root domains `www.example.com`, `example.com`, `com`
|
||||
// TODO: could be faster using indexes and cropping the string?
|
||||
n := strings.Split(domain, ".")
|
||||
for i := 0; i < len(n); i++ {
|
||||
if _, ok := d.m[strings.Join(n[i:], ".")]; ok {
|
||||
@ -43,6 +48,10 @@ func (d *Domains) IsValid(host string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Compile downloads the list of domains from the database and loads them into
|
||||
// memory for faster lookups.
|
||||
//
|
||||
// This method is asynchronous and uses locks for safety.
|
||||
func (d *Domains) Compile() {
|
||||
// async compile magic
|
||||
go func() {
|
||||
@ -59,6 +68,8 @@ func (d *Domains) Compile() {
|
||||
}()
|
||||
}
|
||||
|
||||
// internalCompile is a hidden internal method for querying the database during
|
||||
// the Compile() method.
|
||||
func (d *Domains) internalCompile(m map[string]struct{}) error {
|
||||
log.Println("[Domains] Updating domains from database")
|
||||
|
||||
|
46
error-pages/error-pages.go
Normal file
46
error-pages/error-pages.go
Normal file
@ -0,0 +1,46 @@
|
||||
package error_pages
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ErrorPages stores the custom error pages and is called by the servers to
|
||||
// output meaningful pages for HTTP error codes
|
||||
type ErrorPages struct {
|
||||
s *sync.RWMutex
|
||||
m map[int]func(rw http.ResponseWriter)
|
||||
generic func(rw http.ResponseWriter, code int)
|
||||
dir string
|
||||
}
|
||||
|
||||
func New(dir string) *ErrorPages {
|
||||
return &ErrorPages{
|
||||
s: &sync.RWMutex{},
|
||||
m: make(map[int]func(rw http.ResponseWriter)),
|
||||
generic: func(rw http.ResponseWriter, code int) {
|
||||
a := http.StatusText(code)
|
||||
if a != "" {
|
||||
http.Error(rw, fmt.Sprintf("%d %s\n", code, a), code)
|
||||
return
|
||||
}
|
||||
http.Error(rw, fmt.Sprintf("%d Unknown Error Code\n", code), code)
|
||||
},
|
||||
dir: dir,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ErrorPages) Compile() {
|
||||
|
||||
}
|
||||
|
||||
func (e *ErrorPages) ServeError(rw http.ResponseWriter, code int) {
|
||||
e.s.RLock()
|
||||
defer e.s.RUnlock()
|
||||
if p, ok := e.m[code]; ok {
|
||||
p(rw)
|
||||
return
|
||||
}
|
||||
e.generic(rw, code)
|
||||
}
|
3
favicons/favicons.go
Normal file
3
favicons/favicons.go
Normal file
@ -0,0 +1,3 @@
|
||||
package favicons
|
||||
|
||||
type Favicons struct{}
|
3
go.mod
3
go.mod
@ -3,12 +3,15 @@ module github.com/MrMelon54/violet
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
code.mrmelon54.com/melon/certgen v0.0.0-20220830133534-0fb4cb7e67d1
|
||||
code.mrmelon54.com/melon/summer-utils v0.0.3
|
||||
github.com/MrMelon54/trie v0.0.2
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/julienschmidt/httprouter v1.3.0
|
||||
github.com/mattn/go-sqlite3 v1.14.16
|
||||
github.com/mrmelon54/mjwt v0.0.1
|
||||
github.com/rs/cors v1.9.0
|
||||
github.com/sethvargo/go-limiter v0.7.2
|
||||
github.com/stretchr/testify v1.8.2
|
||||
golang.org/x/net v0.9.0
|
||||
)
|
||||
|
6
go.sum
6
go.sum
@ -1,3 +1,5 @@
|
||||
code.mrmelon54.com/melon/certgen v0.0.0-20220830133534-0fb4cb7e67d1 h1:tll8DwvO1CL+xXJIMLyDmQYoYr/gA4BkcUFtNHB1BFo=
|
||||
code.mrmelon54.com/melon/certgen v0.0.0-20220830133534-0fb4cb7e67d1/go.mod h1:Liyhe1bkNyeVfw6LicCgrQ+4oUT/w/qONLjvejkUim0=
|
||||
code.mrmelon54.com/melon/summer-utils v0.0.3 h1:Bz4o5BBOqWCNGpKkxUum4rwMn/DIdyMCKGQ/D6SXD6Q=
|
||||
code.mrmelon54.com/melon/summer-utils v0.0.3/go.mod h1:Gh/baXSzkf1ZhHonpPP8oQkyhhmFZcC2yTMlrwclDUw=
|
||||
github.com/MrMelon54/trie v0.0.2 h1:ZXWcX5ij62O9K4I/anuHmVg8L3tF0UGdlPceAASwKEY=
|
||||
@ -7,6 +9,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
|
||||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
||||
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
|
||||
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
||||
github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
|
||||
@ -19,6 +23,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE=
|
||||
github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
|
||||
github.com/sethvargo/go-limiter v0.7.2 h1:FgC4N7RMpV5gMrUdda15FaFTkQ/L4fEqM7seXMs4oO8=
|
||||
github.com/sethvargo/go-limiter v0.7.2/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
|
24
servers/conf.go
Normal file
24
servers/conf.go
Normal file
@ -0,0 +1,24 @@
|
||||
package servers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/MrMelon54/violet/certs"
|
||||
"github.com/MrMelon54/violet/domains"
|
||||
errorPages "github.com/MrMelon54/violet/error-pages"
|
||||
"github.com/MrMelon54/violet/favicons"
|
||||
"github.com/mrmelon54/mjwt"
|
||||
"net/http/httputil"
|
||||
)
|
||||
|
||||
type Conf struct {
|
||||
ApiListen string
|
||||
HttpListen string
|
||||
HttpsListen string
|
||||
DB *sql.DB
|
||||
Domains *domains.Domains
|
||||
Certs *certs.Certs
|
||||
Favicons *favicons.Favicons
|
||||
Verify mjwt.Provider
|
||||
ErrorPages *errorPages.ErrorPages
|
||||
Proxy *httputil.ReverseProxy
|
||||
}
|
@ -1,9 +1,7 @@
|
||||
package servers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/MrMelon54/violet/domains"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"log"
|
||||
@ -17,9 +15,13 @@ import (
|
||||
//
|
||||
// `/.well-known/acme-challenge/{token}` is used for outputting answers for
|
||||
// acme challenges, this is used for Lets Encrypt HTTP verification.
|
||||
func NewHttpServer(listen string, httpsPort int, domainCheck *domains.Domains, db *sql.DB) *http.Server {
|
||||
func NewHttpServer(conf *Conf) *http.Server {
|
||||
r := httprouter.New()
|
||||
var secureExtend string
|
||||
_, httpsPort, ok := utils.SplitDomainPort(conf.HttpsListen, 443)
|
||||
if !ok {
|
||||
httpsPort = 443
|
||||
}
|
||||
if httpsPort != 443 {
|
||||
secureExtend = fmt.Sprintf(":%d", httpsPort)
|
||||
}
|
||||
@ -28,7 +30,7 @@ func NewHttpServer(listen string, httpsPort int, domainCheck *domains.Domains, d
|
||||
r.GET("/.well-known/acme-challenge/{key}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
if h, ok := utils.GetDomainWithoutPort(req.Host); ok {
|
||||
// check if the host is valid
|
||||
if !domainCheck.IsValid(req.Host) {
|
||||
if !conf.Domains.IsValid(req.Host) {
|
||||
http.Error(rw, fmt.Sprintf("%d %s\n", 420, "Invalid host"), 420)
|
||||
return
|
||||
}
|
||||
@ -41,7 +43,7 @@ func NewHttpServer(listen string, httpsPort int, domainCheck *domains.Domains, d
|
||||
}
|
||||
|
||||
// prepare for executing query
|
||||
prepare, err := db.Prepare("select value from acme_challenges limit 1 where domain = ? and key = ?")
|
||||
prepare, err := conf.DB.Prepare("select value from acme_challenges limit 1 where domain = ? and key = ?")
|
||||
if err != nil {
|
||||
utils.RespondHttpStatus(rw, http.StatusInternalServerError)
|
||||
return
|
||||
@ -79,7 +81,7 @@ func NewHttpServer(listen string, httpsPort int, domainCheck *domains.Domains, d
|
||||
|
||||
// Create and run http server
|
||||
s := &http.Server{
|
||||
Addr: listen,
|
||||
Addr: conf.HttpListen,
|
||||
Handler: r,
|
||||
ReadTimeout: time.Minute,
|
||||
ReadHeaderTimeout: time.Minute,
|
||||
|
@ -1 +1,79 @@
|
||||
package servers
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/MrMelon54/violet/router"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/sethvargo/go-limiter/httplimit"
|
||||
"github.com/sethvargo/go-limiter/memorystore"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewHttpsServer creates and runs a http server containing the public https
|
||||
// endpoints for the reverse proxy.
|
||||
func NewHttpsServer(conf *Conf) *http.Server {
|
||||
r := router.New(conf.Proxy)
|
||||
|
||||
s := &http.Server{
|
||||
Addr: conf.HttpsListen,
|
||||
Handler: setupRateLimiter(300).Middleware(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusNotImplemented)
|
||||
_, _ = rw.Write([]byte(fmt.Sprintf("%#v\n", req)))
|
||||
_ = r
|
||||
// TODO: serve from router and proxy
|
||||
// r.ServeHTTP(rw, req)
|
||||
})),
|
||||
DisableGeneralOptionsHandler: false,
|
||||
TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
// error out on invalid domains
|
||||
if !conf.Domains.IsValid(info.ServerName) {
|
||||
return nil, fmt.Errorf("invalid hostname used: '%s'", info.ServerName)
|
||||
}
|
||||
|
||||
// find a certificate
|
||||
cert := conf.Certs.GetCertForDomain(info.ServerName)
|
||||
if cert == nil {
|
||||
return nil, fmt.Errorf("failed to find certificate for: '%s'", info.ServerName)
|
||||
}
|
||||
|
||||
// time to return
|
||||
return cert, nil
|
||||
}},
|
||||
ReadTimeout: 150 * time.Second,
|
||||
ReadHeaderTimeout: 150 * time.Second,
|
||||
WriteTimeout: 150 * time.Second,
|
||||
IdleTimeout: 150 * time.Second,
|
||||
MaxHeaderBytes: 4096000,
|
||||
ConnState: func(conn net.Conn, state http.ConnState) {
|
||||
fmt.Printf("%s => %s: %s\n", conn.LocalAddr(), conn.RemoteAddr(), state.String())
|
||||
},
|
||||
}
|
||||
log.Printf("[HTTPS] Starting HTTPS server on: '%s'\n", s.Addr)
|
||||
go utils.RunBackgroundHttps("HTTPS", s)
|
||||
return s
|
||||
}
|
||||
|
||||
// setupRateLimiter is an internal function to create a middleware to manage
|
||||
// rate limits.
|
||||
func setupRateLimiter(rateLimit uint64) mux.MiddlewareFunc {
|
||||
// create memory store
|
||||
store, err := memorystore.New(&memorystore.Config{
|
||||
Tokens: rateLimit,
|
||||
Interval: time.Minute,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
// create a middleware using ips as the key for rate limits
|
||||
middleware, err := httplimit.NewMiddleware(store, httplimit.IPKeyFunc())
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
return middleware.Handle
|
||||
}
|
||||
|
@ -1,16 +1,17 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func SplitDomainPort(host string, defaultPort uint16) (domain string, port uint16, ok bool) {
|
||||
func SplitDomainPort(host string, defaultPort int) (domain string, port int, ok bool) {
|
||||
a := strings.SplitN(host, ":", 2)
|
||||
switch len(a) {
|
||||
case 2:
|
||||
domain = a[0]
|
||||
_, err := fmt.Sscanf(a[1], "%d", &port)
|
||||
p, err := strconv.Atoi(a[1])
|
||||
port = p
|
||||
ok = err == nil
|
||||
case 1:
|
||||
domain = a[0]
|
||||
|
@ -9,12 +9,12 @@ func TestSplitDomainPort(t *testing.T) {
|
||||
domain, port, ok := SplitDomainPort("www.example.com:5612", 443)
|
||||
assert.True(t, ok, "Output should be true")
|
||||
assert.Equal(t, "www.example.com", domain)
|
||||
assert.Equal(t, uint16(5612), port)
|
||||
assert.Equal(t, int(5612), port)
|
||||
|
||||
domain, port, ok = SplitDomainPort("example.com", 443)
|
||||
assert.True(t, ok, "Output should be true")
|
||||
assert.Equal(t, "example.com", domain)
|
||||
assert.Equal(t, uint16(443), port)
|
||||
assert.Equal(t, int(443), port)
|
||||
}
|
||||
|
||||
func TestDomainWithoutPort(t *testing.T) {
|
||||
|
Loading…
Reference in New Issue
Block a user