diff --git a/certs/certs.go b/certs/certs.go index 66a0f6d..14252ae 100644 --- a/certs/certs.go +++ b/certs/certs.go @@ -121,6 +121,10 @@ func (c *Certs) Compile() { // internalCompile is a hidden internal method for loading the certificate and // key files func (c *Certs) internalCompile(m map[string]*tls.Certificate) error { + if c.cDir == nil { + return nil + } + // try to read dir files, err := fs.ReadDir(c.cDir, "") if err != nil { diff --git a/cmd/violet/main.go b/cmd/violet/main.go index 69dc0f8..97ea2e8 100644 --- a/cmd/violet/main.go +++ b/cmd/violet/main.go @@ -33,6 +33,7 @@ var ( httpListen = flag.String("http", "0.0.0.0:80", "address for http listening") httpsListen = flag.String("https", "0.0.0.0:443", "address for https listening") inkscapeCmd = flag.String("inkscape", "inkscape", "Path to inkscape binary") + rateLimit = flag.Uint64("ratelimit", 300, "Rate limit (max requests per minute)") ) func main() { @@ -61,6 +62,7 @@ func main() { } allowedDomains := domains.New(db) // load allowed domains + acmeChallenges := utils.NewAcmeChallenge() // load acme challenge store allowedCerts := certs.New(os.DirFS(*certPath), os.DirFS(*keyPath), *selfSigned) // load certificate manager reverseProxy := proxy.NewHybridTransport() // load reverse proxy dynamicFavicons := favicons.New(db, *inkscapeCmd) // load dynamic favicon provider @@ -72,8 +74,10 @@ func main() { ApiListen: *apiListen, HttpListen: *httpListen, HttpsListen: *httpsListen, + RateLimit: *rateLimit, DB: db, Domains: allowedDomains, + Acme: acmeChallenges, Certs: allowedCerts, Favicons: dynamicFavicons, Verify: nil, // TODO: add mjwt verify support @@ -84,12 +88,18 @@ func main() { var srvApi, srvHttp, srvHttps *http.Server if *apiListen != "" { srvApi = servers.NewApiServer(srvConf, utils.MultiCompilable{allowedDomains, allowedCerts, dynamicFavicons, dynamicErrorPages, dynamicRouter}) + log.Printf("[API] Starting API server on: '%s'\n", srvApi.Addr) + go utils.RunBackgroundHttp("API", srvApi) } if *httpListen != "" { srvHttp = servers.NewHttpServer(srvConf) + log.Printf("[HTTP] Starting HTTP server on: '%s'\n", srvHttp.Addr) + go utils.RunBackgroundHttp("HTTP", srvHttp) } if *httpsListen != "" { srvHttps = servers.NewHttpsServer(srvConf) + log.Printf("[HTTPS] Starting HTTPS server on: '%s'\n", srvHttps.Addr) + go utils.RunBackgroundHttps("HTTPS", srvHttps) } // Wait for exit signal diff --git a/error-pages/error-pages.go b/error-pages/error-pages.go index 7eb91dc..38f316d 100644 --- a/error-pages/error-pages.go +++ b/error-pages/error-pages.go @@ -82,7 +82,7 @@ func (e *ErrorPages) Compile() { func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) error { // try to read dir - files, err := fs.ReadDir(e.dir, "") + files, err := fs.ReadDir(e.dir, ".") if err != nil { return fmt.Errorf("failed to read error pages dir: %w", err) } @@ -101,7 +101,7 @@ func (e *ErrorPages) internalCompile(m map[int]func(rw http.ResponseWriter)) err ext := filepath.Ext(name) // if the extension is not 'html' then ignore the file - if ext != "html" { + if ext != ".html" { log.Printf("[ErrorPages] WARNING: ignoring non '.html' file in error pages directory: '%s'\n", name) continue } diff --git a/error-pages/error-pages_test.go b/error-pages/error-pages_test.go index d5ee363..5f92178 100644 --- a/error-pages/error-pages_test.go +++ b/error-pages/error-pages_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" "testing" + "testing/fstest" ) func TestErrorPages_ServeError(t *testing.T) { @@ -29,3 +30,35 @@ func TestErrorPages_ServeError(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "469 Unknown Error Code\n\n", string(a)) } + +func TestErrorPagesWithCustom(t *testing.T) { + fs := fstest.MapFS{ + "418.html": { + Data: []byte("418 Custom Error Page\n"), + }, + "469.html": { + Data: []byte("469 Custom Error Page\n"), + }, + } + + errorPages := New(fs) + assert.NoError(t, errorPages.internalCompile(errorPages.m)) + + rec := httptest.NewRecorder() + errorPages.ServeError(rec, http.StatusTeapot) + res := rec.Result() + assert.Equal(t, http.StatusTeapot, res.StatusCode) + assert.Equal(t, "418 I'm a teapot", res.Status) + a, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.Equal(t, "418 Custom Error Page\n", string(a)) + + rec = httptest.NewRecorder() + errorPages.ServeError(rec, 469) + res = rec.Result() + assert.Equal(t, 469, res.StatusCode) + assert.Equal(t, "469 ", res.Status) + a, err = io.ReadAll(res.Body) + assert.NoError(t, err) + assert.Equal(t, "469 Custom Error Page\n", string(a)) +} diff --git a/servers/api.go b/servers/api.go index 292ebc2..e7cd1a7 100644 --- a/servers/api.go +++ b/servers/api.go @@ -5,7 +5,6 @@ import ( "github.com/MrMelon54/violet/utils" "github.com/julienschmidt/httprouter" "github.com/mrmelon54/mjwt" - "log" "net/http" "time" ) @@ -19,22 +18,7 @@ func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server // Endpoint for compile action r.POST("/compile", func(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { - // Get bearer token - bearer := utils.GetBearer(req) - if bearer == "" { - utils.RespondHttpStatus(rw, http.StatusForbidden) - return - } - - // Read claims from mjwt - _, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](conf.Verify, bearer) - if err != nil { - utils.RespondHttpStatus(rw, http.StatusForbidden) - return - } - - // Token must have `violet:compile` perm - if !b.Claims.Perms.Has("violet:compile") { + if !hasPerms(conf.Verify, req, "violet:compile") { utils.RespondHttpStatus(rw, http.StatusForbidden) return } @@ -44,8 +28,36 @@ func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server rw.WriteHeader(http.StatusAccepted) }) + // Endpoint for acme-challenge + r.PUT("/acme-challenge/:domain/:key/:value", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + if !hasPerms(conf.Verify, req, "violet:acme-challenge") { + utils.RespondHttpStatus(rw, http.StatusForbidden) + return + } + domain := params.ByName("domain") + if !conf.Domains.IsValid(domain) { + utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain") + return + } + conf.Acme.Put(domain, params.ByName("key"), params.ByName("value")) + rw.WriteHeader(http.StatusAccepted) + }) + r.DELETE("/acme-challenge/:domain/:key", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + if !hasPerms(conf.Verify, req, "violet:acme-challenge") { + utils.RespondHttpStatus(rw, http.StatusForbidden) + return + } + domain := params.ByName("domain") + if !conf.Domains.IsValid(domain) { + utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid ACME challenge domain") + return + } + conf.Acme.Delete(domain, params.ByName("key")) + rw.WriteHeader(http.StatusAccepted) + }) + // Create and run http server - s := &http.Server{ + return &http.Server{ Addr: conf.ApiListen, Handler: r, ReadTimeout: time.Minute, @@ -54,7 +66,21 @@ func NewApiServer(conf *Conf, compileTarget utils.MultiCompilable) *http.Server IdleTimeout: time.Minute, MaxHeaderBytes: 2500, } - log.Printf("[API] Starting API server on: '%s'\n", s.Addr) - go utils.RunBackgroundHttp("API", s) - return s +} + +func hasPerms(verify mjwt.Provider, req *http.Request, perm string) bool { + // Get bearer token + bearer := utils.GetBearer(req) + if bearer == "" { + return false + } + + // Read claims from mjwt + _, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer) + if err != nil { + return false + } + + // Token must have perm + return b.Claims.Perms.Has(perm) } diff --git a/servers/api_test.go b/servers/api_test.go new file mode 100644 index 0000000..f2a2c5c --- /dev/null +++ b/servers/api_test.go @@ -0,0 +1,164 @@ +package servers + +import ( + "code.mrmelon54.com/melon/summer-utils/claims" + "code.mrmelon54.com/melon/summer-utils/claims/auth" + "crypto/rand" + "crypto/rsa" + "github.com/MrMelon54/violet/utils" + "github.com/mrmelon54/mjwt" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +var snakeOilProv = genSnakeOilProv() + +type fakeDomains struct{} + +func (f *fakeDomains) IsValid(host string) bool { return host == "example.com" } + +func genSnakeOilProv() mjwt.Provider { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + panic(err) + } + return mjwt.NewMJwtSigner("violet.test", key) +} + +func genSnakeOilKey(perm string) string { + p := claims.NewPermStorage() + p.Set(perm) + val, err := snakeOilProv.GenerateJwt("abc", "abc", 5*time.Minute, auth.AccessTokenClaims{ + UserId: 1, + Perms: p, + }) + if err != nil { + panic(err) + } + return val +} + +type fakeCompilable struct{ done bool } + +func (f *fakeCompilable) Compile() { f.done = true } + +var _ utils.Compilable = &fakeCompilable{} + +func TestNewApiServer_Compile(t *testing.T) { + apiConf := &Conf{ + Domains: &fakeDomains{}, + Acme: utils.NewAcmeChallenge(), + Verify: snakeOilProv, + } + f := &fakeCompilable{} + srv := NewApiServer(apiConf, utils.MultiCompilable{f}) + + req, err := http.NewRequest(http.MethodPost, "https://example.com/compile", nil) + assert.NoError(t, err) + + rec := httptest.NewRecorder() + srv.Handler.ServeHTTP(rec, req) + res := rec.Result() + assert.Equal(t, http.StatusForbidden, res.StatusCode) + assert.False(t, f.done) + + req.Header.Set("Authorization", "Bearer "+genSnakeOilKey("violet:compile")) + + rec = httptest.NewRecorder() + srv.Handler.ServeHTTP(rec, req) + res = rec.Result() + assert.Equal(t, http.StatusAccepted, res.StatusCode) + assert.True(t, f.done) +} + +func TestNewApiServer_AcmeChallenge_Put(t *testing.T) { + apiConf := &Conf{ + Domains: &fakeDomains{}, + Acme: utils.NewAcmeChallenge(), + Verify: snakeOilProv, + } + srv := NewApiServer(apiConf, utils.MultiCompilable{}) + acmeKey := genSnakeOilKey("violet:acme-challenge") + + // Valid domain + req, err := http.NewRequest(http.MethodPut, "https://example.com/acme-challenge/example.com/123/123abc", nil) + assert.NoError(t, err) + + rec := httptest.NewRecorder() + srv.Handler.ServeHTTP(rec, req) + res := rec.Result() + assert.Equal(t, http.StatusForbidden, res.StatusCode) + + req.Header.Set("Authorization", "Bearer "+acmeKey) + + rec = httptest.NewRecorder() + srv.Handler.ServeHTTP(rec, req) + res = rec.Result() + assert.Equal(t, http.StatusAccepted, res.StatusCode) + assert.Equal(t, "123abc", apiConf.Acme.Get("example.com", "123")) + + // Invalid domain + req, err = http.NewRequest(http.MethodPut, "https://example.com/acme-challenge/notexample.com/123/123abc", nil) + assert.NoError(t, err) + + rec = httptest.NewRecorder() + srv.Handler.ServeHTTP(rec, req) + res = rec.Result() + assert.Equal(t, http.StatusForbidden, res.StatusCode) + + req.Header.Set("Authorization", "Bearer "+acmeKey) + + rec = httptest.NewRecorder() + srv.Handler.ServeHTTP(rec, req) + res = rec.Result() + assert.Equal(t, http.StatusBadRequest, res.StatusCode) + assert.Equal(t, "Invalid ACME challenge domain", res.Header.Get("X-Violet-Error")) +} + +func TestNewApiServer_AcmeChallenge_Delete(t *testing.T) { + apiConf := &Conf{ + Domains: &fakeDomains{}, + Acme: utils.NewAcmeChallenge(), + Verify: snakeOilProv, + } + srv := NewApiServer(apiConf, utils.MultiCompilable{}) + acmeKey := genSnakeOilKey("violet:acme-challenge") + + // Valid domain + req, err := http.NewRequest(http.MethodDelete, "https://example.com/acme-challenge/example.com/123", nil) + assert.NoError(t, err) + + rec := httptest.NewRecorder() + srv.Handler.ServeHTTP(rec, req) + res := rec.Result() + assert.Equal(t, http.StatusForbidden, res.StatusCode) + + req.Header.Set("Authorization", "Bearer "+acmeKey) + apiConf.Acme.Put("example.com", "123", "123abc") + + rec = httptest.NewRecorder() + srv.Handler.ServeHTTP(rec, req) + res = rec.Result() + assert.Equal(t, http.StatusAccepted, res.StatusCode) + assert.Equal(t, "", apiConf.Acme.Get("example.com", "123")) + + // Invalid domain + req, err = http.NewRequest(http.MethodDelete, "https://example.com/acme-challenge/notexample.com/123", nil) + assert.NoError(t, err) + + rec = httptest.NewRecorder() + srv.Handler.ServeHTTP(rec, req) + res = rec.Result() + assert.Equal(t, http.StatusForbidden, res.StatusCode) + + req.Header.Set("Authorization", "Bearer "+acmeKey) + + rec = httptest.NewRecorder() + srv.Handler.ServeHTTP(rec, req) + res = rec.Result() + assert.Equal(t, http.StatusBadRequest, res.StatusCode) + assert.Equal(t, "Invalid ACME challenge domain", res.Header.Get("X-Violet-Error")) +} diff --git a/servers/conf.go b/servers/conf.go index 5bf11a3..c885978 100644 --- a/servers/conf.go +++ b/servers/conf.go @@ -1,9 +1,8 @@ package servers import ( + "crypto/tls" "database/sql" - "github.com/MrMelon54/violet/certs" - "github.com/MrMelon54/violet/domains" errorPages "github.com/MrMelon54/violet/error-pages" "github.com/MrMelon54/violet/favicons" "github.com/MrMelon54/violet/router" @@ -15,11 +14,27 @@ type Conf struct { ApiListen string // api server listen address HttpListen string // http server listen address HttpsListen string // https server listen address + RateLimit uint64 // rate limit per minute DB *sql.DB - Domains *domains.Domains - Certs *certs.Certs + Domains DomainProvider + Acme AcmeChallengeProvider + Certs CertProvider Favicons *favicons.Favicons Verify mjwt.Provider ErrorPages *errorPages.ErrorPages Router *router.Manager } + +type DomainProvider interface { + IsValid(host string) bool +} + +type AcmeChallengeProvider interface { + Get(domain, key string) string + Put(domain, key, value string) + Delete(domain, key string) +} + +type CertProvider interface { + GetCertForDomain(domain string) *tls.Certificate +} diff --git a/servers/http.go b/servers/http.go index 416ff7f..fe1fb58 100644 --- a/servers/http.go +++ b/servers/http.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/MrMelon54/violet/utils" "github.com/julienschmidt/httprouter" - "log" "net/http" "net/url" "time" @@ -27,60 +26,42 @@ func NewHttpServer(conf *Conf) *http.Server { } // Endpoint for acme challenge outputs - r.GET("/.well-known/acme-challenge/{key}", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { - if h, ok := utils.GetDomainWithoutPort(req.Host); ok { - // check if the host is valid - if !conf.Domains.IsValid(req.Host) { - http.Error(rw, fmt.Sprintf("%d %s\n", 420, "Invalid host"), 420) - return - } + r.GET("/.well-known/acme-challenge/:key", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + h := utils.GetDomainWithoutPort(req.Host) - // check if the key is valid - key := params.ByName("key") - if key == "" { - rw.WriteHeader(http.StatusNotFound) - return - } - - // prepare for executing query - prepare, err := conf.DB.Prepare("select value from acme_challenges limit 1 where domain = ? and key = ?") - if err != nil { - utils.RespondHttpStatus(rw, http.StatusInternalServerError) - return - } - - // query the row and extract the value - row := prepare.QueryRow(h, key) - var value string - err = row.Scan(&value) - if err != nil { - utils.RespondHttpStatus(rw, http.StatusInternalServerError) - return - } - - // output response - rw.WriteHeader(http.StatusOK) - _, _ = rw.Write([]byte(value)) + // check if the host is valid + if !conf.Domains.IsValid(req.Host) { + utils.RespondVioletError(rw, http.StatusBadRequest, "Invalid host") + return } - rw.WriteHeader(http.StatusNotFound) + + // check if the key is valid + value := conf.Acme.Get(h, params.ByName("key")) + if value == "" { + rw.WriteHeader(http.StatusNotFound) + return + } + + // output response + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write([]byte(value)) }) // All other paths lead here and are forwarded to HTTPS r.NotFound = http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if h, ok := utils.GetDomainWithoutPort(req.Host); ok { - u := &url.URL{ - Scheme: "https", - Host: h + secureExtend, - Path: req.URL.Path, - RawPath: req.URL.RawPath, - RawQuery: req.URL.RawQuery, - } - utils.FastRedirect(rw, req, u.String(), http.StatusPermanentRedirect) + h := utils.GetDomainWithoutPort(req.Host) + u := &url.URL{ + Scheme: "https", + Host: h + secureExtend, + Path: req.URL.Path, + RawPath: req.URL.RawPath, + RawQuery: req.URL.RawQuery, } + utils.FastRedirect(rw, req, u.String(), http.StatusPermanentRedirect) }) // Create and run http server - s := &http.Server{ + return &http.Server{ Addr: conf.HttpListen, Handler: r, ReadTimeout: time.Minute, @@ -89,7 +70,4 @@ func NewHttpServer(conf *Conf) *http.Server { IdleTimeout: time.Minute, MaxHeaderBytes: 2500, } - log.Printf("[HTTP] Starting HTTP server on: '%s'\n", s.Addr) - go utils.RunBackgroundHttp("HTTP", s) - return s } diff --git a/servers/http_test.go b/servers/http_test.go new file mode 100644 index 0000000..4069e6e --- /dev/null +++ b/servers/http_test.go @@ -0,0 +1,46 @@ +package servers + +import ( + "bytes" + "github.com/MrMelon54/violet/utils" + "github.com/stretchr/testify/assert" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func TestNewHttpServer_AcmeChallenge(t *testing.T) { + httpConf := &Conf{ + Domains: &fakeDomains{}, + Acme: utils.NewAcmeChallenge(), + Verify: snakeOilProv, + } + srv := NewHttpServer(httpConf) + httpConf.Acme.Put("example.com", "456", "456def") + + req, err := http.NewRequest(http.MethodGet, "https://example.com/.well-known/acme-challenge/456", nil) + assert.NoError(t, err) + + rec := httptest.NewRecorder() + srv.Handler.ServeHTTP(rec, req) + res := rec.Result() + assert.Equal(t, http.StatusOK, res.StatusCode) + + all, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.Equal(t, 0, bytes.Compare([]byte("456def"), all)) + + // Invalid key + req, err = http.NewRequest(http.MethodGet, "https://example.com/.well-known/acme-challenge/789", nil) + assert.NoError(t, err) + + rec = httptest.NewRecorder() + srv.Handler.ServeHTTP(rec, req) + res = rec.Result() + assert.Equal(t, http.StatusNotFound, res.StatusCode) + + all, err = io.ReadAll(res.Body) + assert.NoError(t, err) + assert.Equal(t, 0, bytes.Compare([]byte(""), all)) +} diff --git a/servers/https.go b/servers/https.go index 185a661..b85a1e7 100644 --- a/servers/https.go +++ b/servers/https.go @@ -16,10 +16,9 @@ import ( // NewHttpsServer creates and runs a http server containing the public https // endpoints for the reverse proxy. func NewHttpsServer(conf *Conf) *http.Server { - s := &http.Server{ - Addr: conf.HttpsListen, - Handler: setupRateLimiter(300, setupFaviconMiddleware(conf.Favicons, conf.Router)), - DisableGeneralOptionsHandler: false, + return &http.Server{ + Addr: conf.HttpsListen, + Handler: setupRateLimiter(conf.RateLimit, setupFaviconMiddleware(conf.Favicons, conf.Router)), TLSConfig: &tls.Config{GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { // error out on invalid domains if !conf.Domains.IsValid(info.ServerName) { @@ -41,12 +40,9 @@ func NewHttpsServer(conf *Conf) *http.Server { IdleTimeout: 150 * time.Second, MaxHeaderBytes: 4096000, ConnState: func(conn net.Conn, state http.ConnState) { - fmt.Printf("%s => %s: %s\n", conn.LocalAddr(), conn.RemoteAddr(), state.String()) + fmt.Printf("[HTTPS] %s => %s: %s\n", conn.LocalAddr(), conn.RemoteAddr(), state.String()) }, } - log.Printf("[HTTPS] Starting HTTPS server on: '%s'\n", s.Addr) - go utils.RunBackgroundHttps("HTTPS", s) - return s } // setupRateLimiter is an internal function to create a middleware to manage diff --git a/servers/https_test.go b/servers/https_test.go new file mode 100644 index 0000000..4c0ee04 --- /dev/null +++ b/servers/https_test.go @@ -0,0 +1,59 @@ +package servers + +import ( + "database/sql" + "github.com/MrMelon54/violet/certs" + "github.com/MrMelon54/violet/proxy" + "github.com/MrMelon54/violet/router" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "sync" + "testing" +) + +type fakeTransport struct{} + +func (f *fakeTransport) RoundTrip(_ *http.Request) (*http.Response, error) { + rec := httptest.NewRecorder() + rec.WriteHeader(http.StatusOK) + return rec.Result(), nil +} + +func TestNewHttpsServer_RateLimit(t *testing.T) { + db, err := sql.Open("sqlite3", "file::memory:?cache=shared") + assert.NoError(t, err) + + ft := &fakeTransport{} + httpsConf := &Conf{ + RateLimit: 5, + Domains: &fakeDomains{}, + Certs: certs.New(nil, nil, true), + Verify: snakeOilProv, + Router: router.NewManager(db, proxy.NewHybridTransportWithCalls(ft, ft)), + } + srv := NewHttpsServer(httpsConf) + + req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) + req.RemoteAddr = "127.0.0.1:1447" + assert.NoError(t, err) + + wg := &sync.WaitGroup{} + wg.Add(5) + for i := 0; i < 5; i++ { + go func() { + defer wg.Done() + rec := httptest.NewRecorder() + srv.Handler.ServeHTTP(rec, req) + res := rec.Result() + assert.Equal(t, http.StatusOK, res.StatusCode) + }() + } + wg.Wait() + + rec := httptest.NewRecorder() + srv.Handler.ServeHTTP(rec, req) + res := rec.Result() + assert.Equal(t, http.StatusTooManyRequests, res.StatusCode) +} diff --git a/utils/acme-challenges.go b/utils/acme-challenges.go new file mode 100644 index 0000000..5f206c9 --- /dev/null +++ b/utils/acme-challenges.go @@ -0,0 +1,55 @@ +package utils + +import "sync" + +type AcmeChallenges struct { + s *sync.RWMutex + d map[string]*AcmeStorage +} + +type AcmeStorage struct { + s *sync.RWMutex + v map[string]string +} + +func NewAcmeChallenge() *AcmeChallenges { + return &AcmeChallenges{ + s: &sync.RWMutex{}, + d: make(map[string]*AcmeStorage), + } +} + +func (a *AcmeChallenges) Get(domain, key string) string { + a.s.RLock() + defer a.s.RUnlock() + if m := a.d[domain]; m != nil { + m.s.RLock() + defer m.s.RUnlock() + return m.v[key] + } + return "" +} + +func (a *AcmeChallenges) Put(domain, key, value string) { + a.s.Lock() + m := a.d[domain] + if m == nil { + m = &AcmeStorage{ + s: &sync.RWMutex{}, + v: make(map[string]string), + } + a.d[domain] = m + } + m.s.Lock() + m.v[key] = value + m.s.Unlock() + a.s.Unlock() +} + +func (a *AcmeChallenges) Delete(domain, key string) { + a.s.Lock() + if m := a.d[domain]; m != nil { + delete(m.v, key) + } + a.s.Unlock() +} diff --git a/utils/acme-challenges_test.go b/utils/acme-challenges_test.go new file mode 100644 index 0000000..022273e --- /dev/null +++ b/utils/acme-challenges_test.go @@ -0,0 +1,27 @@ +package utils + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestAcmeChallenges(t *testing.T) { + a := NewAcmeChallenge() + assert.Equal(t, "", a.Get("example.com", "123")) + + // The challenge should be created + a.Put("example.com", "123", "123abc") + assert.Equal(t, "123abc", a.Get("example.com", "123")) + + // The challenge should be deleted + a.Delete("example.com", "123") + assert.Equal(t, "", a.Get("example.com", "123")) + + // This should not crash or stop execution + a.Delete("example.com", "123") + assert.Equal(t, "", a.Get("example.com", "123")) + + // This should not crash or stop execution + a.Delete("www.example.com", "123") + assert.Equal(t, "", a.Get("example.com", "123")) +} diff --git a/utils/domain-utils.go b/utils/domain-utils.go index fee486f..acedb6f 100644 --- a/utils/domain-utils.go +++ b/utils/domain-utils.go @@ -27,13 +27,13 @@ func SplitDomainPort(host string, defaultPort int) (domain string, port int, ok // without the port. // // example.com:443 => example.com -func GetDomainWithoutPort(domain string) (string, bool) { +func GetDomainWithoutPort(domain string) string { // if a valid index isn't found then return false n := strings.LastIndexByte(domain, ':') if n == -1 { - return "", false + return domain } - return domain[:n], true + return domain[:n] } // ReplaceSubdomainWithWildcard returns the domain with the subdomain replaced diff --git a/utils/domain-utils_test.go b/utils/domain-utils_test.go index 46251e1..63f51e6 100644 --- a/utils/domain-utils_test.go +++ b/utils/domain-utils_test.go @@ -18,12 +18,16 @@ func TestSplitDomainPort(t *testing.T) { } func TestDomainWithoutPort(t *testing.T) { - domain, ok := GetDomainWithoutPort("www.example.com:5612") - assert.True(t, ok, "Output should be true") + domain := GetDomainWithoutPort("www.example.com:5612") assert.Equal(t, "www.example.com", domain) - domain, ok = GetDomainWithoutPort("example.com:443") - assert.True(t, ok, "Output should be true") + domain = GetDomainWithoutPort("example.com:443") + assert.Equal(t, "example.com", domain) + + domain = GetDomainWithoutPort("www.example.com") + assert.Equal(t, "www.example.com", domain) + + domain = GetDomainWithoutPort("example.com") assert.Equal(t, "example.com", domain) } diff --git a/utils/multi-compilable_test.go b/utils/multi-compilable_test.go new file mode 100644 index 0000000..f771c91 --- /dev/null +++ b/utils/multi-compilable_test.go @@ -0,0 +1,22 @@ +package utils + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +type fakeCompile struct{ done bool } + +func (f *fakeCompile) Compile() { + f.done = true +} + +var _ Compilable = &fakeCompile{} + +func TestMultiCompilable_Compile(t *testing.T) { + f := &fakeCompile{} + a := MultiCompilable{f} + assert.False(t, f.done) + a.Compile() + assert.True(t, f.done) +} diff --git a/utils/response_test.go b/utils/response_test.go new file mode 100644 index 0000000..2f33781 --- /dev/null +++ b/utils/response_test.go @@ -0,0 +1,32 @@ +package utils + +import ( + "github.com/stretchr/testify/assert" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRespondHttpStatus(t *testing.T) { + rec := httptest.NewRecorder() + RespondHttpStatus(rec, http.StatusTeapot) + res := rec.Result() + assert.Equal(t, http.StatusTeapot, res.StatusCode) + assert.Equal(t, "418 I'm a teapot", res.Status) + a, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.Equal(t, "418 I'm a teapot\n\n", string(a)) +} + +func TestRespondVioletError(t *testing.T) { + rec := httptest.NewRecorder() + RespondVioletError(rec, http.StatusTeapot, "Hidden Error Message") + res := rec.Result() + assert.Equal(t, http.StatusTeapot, res.StatusCode) + assert.Equal(t, "418 I'm a teapot", res.Status) + a, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.Equal(t, "418 I'm a teapot\n\n", string(a)) + assert.Equal(t, "Hidden Error Message", res.Header.Get("X-Violet-Error")) +} diff --git a/utils/server-utils_test.go b/utils/server-utils_test.go new file mode 100644 index 0000000..d901986 --- /dev/null +++ b/utils/server-utils_test.go @@ -0,0 +1,20 @@ +package utils + +import ( + "github.com/stretchr/testify/assert" + "net/http" + "testing" +) + +func TestGetBearer(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "https://example.com", nil) + assert.NoError(t, err) + req.Header.Set("Authorization", "Bearer abc") + assert.Equal(t, "abc", GetBearer(req)) +} + +func TestGetBearer_Empty(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "https://example.com", nil) + assert.NoError(t, err) + assert.Equal(t, "", GetBearer(req)) +}