mirror of
https://github.com/1f349/violet.git
synced 2024-11-21 19:01:39 +00:00
Write lots of tests
This commit is contained in:
parent
1f487eb80c
commit
afc661c62b
@ -121,6 +121,10 @@ func (c *Certs) Compile() {
|
||||
// internalCompile is a hidden internal method for loading the certificate and
|
||||
// key files
|
||||
func (c *Certs) internalCompile(m map[string]*tls.Certificate) error {
|
||||
if c.cDir == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// try to read dir
|
||||
files, err := fs.ReadDir(c.cDir, "")
|
||||
if err != nil {
|
||||
|
@ -33,6 +33,7 @@ var (
|
||||
httpListen = flag.String("http", "0.0.0.0:80", "address for http listening")
|
||||
httpsListen = flag.String("https", "0.0.0.0:443", "address for https listening")
|
||||
inkscapeCmd = flag.String("inkscape", "inkscape", "Path to inkscape binary")
|
||||
rateLimit = flag.Uint64("ratelimit", 300, "Rate limit (max requests per minute)")
|
||||
)
|
||||
|
||||
func main() {
|
||||
@ -61,6 +62,7 @@ func main() {
|
||||
}
|
||||
|
||||
allowedDomains := domains.New(db) // load allowed domains
|
||||
acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store
|
||||
allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath), *selfSigned) // load certificate manager
|
||||
reverseProxy := proxy.NewHybridTransport() // load reverse proxy
|
||||
dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider
|
||||
@ -72,8 +74,10 @@ func main() {
|
||||
ApiListen: *apiListen,
|
||||
HttpListen: *httpListen,
|
||||
HttpsListen: *httpsListen,
|
||||
RateLimit: *rateLimit,
|
||||
DB: db,
|
||||
Domains: allowedDomains,
|
||||
Acme: acmeChallenges,
|
||||
Certs: allowedCerts,
|
||||
Favicons: dynamicFavicons,
|
||||
Verify: nil, // TODO: add mjwt verify support
|
||||
@ -84,12 +88,18 @@ func main() {
|
||||
var srvApi, srvHttp, srvHttps *http.Server
|
||||
if *apiListen != "" {
|
||||
srvApi = servers.NewApiServer(srvConf, utils.MultiCompilable{allowedDomains, allowedCerts, dynamicFavicons, dynamicErrorPages, dynamicRouter})
|
||||
log.Printf("[API] Starting API server on: '%s'\n", srvApi.Addr)
|
||||
go utils.RunBackgroundHttp("API", srvApi)
|
||||
}
|
||||
if *httpListen != "" {
|
||||
srvHttp = servers.NewHttpServer(srvConf)
|
||||
log.Printf("[HTTP] Starting HTTP server on: '%s'\n", srvHttp.Addr)
|
||||
go utils.RunBackgroundHttp("HTTP", srvHttp)
|
||||
}
|
||||
if *httpsListen != "" {
|
||||
srvHttps = servers.NewHttpsServer(srvConf)
|
||||
log.Printf("[HTTPS] Starting HTTPS server on: '%s'\n", srvHttps.Addr)
|
||||
go utils.RunBackgroundHttps("HTTPS", srvHttps)
|
||||
}
|
||||
|
||||
// Wait for exit signal
|
||||
|
@ -82,7 +82,7 @@ func (e *ErrorPages) Compile() {
|
||||
|
||||
func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) error {
|
||||
// try to read dir
|
||||
files, err := fs.ReadDir(e.dir, "")
|
||||
files, err := fs.ReadDir(e.dir, ".")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read error pages dir: %w", err)
|
||||
}
|
||||
@ -101,7 +101,7 @@ func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) err
|
||||
ext := filepath.Ext(name)
|
||||
|
||||
// if the extension is not 'html' then ignore the file
|
||||
if ext != "html" {
|
||||
if ext != ".html" {
|
||||
log.Printf("[ErrorPages] WARNING: ignoring non '.html' file in error pages directory: '%s'\n", name)
|
||||
continue
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
)
|
||||
|
||||
func TestErrorPages_ServeError(t *testing.T) {
|
||||
@ -29,3 +30,35 @@ func TestErrorPages_ServeError(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "469 Unknown Error Code\n\n", string(a))
|
||||
}
|
||||
|
||||
func TestErrorPagesWithCustom(t *testing.T) {
|
||||
fs := fstest.MapFS{
|
||||
"418.html": {
|
||||
Data: []byte("418 Custom Error Page\n"),
|
||||
},
|
||||
"469.html": {
|
||||
Data: []byte("469 Custom Error Page\n"),
|
||||
},
|
||||
}
|
||||
|
||||
errorPages := New(fs)
|
||||
assert.NoError(t, errorPages.internalCompile(errorPages.m))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
errorPages.ServeError(rec, http.StatusTeapot)
|
||||
res := rec.Result()
|
||||
assert.Equal(t, http.StatusTeapot, res.StatusCode)
|
||||
assert.Equal(t, "418 I'm a teapot", res.Status)
|
||||
a, err := io.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "418 Custom Error Page\n", string(a))
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
errorPages.ServeError(rec, 469)
|
||||
res = rec.Result()
|
||||
assert.Equal(t, 469, res.StatusCode)
|
||||
assert.Equal(t, "469 ", res.Status)
|
||||
a, err = io.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "469 Custom Error Page\n", string(a))
|
||||
}
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/mrmelon54/mjwt"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
@ -19,22 +18,7 @@ func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server
|
||||
|
||||
// Endpoint for compile action
|
||||
r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
|
||||
// Get bearer token
|
||||
bearer := utils.GetBearer(req)
|
||||
if bearer == "" {
|
||||
utils.RespondHttpStatus(rw, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Read claims from mjwt
|
||||
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](conf.Verify, bearer)
|
||||
if err != nil {
|
||||
utils.RespondHttpStatus(rw, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Token must have `violet:compile` perm
|
||||
if !b.Claims.Perms.Has("violet:compile") {
|
||||
if !hasPerms(conf.Verify, req, "violet:compile") {
|
||||
utils.RespondHttpStatus(rw, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
@ -44,8 +28,36 @@ func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
})
|
||||
|
||||
// Endpoint for acme-challenge
|
||||
r.PUT("/acme-challenge/:domain/:key/:value", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
if !hasPerms(conf.Verify, req, "violet:acme-challenge") {
|
||||
utils.RespondHttpStatus(rw, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
domain := params.ByName("domain")
|
||||
if !conf.Domains.IsValid(domain) {
|
||||
utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain")
|
||||
return
|
||||
}
|
||||
conf.Acme.Put(domain, params.ByName("key"), params.ByName("value"))
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
})
|
||||
r.DELETE("/acme-challenge/:domain/:key", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
if !hasPerms(conf.Verify, req, "violet:acme-challenge") {
|
||||
utils.RespondHttpStatus(rw, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
domain := params.ByName("domain")
|
||||
if !conf.Domains.IsValid(domain) {
|
||||
utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain")
|
||||
return
|
||||
}
|
||||
conf.Acme.Delete(domain, params.ByName("key"))
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
})
|
||||
|
||||
// Create and run http server
|
||||
s := &http.Server{
|
||||
return &http.Server{
|
||||
Addr: conf.ApiListen,
|
||||
Handler: r,
|
||||
ReadTimeout: time.Minute,
|
||||
@ -54,7 +66,21 @@ func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server
|
||||
IdleTimeout: time.Minute,
|
||||
MaxHeaderBytes: 2500,
|
||||
}
|
||||
log.Printf("[API] Starting API server on: '%s'\n", s.Addr)
|
||||
go utils.RunBackgroundHttp("API", s)
|
||||
return s
|
||||
}
|
||||
|
||||
func hasPerms(verify mjwt.Provider, req *http.Request, perm string) bool {
|
||||
// Get bearer token
|
||||
bearer := utils.GetBearer(req)
|
||||
if bearer == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Read claims from mjwt
|
||||
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Token must have perm
|
||||
return b.Claims.Perms.Has(perm)
|
||||
}
|
||||
|
164
servers/api_test.go
Normal file
164
servers/api_test.go
Normal file
@ -0,0 +1,164 @@
|
||||
package servers
|
||||
|
||||
import (
|
||||
"code.mrmelon54.com/melon/summer-utils/claims"
|
||||
"code.mrmelon54.com/melon/summer-utils/claims/auth"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/mrmelon54/mjwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var snakeOilProv = genSnakeOilProv()
|
||||
|
||||
type fakeDomains struct{}
|
||||
|
||||
func (f *fakeDomains) IsValid(host string) bool { return host == "example.com" }
|
||||
|
||||
func genSnakeOilProv() mjwt.Provider {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return mjwt.NewMJwtSigner("violet.test", key)
|
||||
}
|
||||
|
||||
func genSnakeOilKey(perm string) string {
|
||||
p := claims.NewPermStorage()
|
||||
p.Set(perm)
|
||||
val, err := snakeOilProv.GenerateJwt("abc", "abc", 5*time.Minute, auth.AccessTokenClaims{
|
||||
UserId: 1,
|
||||
Perms: p,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
type fakeCompilable struct{ done bool }
|
||||
|
||||
func (f *fakeCompilable) Compile() { f.done = true }
|
||||
|
||||
var _ utils.Compilable = &fakeCompilable{}
|
||||
|
||||
func TestNewApiServer_Compile(t *testing.T) {
|
||||
apiConf := &Conf{
|
||||
Domains: &fakeDomains{},
|
||||
Acme: utils.NewAcmeChallenge(),
|
||||
Verify: snakeOilProv,
|
||||
}
|
||||
f := &fakeCompilable{}
|
||||
srv := NewApiServer(apiConf, utils.MultiCompilable{f})
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com/compile", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
srv.Handler.ServeHTTP(rec, req)
|
||||
res := rec.Result()
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode)
|
||||
assert.False(t, f.done)
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+genSnakeOilKey("violet:compile"))
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
srv.Handler.ServeHTTP(rec, req)
|
||||
res = rec.Result()
|
||||
assert.Equal(t, http.StatusAccepted, res.StatusCode)
|
||||
assert.True(t, f.done)
|
||||
}
|
||||
|
||||
func TestNewApiServer_AcmeChallenge_Put(t *testing.T) {
|
||||
apiConf := &Conf{
|
||||
Domains: &fakeDomains{},
|
||||
Acme: utils.NewAcmeChallenge(),
|
||||
Verify: snakeOilProv,
|
||||
}
|
||||
srv := NewApiServer(apiConf, utils.MultiCompilable{})
|
||||
acmeKey := genSnakeOilKey("violet:acme-challenge")
|
||||
|
||||
// Valid domain
|
||||
req, err := http.NewRequest(http.MethodPut, "https://example.com/acme-challenge/example.com/123/123abc", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
srv.Handler.ServeHTTP(rec, req)
|
||||
res := rec.Result()
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode)
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+acmeKey)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
srv.Handler.ServeHTTP(rec, req)
|
||||
res = rec.Result()
|
||||
assert.Equal(t, http.StatusAccepted, res.StatusCode)
|
||||
assert.Equal(t, "123abc", apiConf.Acme.Get("example.com", "123"))
|
||||
|
||||
// Invalid domain
|
||||
req, err = http.NewRequest(http.MethodPut, "https://example.com/acme-challenge/notexample.com/123/123abc", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
srv.Handler.ServeHTTP(rec, req)
|
||||
res = rec.Result()
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode)
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+acmeKey)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
srv.Handler.ServeHTTP(rec, req)
|
||||
res = rec.Result()
|
||||
assert.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
assert.Equal(t, "Invalid ACME challenge domain", res.Header.Get("X-Violet-Error"))
|
||||
}
|
||||
|
||||
func TestNewApiServer_AcmeChallenge_Delete(t *testing.T) {
|
||||
apiConf := &Conf{
|
||||
Domains: &fakeDomains{},
|
||||
Acme: utils.NewAcmeChallenge(),
|
||||
Verify: snakeOilProv,
|
||||
}
|
||||
srv := NewApiServer(apiConf, utils.MultiCompilable{})
|
||||
acmeKey := genSnakeOilKey("violet:acme-challenge")
|
||||
|
||||
// Valid domain
|
||||
req, err := http.NewRequest(http.MethodDelete, "https://example.com/acme-challenge/example.com/123", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
srv.Handler.ServeHTTP(rec, req)
|
||||
res := rec.Result()
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode)
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+acmeKey)
|
||||
apiConf.Acme.Put("example.com", "123", "123abc")
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
srv.Handler.ServeHTTP(rec, req)
|
||||
res = rec.Result()
|
||||
assert.Equal(t, http.StatusAccepted, res.StatusCode)
|
||||
assert.Equal(t, "", apiConf.Acme.Get("example.com", "123"))
|
||||
|
||||
// Invalid domain
|
||||
req, err = http.NewRequest(http.MethodDelete, "https://example.com/acme-challenge/notexample.com/123", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
srv.Handler.ServeHTTP(rec, req)
|
||||
res = rec.Result()
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode)
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+acmeKey)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
srv.Handler.ServeHTTP(rec, req)
|
||||
res = rec.Result()
|
||||
assert.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
assert.Equal(t, "Invalid ACME challenge domain", res.Header.Get("X-Violet-Error"))
|
||||
}
|
@ -1,9 +1,8 @@
|
||||
package servers
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"github.com/MrMelon54/violet/certs"
|
||||
"github.com/MrMelon54/violet/domains"
|
||||
errorPages "github.com/MrMelon54/violet/error-pages"
|
||||
"github.com/MrMelon54/violet/favicons"
|
||||
"github.com/MrMelon54/violet/router"
|
||||
@ -15,11 +14,27 @@ type Conf struct {
|
||||
ApiListen string // api server listen address
|
||||
HttpListen string // http server listen address
|
||||
HttpsListen string // https server listen address
|
||||
RateLimit uint64 // rate limit per minute
|
||||
DB *sql.DB
|
||||
Domains *domains.Domains
|
||||
Certs *certs.Certs
|
||||
Domains DomainProvider
|
||||
Acme AcmeChallengeProvider
|
||||
Certs CertProvider
|
||||
Favicons *favicons.Favicons
|
||||
Verify mjwt.Provider
|
||||
ErrorPages *errorPages.ErrorPages
|
||||
Router *router.Manager
|
||||
}
|
||||
|
||||
type DomainProvider interface {
|
||||
IsValid(host string) bool
|
||||
}
|
||||
|
||||
type AcmeChallengeProvider interface {
|
||||
Get(domain, key string) string
|
||||
Put(domain, key, value string)
|
||||
Delete(domain, key string)
|
||||
}
|
||||
|
||||
type CertProvider interface {
|
||||
GetCertForDomain(domain string) *tls.Certificate
|
||||
}
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
@ -27,47 +26,30 @@ func NewHttpServer(conf *Conf) *http.Server {
|
||||
}
|
||||
|
||||
// Endpoint for acme challenge outputs
|
||||
r.GET("/.well-known/acme-challenge/{key}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
if h, ok := utils.GetDomainWithoutPort(req.Host); ok {
|
||||
r.GET("/.well-known/acme-challenge/:key", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
h := utils.GetDomainWithoutPort(req.Host)
|
||||
|
||||
// check if the host is valid
|
||||
if !conf.Domains.IsValid(req.Host) {
|
||||
http.Error(rw, fmt.Sprintf("%d %s\n", 420, "Invalid host"), 420)
|
||||
utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid host")
|
||||
return
|
||||
}
|
||||
|
||||
// check if the key is valid
|
||||
key := params.ByName("key")
|
||||
if key == "" {
|
||||
value := conf.Acme.Get(h, params.ByName("key"))
|
||||
if value == "" {
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// prepare for executing query
|
||||
prepare, err := conf.DB.Prepare("select value from acme_challenges limit 1 where domain = ? and key = ?")
|
||||
if err != nil {
|
||||
utils.RespondHttpStatus(rw, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// query the row and extract the value
|
||||
row := prepare.QueryRow(h, key)
|
||||
var value string
|
||||
err = row.Scan(&value)
|
||||
if err != nil {
|
||||
utils.RespondHttpStatus(rw, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// output response
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = rw.Write([]byte(value))
|
||||
}
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
})
|
||||
|
||||
// All other paths lead here and are forwarded to HTTPS
|
||||
r.NotFound = http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if h, ok := utils.GetDomainWithoutPort(req.Host); ok {
|
||||
h := utils.GetDomainWithoutPort(req.Host)
|
||||
u := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: h + secureExtend,
|
||||
@ -76,11 +58,10 @@ func NewHttpServer(conf *Conf) *http.Server {
|
||||
RawQuery: req.URL.RawQuery,
|
||||
}
|
||||
utils.FastRedirect(rw, req, u.String(), http.StatusPermanentRedirect)
|
||||
}
|
||||
})
|
||||
|
||||
// Create and run http server
|
||||
s := &http.Server{
|
||||
return &http.Server{
|
||||
Addr: conf.HttpListen,
|
||||
Handler: r,
|
||||
ReadTimeout: time.Minute,
|
||||
@ -89,7 +70,4 @@ func NewHttpServer(conf *Conf) *http.Server {
|
||||
IdleTimeout: time.Minute,
|
||||
MaxHeaderBytes: 2500,
|
||||
}
|
||||
log.Printf("[HTTP] Starting HTTP server on: '%s'\n", s.Addr)
|
||||
go utils.RunBackgroundHttp("HTTP", s)
|
||||
return s
|
||||
}
|
||||
|
46
servers/http_test.go
Normal file
46
servers/http_test.go
Normal file
@ -0,0 +1,46 @@
|
||||
package servers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/MrMelon54/violet/utils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewHttpServer_AcmeChallenge(t *testing.T) {
|
||||
httpConf := &Conf{
|
||||
Domains: &fakeDomains{},
|
||||
Acme: utils.NewAcmeChallenge(),
|
||||
Verify: snakeOilProv,
|
||||
}
|
||||
srv := NewHttpServer(httpConf)
|
||||
httpConf.Acme.Put("example.com", "456", "456def")
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "https://example.com/.well-known/acme-challenge/456", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
srv.Handler.ServeHTTP(rec, req)
|
||||
res := rec.Result()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
all, err := io.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, bytes.Compare([]byte("456def"), all))
|
||||
|
||||
// Invalid key
|
||||
req, err = http.NewRequest(http.MethodGet, "https://example.com/.well-known/acme-challenge/789", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
srv.Handler.ServeHTTP(rec, req)
|
||||
res = rec.Result()
|
||||
assert.Equal(t, http.StatusNotFound, res.StatusCode)
|
||||
|
||||
all, err = io.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, bytes.Compare([]byte(""), all))
|
||||
}
|
@ -16,10 +16,9 @@ import (
|
||||
// NewHttpsServer creates and runs a http server containing the public https
|
||||
// endpoints for the reverse proxy.
|
||||
func NewHttpsServer(conf *Conf) *http.Server {
|
||||
s := &http.Server{
|
||||
return &http.Server{
|
||||
Addr: conf.HttpsListen,
|
||||
Handler: setupRateLimiter(300, setupFaviconMiddleware(conf.Favicons, conf.Router)),
|
||||
DisableGeneralOptionsHandler: false,
|
||||
Handler: setupRateLimiter(conf.RateLimit, setupFaviconMiddleware(conf.Favicons, conf.Router)),
|
||||
TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
// error out on invalid domains
|
||||
if !conf.Domains.IsValid(info.ServerName) {
|
||||
@ -41,12 +40,9 @@ func NewHttpsServer(conf *Conf) *http.Server {
|
||||
IdleTimeout: 150 * time.Second,
|
||||
MaxHeaderBytes: 4096000,
|
||||
ConnState: func(conn net.Conn, state http.ConnState) {
|
||||
fmt.Printf("%s => %s: %s\n", conn.LocalAddr(), conn.RemoteAddr(), state.String())
|
||||
fmt.Printf("[HTTPS] %s => %s: %s\n", conn.LocalAddr(), conn.RemoteAddr(), state.String())
|
||||
},
|
||||
}
|
||||
log.Printf("[HTTPS] Starting HTTPS server on: '%s'\n", s.Addr)
|
||||
go utils.RunBackgroundHttps("HTTPS", s)
|
||||
return s
|
||||
}
|
||||
|
||||
// setupRateLimiter is an internal function to create a middleware to manage
|
||||
|
59
servers/https_test.go
Normal file
59
servers/https_test.go
Normal file
@ -0,0 +1,59 @@
|
||||
package servers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/MrMelon54/violet/certs"
|
||||
"github.com/MrMelon54/violet/proxy"
|
||||
"github.com/MrMelon54/violet/router"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fakeTransport struct{}
|
||||
|
||||
func (f *fakeTransport) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||
rec := httptest.NewRecorder()
|
||||
rec.WriteHeader(http.StatusOK)
|
||||
return rec.Result(), nil
|
||||
}
|
||||
|
||||
func TestNewHttpsServer_RateLimit(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
|
||||
assert.NoError(t, err)
|
||||
|
||||
ft := &fakeTransport{}
|
||||
httpsConf := &Conf{
|
||||
RateLimit: 5,
|
||||
Domains: &fakeDomains{},
|
||||
Certs: certs.New(nil, nil, true),
|
||||
Verify: snakeOilProv,
|
||||
Router: router.NewManager(db, proxy.NewHybridTransportWithCalls(ft, ft)),
|
||||
}
|
||||
srv := NewHttpsServer(httpsConf)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||
req.RemoteAddr = "127.0.0.1:1447"
|
||||
assert.NoError(t, err)
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(5)
|
||||
for i := 0; i < 5; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
rec := httptest.NewRecorder()
|
||||
srv.Handler.ServeHTTP(rec, req)
|
||||
res := rec.Result()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
srv.Handler.ServeHTTP(rec, req)
|
||||
res := rec.Result()
|
||||
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
|
||||
}
|
55
utils/acme-challenges.go
Normal file
55
utils/acme-challenges.go
Normal file
@ -0,0 +1,55 @@
|
||||
package utils
|
||||
|
||||
import "sync"
|
||||
|
||||
type AcmeChallenges struct {
|
||||
s *sync.RWMutex
|
||||
d map[string]*AcmeStorage
|
||||
}
|
||||
|
||||
type AcmeStorage struct {
|
||||
s *sync.RWMutex
|
||||
v map[string]string
|
||||
}
|
||||
|
||||
func NewAcmeChallenge() *AcmeChallenges {
|
||||
return &AcmeChallenges{
|
||||
s: &sync.RWMutex{},
|
||||
d: make(map[string]*AcmeStorage),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AcmeChallenges) Get(domain, key string) string {
|
||||
a.s.RLock()
|
||||
defer a.s.RUnlock()
|
||||
if m := a.d[domain]; m != nil {
|
||||
m.s.RLock()
|
||||
defer m.s.RUnlock()
|
||||
return m.v[key]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *AcmeChallenges) Put(domain, key, value string) {
|
||||
a.s.Lock()
|
||||
m := a.d[domain]
|
||||
if m == nil {
|
||||
m = &AcmeStorage{
|
||||
s: &sync.RWMutex{},
|
||||
v: make(map[string]string),
|
||||
}
|
||||
a.d[domain] = m
|
||||
}
|
||||
m.s.Lock()
|
||||
m.v[key] = value
|
||||
m.s.Unlock()
|
||||
a.s.Unlock()
|
||||
}
|
||||
|
||||
func (a *AcmeChallenges) Delete(domain, key string) {
|
||||
a.s.Lock()
|
||||
if m := a.d[domain]; m != nil {
|
||||
delete(m.v, key)
|
||||
}
|
||||
a.s.Unlock()
|
||||
}
|
27
utils/acme-challenges_test.go
Normal file
27
utils/acme-challenges_test.go
Normal file
@ -0,0 +1,27 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAcmeChallenges(t *testing.T) {
|
||||
a := NewAcmeChallenge()
|
||||
assert.Equal(t, "", a.Get("example.com", "123"))
|
||||
|
||||
// The challenge should be created
|
||||
a.Put("example.com", "123", "123abc")
|
||||
assert.Equal(t, "123abc", a.Get("example.com", "123"))
|
||||
|
||||
// The challenge should be deleted
|
||||
a.Delete("example.com", "123")
|
||||
assert.Equal(t, "", a.Get("example.com", "123"))
|
||||
|
||||
// This should not crash or stop execution
|
||||
a.Delete("example.com", "123")
|
||||
assert.Equal(t, "", a.Get("example.com", "123"))
|
||||
|
||||
// This should not crash or stop execution
|
||||
a.Delete("www.example.com", "123")
|
||||
assert.Equal(t, "", a.Get("example.com", "123"))
|
||||
}
|
@ -27,13 +27,13 @@ func SplitDomainPort(host string, defaultPort int) (domain string, port int, ok
|
||||
// without the port.
|
||||
//
|
||||
// example.com:443 => example.com
|
||||
func GetDomainWithoutPort(domain string) (string, bool) {
|
||||
func GetDomainWithoutPort(domain string) string {
|
||||
// if a valid index isn't found then return false
|
||||
n := strings.LastIndexByte(domain, ':')
|
||||
if n == -1 {
|
||||
return "", false
|
||||
return domain
|
||||
}
|
||||
return domain[:n], true
|
||||
return domain[:n]
|
||||
}
|
||||
|
||||
// ReplaceSubdomainWithWildcard returns the domain with the subdomain replaced
|
||||
|
@ -18,12 +18,16 @@ func TestSplitDomainPort(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDomainWithoutPort(t *testing.T) {
|
||||
domain, ok := GetDomainWithoutPort("www.example.com:5612")
|
||||
assert.True(t, ok, "Output should be true")
|
||||
domain := GetDomainWithoutPort("www.example.com:5612")
|
||||
assert.Equal(t, "www.example.com", domain)
|
||||
|
||||
domain, ok = GetDomainWithoutPort("example.com:443")
|
||||
assert.True(t, ok, "Output should be true")
|
||||
domain = GetDomainWithoutPort("example.com:443")
|
||||
assert.Equal(t, "example.com", domain)
|
||||
|
||||
domain = GetDomainWithoutPort("www.example.com")
|
||||
assert.Equal(t, "www.example.com", domain)
|
||||
|
||||
domain = GetDomainWithoutPort("example.com")
|
||||
assert.Equal(t, "example.com", domain)
|
||||
}
|
||||
|
||||
|
22
utils/multi-compilable_test.go
Normal file
22
utils/multi-compilable_test.go
Normal file
@ -0,0 +1,22 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fakeCompile struct{ done bool }
|
||||
|
||||
func (f *fakeCompile) Compile() {
|
||||
f.done = true
|
||||
}
|
||||
|
||||
var _ Compilable = &fakeCompile{}
|
||||
|
||||
func TestMultiCompilable_Compile(t *testing.T) {
|
||||
f := &fakeCompile{}
|
||||
a := MultiCompilable{f}
|
||||
assert.False(t, f.done)
|
||||
a.Compile()
|
||||
assert.True(t, f.done)
|
||||
}
|
32
utils/response_test.go
Normal file
32
utils/response_test.go
Normal file
@ -0,0 +1,32 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRespondHttpStatus(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
RespondHttpStatus(rec, http.StatusTeapot)
|
||||
res := rec.Result()
|
||||
assert.Equal(t, http.StatusTeapot, res.StatusCode)
|
||||
assert.Equal(t, "418 I'm a teapot", res.Status)
|
||||
a, err := io.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "418 I'm a teapot\n\n", string(a))
|
||||
}
|
||||
|
||||
func TestRespondVioletError(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
RespondVioletError(rec, http.StatusTeapot, "Hidden Error Message")
|
||||
res := rec.Result()
|
||||
assert.Equal(t, http.StatusTeapot, res.StatusCode)
|
||||
assert.Equal(t, "418 I'm a teapot", res.Status)
|
||||
a, err := io.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "418 I'm a teapot\n\n", string(a))
|
||||
assert.Equal(t, "Hidden Error Message", res.Header.Get("X-Violet-Error"))
|
||||
}
|
20
utils/server-utils_test.go
Normal file
20
utils/server-utils_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetBearer(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
req.Header.Set("Authorization", "Bearer abc")
|
||||
assert.Equal(t, "abc", GetBearer(req))
|
||||
}
|
||||
|
||||
func TestGetBearer_Empty(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "", GetBearer(req))
|
||||
}
|
Loading…
Reference in New Issue
Block a user