Write lots of tests

This commit is contained in:
Melon 2023-06-04 22:28:48 +01:00
parent 1f487eb80c
commit afc661c62b
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
18 changed files with 581 additions and 90 deletions

View File

@ -121,6 +121,10 @@ func (c *Certs) Compile() {
// internalCompile is a hidden internal method for loading the certificate and // internalCompile is a hidden internal method for loading the certificate and
// key files // key files
func (c *Certs) internalCompile(m map[string]*tls.Certificate) error { func (c *Certs) internalCompile(m map[string]*tls.Certificate) error {
if c.cDir == nil {
return nil
}
// try to read dir // try to read dir
files, err := fs.ReadDir(c.cDir, "") files, err := fs.ReadDir(c.cDir, "")
if err != nil { if err != nil {

View File

@ -33,6 +33,7 @@ var (
httpListen = flag.String("http", "0.0.0.0:80", "address for http listening") httpListen = flag.String("http", "0.0.0.0:80", "address for http listening")
httpsListen = flag.String("https", "0.0.0.0:443", "address for https listening") httpsListen = flag.String("https", "0.0.0.0:443", "address for https listening")
inkscapeCmd = flag.String("inkscape", "inkscape", "Path to inkscape binary") inkscapeCmd = flag.String("inkscape", "inkscape", "Path to inkscape binary")
rateLimit = flag.Uint64("ratelimit", 300, "Rate limit (max requests per minute)")
) )
func main() { func main() {
@ -61,6 +62,7 @@ func main() {
} }
allowedDomains := domains.New(db) // load allowed domains allowedDomains := domains.New(db) // load allowed domains
acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store
allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath), *selfSigned) // load certificate manager allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath), *selfSigned) // load certificate manager
reverseProxy := proxy.NewHybridTransport() // load reverse proxy reverseProxy := proxy.NewHybridTransport() // load reverse proxy
dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider
@ -72,8 +74,10 @@ func main() {
ApiListen: *apiListen, ApiListen: *apiListen,
HttpListen: *httpListen, HttpListen: *httpListen,
HttpsListen: *httpsListen, HttpsListen: *httpsListen,
RateLimit: *rateLimit,
DB: db, DB: db,
Domains: allowedDomains, Domains: allowedDomains,
Acme: acmeChallenges,
Certs: allowedCerts, Certs: allowedCerts,
Favicons: dynamicFavicons, Favicons: dynamicFavicons,
Verify: nil, // TODO: add mjwt verify support Verify: nil, // TODO: add mjwt verify support
@ -84,12 +88,18 @@ func main() {
var srvApi, srvHttp, srvHttps *http.Server var srvApi, srvHttp, srvHttps *http.Server
if *apiListen != "" { if *apiListen != "" {
srvApi = servers.NewApiServer(srvConf, utils.MultiCompilable{allowedDomains, allowedCerts, dynamicFavicons, dynamicErrorPages, dynamicRouter}) srvApi = servers.NewApiServer(srvConf, utils.MultiCompilable{allowedDomains, allowedCerts, dynamicFavicons, dynamicErrorPages, dynamicRouter})
log.Printf("[API] Starting API server on: '%s'\n", srvApi.Addr)
go utils.RunBackgroundHttp("API", srvApi)
} }
if *httpListen != "" { if *httpListen != "" {
srvHttp = servers.NewHttpServer(srvConf) srvHttp = servers.NewHttpServer(srvConf)
log.Printf("[HTTP] Starting HTTP server on: '%s'\n", srvHttp.Addr)
go utils.RunBackgroundHttp("HTTP", srvHttp)
} }
if *httpsListen != "" { if *httpsListen != "" {
srvHttps = servers.NewHttpsServer(srvConf) srvHttps = servers.NewHttpsServer(srvConf)
log.Printf("[HTTPS] Starting HTTPS server on: '%s'\n", srvHttps.Addr)
go utils.RunBackgroundHttps("HTTPS", srvHttps)
} }
// Wait for exit signal // Wait for exit signal

View File

@ -82,7 +82,7 @@ func (e *ErrorPages) Compile() {
func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) error { func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) error {
// try to read dir // try to read dir
files, err := fs.ReadDir(e.dir, "") files, err := fs.ReadDir(e.dir, ".")
if err != nil { if err != nil {
return fmt.Errorf("failed to read error pages dir: %w", err) return fmt.Errorf("failed to read error pages dir: %w", err)
} }
@ -101,7 +101,7 @@ func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) err
ext := filepath.Ext(name) ext := filepath.Ext(name)
// if the extension is not 'html' then ignore the file // if the extension is not 'html' then ignore the file
if ext != "html" { if ext != ".html" {
log.Printf("[ErrorPages] WARNING: ignoring non '.html' file in error pages directory: '%s'\n", name) log.Printf("[ErrorPages] WARNING: ignoring non '.html' file in error pages directory: '%s'\n", name)
continue continue
} }

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"testing/fstest"
) )
func TestErrorPages_ServeError(t *testing.T) { func TestErrorPages_ServeError(t *testing.T) {
@ -29,3 +30,35 @@ func TestErrorPages_ServeError(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "469 Unknown Error Code\n\n", string(a)) assert.Equal(t, "469 Unknown Error Code\n\n", string(a))
} }
func TestErrorPagesWithCustom(t *testing.T) {
fs := fstest.MapFS{
"418.html": {
Data: []byte("418 Custom Error Page\n"),
},
"469.html": {
Data: []byte("469 Custom Error Page\n"),
},
}
errorPages := New(fs)
assert.NoError(t, errorPages.internalCompile(errorPages.m))
rec := httptest.NewRecorder()
errorPages.ServeError(rec, http.StatusTeapot)
res := rec.Result()
assert.Equal(t, http.StatusTeapot, res.StatusCode)
assert.Equal(t, "418 I'm a teapot", res.Status)
a, err := io.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "418 Custom Error Page\n", string(a))
rec = httptest.NewRecorder()
errorPages.ServeError(rec, 469)
res = rec.Result()
assert.Equal(t, 469, res.StatusCode)
assert.Equal(t, "469 ", res.Status)
a, err = io.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "469 Custom Error Page\n", string(a))
}

View File

@ -5,7 +5,6 @@ import (
"github.com/MrMelon54/violet/utils" "github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/mrmelon54/mjwt" "github.com/mrmelon54/mjwt"
"log"
"net/http" "net/http"
"time" "time"
) )
@ -19,22 +18,7 @@ func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server
// Endpoint for compile action // Endpoint for compile action
r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
// Get bearer token if !hasPerms(conf.Verify, req, "violet:compile") {
bearer := utils.GetBearer(req)
if bearer == "" {
utils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
// Read claims from mjwt
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](conf.Verify, bearer)
if err != nil {
utils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
// Token must have `violet:compile` perm
if !b.Claims.Perms.Has("violet:compile") {
utils.RespondHttpStatus(rw, http.StatusForbidden) utils.RespondHttpStatus(rw, http.StatusForbidden)
return return
} }
@ -44,8 +28,36 @@ func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server
rw.WriteHeader(http.StatusAccepted) rw.WriteHeader(http.StatusAccepted)
}) })
// Endpoint for acme-challenge
r.PUT("/acme-challenge/:domain/:key/:value", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
if !hasPerms(conf.Verify, req, "violet:acme-challenge") {
utils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
domain := params.ByName("domain")
if !conf.Domains.IsValid(domain) {
utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain")
return
}
conf.Acme.Put(domain, params.ByName("key"), params.ByName("value"))
rw.WriteHeader(http.StatusAccepted)
})
r.DELETE("/acme-challenge/:domain/:key", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
if !hasPerms(conf.Verify, req, "violet:acme-challenge") {
utils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
domain := params.ByName("domain")
if !conf.Domains.IsValid(domain) {
utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain")
return
}
conf.Acme.Delete(domain, params.ByName("key"))
rw.WriteHeader(http.StatusAccepted)
})
// Create and run http server // Create and run http server
s := &http.Server{ return &http.Server{
Addr: conf.ApiListen, Addr: conf.ApiListen,
Handler: r, Handler: r,
ReadTimeout: time.Minute, ReadTimeout: time.Minute,
@ -54,7 +66,21 @@ func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server
IdleTimeout: time.Minute, IdleTimeout: time.Minute,
MaxHeaderBytes: 2500, MaxHeaderBytes: 2500,
} }
log.Printf("[API] Starting API server on: '%s'\n", s.Addr) }
go utils.RunBackgroundHttp("API", s)
return s func hasPerms(verify mjwt.Provider, req *http.Request, perm string) bool {
// Get bearer token
bearer := utils.GetBearer(req)
if bearer == "" {
return false
}
// Read claims from mjwt
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer)
if err != nil {
return false
}
// Token must have perm
return b.Claims.Perms.Has(perm)
} }

164
servers/api_test.go Normal file
View File

@ -0,0 +1,164 @@
package servers
import (
"code.mrmelon54.com/melon/summer-utils/claims"
"code.mrmelon54.com/melon/summer-utils/claims/auth"
"crypto/rand"
"crypto/rsa"
"github.com/MrMelon54/violet/utils"
"github.com/mrmelon54/mjwt"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
)
var snakeOilProv = genSnakeOilProv()
type fakeDomains struct{}
func (f *fakeDomains) IsValid(host string) bool { return host == "example.com" }
func genSnakeOilProv() mjwt.Provider {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
}
return mjwt.NewMJwtSigner("violet.test", key)
}
func genSnakeOilKey(perm string) string {
p := claims.NewPermStorage()
p.Set(perm)
val, err := snakeOilProv.GenerateJwt("abc", "abc", 5*time.Minute, auth.AccessTokenClaims{
UserId: 1,
Perms: p,
})
if err != nil {
panic(err)
}
return val
}
type fakeCompilable struct{ done bool }
func (f *fakeCompilable) Compile() { f.done = true }
var _ utils.Compilable = &fakeCompilable{}
func TestNewApiServer_Compile(t *testing.T) {
apiConf := &Conf{
Domains: &fakeDomains{},
Acme: utils.NewAcmeChallenge(),
Verify: snakeOilProv,
}
f := &fakeCompilable{}
srv := NewApiServer(apiConf, utils.MultiCompilable{f})
req, err := http.NewRequest(http.MethodPost, "https://example.com/compile", nil)
assert.NoError(t, err)
rec := httptest.NewRecorder()
srv.Handler.ServeHTTP(rec, req)
res := rec.Result()
assert.Equal(t, http.StatusForbidden, res.StatusCode)
assert.False(t, f.done)
req.Header.Set("Authorization", "Bearer "+genSnakeOilKey("violet:compile"))
rec = httptest.NewRecorder()
srv.Handler.ServeHTTP(rec, req)
res = rec.Result()
assert.Equal(t, http.StatusAccepted, res.StatusCode)
assert.True(t, f.done)
}
func TestNewApiServer_AcmeChallenge_Put(t *testing.T) {
apiConf := &Conf{
Domains: &fakeDomains{},
Acme: utils.NewAcmeChallenge(),
Verify: snakeOilProv,
}
srv := NewApiServer(apiConf, utils.MultiCompilable{})
acmeKey := genSnakeOilKey("violet:acme-challenge")
// Valid domain
req, err := http.NewRequest(http.MethodPut, "https://example.com/acme-challenge/example.com/123/123abc", nil)
assert.NoError(t, err)
rec := httptest.NewRecorder()
srv.Handler.ServeHTTP(rec, req)
res := rec.Result()
assert.Equal(t, http.StatusForbidden, res.StatusCode)
req.Header.Set("Authorization", "Bearer "+acmeKey)
rec = httptest.NewRecorder()
srv.Handler.ServeHTTP(rec, req)
res = rec.Result()
assert.Equal(t, http.StatusAccepted, res.StatusCode)
assert.Equal(t, "123abc", apiConf.Acme.Get("example.com", "123"))
// Invalid domain
req, err = http.NewRequest(http.MethodPut, "https://example.com/acme-challenge/notexample.com/123/123abc", nil)
assert.NoError(t, err)
rec = httptest.NewRecorder()
srv.Handler.ServeHTTP(rec, req)
res = rec.Result()
assert.Equal(t, http.StatusForbidden, res.StatusCode)
req.Header.Set("Authorization", "Bearer "+acmeKey)
rec = httptest.NewRecorder()
srv.Handler.ServeHTTP(rec, req)
res = rec.Result()
assert.Equal(t, http.StatusBadRequest, res.StatusCode)
assert.Equal(t, "Invalid ACME challenge domain", res.Header.Get("X-Violet-Error"))
}
func TestNewApiServer_AcmeChallenge_Delete(t *testing.T) {
apiConf := &Conf{
Domains: &fakeDomains{},
Acme: utils.NewAcmeChallenge(),
Verify: snakeOilProv,
}
srv := NewApiServer(apiConf, utils.MultiCompilable{})
acmeKey := genSnakeOilKey("violet:acme-challenge")
// Valid domain
req, err := http.NewRequest(http.MethodDelete, "https://example.com/acme-challenge/example.com/123", nil)
assert.NoError(t, err)
rec := httptest.NewRecorder()
srv.Handler.ServeHTTP(rec, req)
res := rec.Result()
assert.Equal(t, http.StatusForbidden, res.StatusCode)
req.Header.Set("Authorization", "Bearer "+acmeKey)
apiConf.Acme.Put("example.com", "123", "123abc")
rec = httptest.NewRecorder()
srv.Handler.ServeHTTP(rec, req)
res = rec.Result()
assert.Equal(t, http.StatusAccepted, res.StatusCode)
assert.Equal(t, "", apiConf.Acme.Get("example.com", "123"))
// Invalid domain
req, err = http.NewRequest(http.MethodDelete, "https://example.com/acme-challenge/notexample.com/123", nil)
assert.NoError(t, err)
rec = httptest.NewRecorder()
srv.Handler.ServeHTTP(rec, req)
res = rec.Result()
assert.Equal(t, http.StatusForbidden, res.StatusCode)
req.Header.Set("Authorization", "Bearer "+acmeKey)
rec = httptest.NewRecorder()
srv.Handler.ServeHTTP(rec, req)
res = rec.Result()
assert.Equal(t, http.StatusBadRequest, res.StatusCode)
assert.Equal(t, "Invalid ACME challenge domain", res.Header.Get("X-Violet-Error"))
}

View File

@ -1,9 +1,8 @@
package servers package servers
import ( import (
"crypto/tls"
"database/sql" "database/sql"
"github.com/MrMelon54/violet/certs"
"github.com/MrMelon54/violet/domains"
errorPages "github.com/MrMelon54/violet/error-pages" errorPages "github.com/MrMelon54/violet/error-pages"
"github.com/MrMelon54/violet/favicons" "github.com/MrMelon54/violet/favicons"
"github.com/MrMelon54/violet/router" "github.com/MrMelon54/violet/router"
@ -15,11 +14,27 @@ type Conf struct {
ApiListen string // api server listen address ApiListen string // api server listen address
HttpListen string // http server listen address HttpListen string // http server listen address
HttpsListen string // https server listen address HttpsListen string // https server listen address
RateLimit uint64 // rate limit per minute
DB *sql.DB DB *sql.DB
Domains *domains.Domains Domains DomainProvider
Certs *certs.Certs Acme AcmeChallengeProvider
Certs CertProvider
Favicons *favicons.Favicons Favicons *favicons.Favicons
Verify mjwt.Provider Verify mjwt.Provider
ErrorPages *errorPages.ErrorPages ErrorPages *errorPages.ErrorPages
Router *router.Manager Router *router.Manager
} }
type DomainProvider interface {
IsValid(host string) bool
}
type AcmeChallengeProvider interface {
Get(domain, key string) string
Put(domain, key, value string)
Delete(domain, key string)
}
type CertProvider interface {
GetCertForDomain(domain string) *tls.Certificate
}

View File

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"github.com/MrMelon54/violet/utils" "github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"log"
"net/http" "net/http"
"net/url" "net/url"
"time" "time"
@ -27,47 +26,30 @@ func NewHttpServer(conf *Conf) *http.Server {
} }
// Endpoint for acme challenge outputs // Endpoint for acme challenge outputs
r.GET("/.well-known/acme-challenge/{key}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { r.GET("/.well-known/acme-challenge/:key", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
if h, ok := utils.GetDomainWithoutPort(req.Host); ok { h := utils.GetDomainWithoutPort(req.Host)
// check if the host is valid // check if the host is valid
if !conf.Domains.IsValid(req.Host) { if !conf.Domains.IsValid(req.Host) {
http.Error(rw, fmt.Sprintf("%d %s\n", 420, "Invalid host"), 420) utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid host")
return return
} }
// check if the key is valid // check if the key is valid
key := params.ByName("key") value := conf.Acme.Get(h, params.ByName("key"))
if key == "" { if value == "" {
rw.WriteHeader(http.StatusNotFound) rw.WriteHeader(http.StatusNotFound)
return return
} }
// prepare for executing query
prepare, err := conf.DB.Prepare("select value from acme_challenges limit 1 where domain = ? and key = ?")
if err != nil {
utils.RespondHttpStatus(rw, http.StatusInternalServerError)
return
}
// query the row and extract the value
row := prepare.QueryRow(h, key)
var value string
err = row.Scan(&value)
if err != nil {
utils.RespondHttpStatus(rw, http.StatusInternalServerError)
return
}
// output response // output response
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
_, _ = rw.Write([]byte(value)) _, _ = rw.Write([]byte(value))
}
rw.WriteHeader(http.StatusNotFound)
}) })
// All other paths lead here and are forwarded to HTTPS // All other paths lead here and are forwarded to HTTPS
r.NotFound = http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { r.NotFound = http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if h, ok := utils.GetDomainWithoutPort(req.Host); ok { h := utils.GetDomainWithoutPort(req.Host)
u := &url.URL{ u := &url.URL{
Scheme: "https", Scheme: "https",
Host: h + secureExtend, Host: h + secureExtend,
@ -76,11 +58,10 @@ func NewHttpServer(conf *Conf) *http.Server {
RawQuery: req.URL.RawQuery, RawQuery: req.URL.RawQuery,
} }
utils.FastRedirect(rw, req, u.String(), http.StatusPermanentRedirect) utils.FastRedirect(rw, req, u.String(), http.StatusPermanentRedirect)
}
}) })
// Create and run http server // Create and run http server
s := &http.Server{ return &http.Server{
Addr: conf.HttpListen, Addr: conf.HttpListen,
Handler: r, Handler: r,
ReadTimeout: time.Minute, ReadTimeout: time.Minute,
@ -89,7 +70,4 @@ func NewHttpServer(conf *Conf) *http.Server {
IdleTimeout: time.Minute, IdleTimeout: time.Minute,
MaxHeaderBytes: 2500, MaxHeaderBytes: 2500,
} }
log.Printf("[HTTP] Starting HTTP server on: '%s'\n", s.Addr)
go utils.RunBackgroundHttp("HTTP", s)
return s
} }

46
servers/http_test.go Normal file
View File

@ -0,0 +1,46 @@
package servers
import (
"bytes"
"github.com/MrMelon54/violet/utils"
"github.com/stretchr/testify/assert"
"io"
"net/http"
"net/http/httptest"
"testing"
)
func TestNewHttpServer_AcmeChallenge(t *testing.T) {
httpConf := &Conf{
Domains: &fakeDomains{},
Acme: utils.NewAcmeChallenge(),
Verify: snakeOilProv,
}
srv := NewHttpServer(httpConf)
httpConf.Acme.Put("example.com", "456", "456def")
req, err := http.NewRequest(http.MethodGet, "https://example.com/.well-known/acme-challenge/456", nil)
assert.NoError(t, err)
rec := httptest.NewRecorder()
srv.Handler.ServeHTTP(rec, req)
res := rec.Result()
assert.Equal(t, http.StatusOK, res.StatusCode)
all, err := io.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, 0, bytes.Compare([]byte("456def"), all))
// Invalid key
req, err = http.NewRequest(http.MethodGet, "https://example.com/.well-known/acme-challenge/789", nil)
assert.NoError(t, err)
rec = httptest.NewRecorder()
srv.Handler.ServeHTTP(rec, req)
res = rec.Result()
assert.Equal(t, http.StatusNotFound, res.StatusCode)
all, err = io.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, 0, bytes.Compare([]byte(""), all))
}

View File

@ -16,10 +16,9 @@ import (
// NewHttpsServer creates and runs a http server containing the public https // NewHttpsServer creates and runs a http server containing the public https
// endpoints for the reverse proxy. // endpoints for the reverse proxy.
func NewHttpsServer(conf *Conf) *http.Server { func NewHttpsServer(conf *Conf) *http.Server {
s := &http.Server{ return &http.Server{
Addr: conf.HttpsListen, Addr: conf.HttpsListen,
Handler: setupRateLimiter(300, setupFaviconMiddleware(conf.Favicons, conf.Router)), Handler: setupRateLimiter(conf.RateLimit, setupFaviconMiddleware(conf.Favicons, conf.Router)),
DisableGeneralOptionsHandler: false,
TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
// error out on invalid domains // error out on invalid domains
if !conf.Domains.IsValid(info.ServerName) { if !conf.Domains.IsValid(info.ServerName) {
@ -41,12 +40,9 @@ func NewHttpsServer(conf *Conf) *http.Server {
IdleTimeout: 150 * time.Second, IdleTimeout: 150 * time.Second,
MaxHeaderBytes: 4096000, MaxHeaderBytes: 4096000,
ConnState: func(conn net.Conn, state http.ConnState) { ConnState: func(conn net.Conn, state http.ConnState) {
fmt.Printf("%s => %s: %s\n", conn.LocalAddr(), conn.RemoteAddr(), state.String()) fmt.Printf("[HTTPS] %s => %s: %s\n", conn.LocalAddr(), conn.RemoteAddr(), state.String())
}, },
} }
log.Printf("[HTTPS] Starting HTTPS server on: '%s'\n", s.Addr)
go utils.RunBackgroundHttps("HTTPS", s)
return s
} }
// setupRateLimiter is an internal function to create a middleware to manage // setupRateLimiter is an internal function to create a middleware to manage

59
servers/https_test.go Normal file
View File

@ -0,0 +1,59 @@
package servers
import (
"database/sql"
"github.com/MrMelon54/violet/certs"
"github.com/MrMelon54/violet/proxy"
"github.com/MrMelon54/violet/router"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"sync"
"testing"
)
type fakeTransport struct{}
func (f *fakeTransport) RoundTrip(_ *http.Request) (*http.Response, error) {
rec := httptest.NewRecorder()
rec.WriteHeader(http.StatusOK)
return rec.Result(), nil
}
func TestNewHttpsServer_RateLimit(t *testing.T) {
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
assert.NoError(t, err)
ft := &fakeTransport{}
httpsConf := &Conf{
RateLimit: 5,
Domains: &fakeDomains{},
Certs: certs.New(nil, nil, true),
Verify: snakeOilProv,
Router: router.NewManager(db, proxy.NewHybridTransportWithCalls(ft, ft)),
}
srv := NewHttpsServer(httpsConf)
req, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
req.RemoteAddr = "127.0.0.1:1447"
assert.NoError(t, err)
wg := &sync.WaitGroup{}
wg.Add(5)
for i := 0; i < 5; i++ {
go func() {
defer wg.Done()
rec := httptest.NewRecorder()
srv.Handler.ServeHTTP(rec, req)
res := rec.Result()
assert.Equal(t, http.StatusOK, res.StatusCode)
}()
}
wg.Wait()
rec := httptest.NewRecorder()
srv.Handler.ServeHTTP(rec, req)
res := rec.Result()
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
}

55
utils/acme-challenges.go Normal file
View File

@ -0,0 +1,55 @@
package utils
import "sync"
type AcmeChallenges struct {
s *sync.RWMutex
d map[string]*AcmeStorage
}
type AcmeStorage struct {
s *sync.RWMutex
v map[string]string
}
func NewAcmeChallenge() *AcmeChallenges {
return &AcmeChallenges{
s: &sync.RWMutex{},
d: make(map[string]*AcmeStorage),
}
}
func (a *AcmeChallenges) Get(domain, key string) string {
a.s.RLock()
defer a.s.RUnlock()
if m := a.d[domain]; m != nil {
m.s.RLock()
defer m.s.RUnlock()
return m.v[key]
}
return ""
}
func (a *AcmeChallenges) Put(domain, key, value string) {
a.s.Lock()
m := a.d[domain]
if m == nil {
m = &AcmeStorage{
s: &sync.RWMutex{},
v: make(map[string]string),
}
a.d[domain] = m
}
m.s.Lock()
m.v[key] = value
m.s.Unlock()
a.s.Unlock()
}
func (a *AcmeChallenges) Delete(domain, key string) {
a.s.Lock()
if m := a.d[domain]; m != nil {
delete(m.v, key)
}
a.s.Unlock()
}

View File

@ -0,0 +1,27 @@
package utils
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestAcmeChallenges(t *testing.T) {
a := NewAcmeChallenge()
assert.Equal(t, "", a.Get("example.com", "123"))
// The challenge should be created
a.Put("example.com", "123", "123abc")
assert.Equal(t, "123abc", a.Get("example.com", "123"))
// The challenge should be deleted
a.Delete("example.com", "123")
assert.Equal(t, "", a.Get("example.com", "123"))
// This should not crash or stop execution
a.Delete("example.com", "123")
assert.Equal(t, "", a.Get("example.com", "123"))
// This should not crash or stop execution
a.Delete("www.example.com", "123")
assert.Equal(t, "", a.Get("example.com", "123"))
}

View File

@ -27,13 +27,13 @@ func SplitDomainPort(host string, defaultPort int) (domain string, port int, ok
// without the port. // without the port.
// //
// example.com:443 => example.com // example.com:443 => example.com
func GetDomainWithoutPort(domain string) (string, bool) { func GetDomainWithoutPort(domain string) string {
// if a valid index isn't found then return false // if a valid index isn't found then return false
n := strings.LastIndexByte(domain, ':') n := strings.LastIndexByte(domain, ':')
if n == -1 { if n == -1 {
return "", false return domain
} }
return domain[:n], true return domain[:n]
} }
// ReplaceSubdomainWithWildcard returns the domain with the subdomain replaced // ReplaceSubdomainWithWildcard returns the domain with the subdomain replaced

View File

@ -18,12 +18,16 @@ func TestSplitDomainPort(t *testing.T) {
} }
func TestDomainWithoutPort(t *testing.T) { func TestDomainWithoutPort(t *testing.T) {
domain, ok := GetDomainWithoutPort("www.example.com:5612") domain := GetDomainWithoutPort("www.example.com:5612")
assert.True(t, ok, "Output should be true")
assert.Equal(t, "www.example.com", domain) assert.Equal(t, "www.example.com", domain)
domain, ok = GetDomainWithoutPort("example.com:443") domain = GetDomainWithoutPort("example.com:443")
assert.True(t, ok, "Output should be true") assert.Equal(t, "example.com", domain)
domain = GetDomainWithoutPort("www.example.com")
assert.Equal(t, "www.example.com", domain)
domain = GetDomainWithoutPort("example.com")
assert.Equal(t, "example.com", domain) assert.Equal(t, "example.com", domain)
} }

View File

@ -0,0 +1,22 @@
package utils
import (
"github.com/stretchr/testify/assert"
"testing"
)
type fakeCompile struct{ done bool }
func (f *fakeCompile) Compile() {
f.done = true
}
var _ Compilable = &fakeCompile{}
func TestMultiCompilable_Compile(t *testing.T) {
f := &fakeCompile{}
a := MultiCompilable{f}
assert.False(t, f.done)
a.Compile()
assert.True(t, f.done)
}

32
utils/response_test.go Normal file
View File

@ -0,0 +1,32 @@
package utils
import (
"github.com/stretchr/testify/assert"
"io"
"net/http"
"net/http/httptest"
"testing"
)
func TestRespondHttpStatus(t *testing.T) {
rec := httptest.NewRecorder()
RespondHttpStatus(rec, http.StatusTeapot)
res := rec.Result()
assert.Equal(t, http.StatusTeapot, res.StatusCode)
assert.Equal(t, "418 I'm a teapot", res.Status)
a, err := io.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "418 I'm a teapot\n\n", string(a))
}
func TestRespondVioletError(t *testing.T) {
rec := httptest.NewRecorder()
RespondVioletError(rec, http.StatusTeapot, "Hidden Error Message")
res := rec.Result()
assert.Equal(t, http.StatusTeapot, res.StatusCode)
assert.Equal(t, "418 I'm a teapot", res.Status)
a, err := io.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "418 I'm a teapot\n\n", string(a))
assert.Equal(t, "Hidden Error Message", res.Header.Get("X-Violet-Error"))
}

View File

@ -0,0 +1,20 @@
package utils
import (
"github.com/stretchr/testify/assert"
"net/http"
"testing"
)
func TestGetBearer(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://example.com", nil)
assert.NoError(t, err)
req.Header.Set("Authorization", "Bearer abc")
assert.Equal(t, "abc", GetBearer(req))
}
func TestGetBearer_Empty(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://example.com", nil)
assert.NoError(t, err)
assert.Equal(t, "", GetBearer(req))
}