From 2193c44252826aadac6c4cb2cc83fde25d70ca30 Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Wed, 8 Jan 2025 00:59:27 +0000 Subject: [PATCH] Move validation to a separate package --- api/api.go | 11 ++++++ upload/upload.go | 58 ++---------------------------- upload/upload_test.go | 60 ------------------------------- validation/validation.go | 54 ++++++++++++++++++++++++++++ validation/validation_test.go | 66 +++++++++++++++++++++++++++++++++++ 5 files changed, 134 insertions(+), 115 deletions(-) create mode 100644 validation/validation.go create mode 100644 validation/validation_test.go diff --git a/api/api.go b/api/api.go index f5d5e71..1cc07d6 100644 --- a/api/api.go +++ b/api/api.go @@ -5,6 +5,7 @@ import ( "encoding/json" "github.com/1f349/bluebell/database" "github.com/1f349/bluebell/upload" + "github.com/1f349/bluebell/validation" "github.com/1f349/mjwt" "github.com/1f349/mjwt/auth" "github.com/julienschmidt/httprouter" @@ -32,6 +33,16 @@ func setEnabled(rw http.ResponseWriter, req *http.Request, params httprouter.Par host := params.ByName("host") branch := params.ByName("branch") + if !validation.IsValidSite(host) { + http.Error(rw, "Invalid site", http.StatusBadRequest) + return + } + + if !validation.IsValidBranch(branch) { + http.Error(rw, "Invalid branch", http.StatusBadRequest) + return + } + if !validateDomainOwnershipClaims(host, b.Claims.Perms) { http.Error(rw, "Forbidden", http.StatusForbidden) return diff --git a/upload/upload.go b/upload/upload.go index 3ecc55f..b456e8a 100644 --- a/upload/upload.go +++ b/upload/upload.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "github.com/1f349/bluebell/database" + "github.com/1f349/bluebell/validation" "github.com/1f349/syncmap" "github.com/dustin/go-humanize" "github.com/julienschmidt/httprouter" @@ -26,59 +27,6 @@ var indexBranches = []string{ "master", } -func containsOnly(s string, f func(r rune) bool) bool { - for _, r := range []rune(s) { - if !f(r) { - return false - } - } - return true -} - -func isValidSite(site string) bool { - if len(site) < 1 || site[0] == '-' { - return false - } - switch site[0] { - case '-': - return false - } - return containsOnly(site, func(r rune) bool { - return isAlphanumericOrDash(r) || r == '.' - }) -} - -func isValidBranch(branch string) bool { - if len(branch) < 1 { - return false - } - switch branch[0] { - case '-', '/': - return false - } - if branch[len(branch)-1] == '/' { - return false - } - return containsOnly(branch, func(r rune) bool { - return isAlphanumericOrDash(r) || r == '/' || r == '.' - }) -} - -func isAlphanumericOrDash(r rune) bool { - switch { - case r >= '0' && r <= '9': - return true - case r >= 'a' && r <= 'z': - return true - case r >= 'A' && r <= 'Z': - return true - case r == '-', r == '_': - return true - default: - return false - } -} - type sitesQueries interface { GetSiteByDomain(ctx context.Context, domain string) (database.Site, error) } @@ -132,10 +80,10 @@ func (h *Handler) Handle(rw http.ResponseWriter, req *http.Request, params httpr } func (h *Handler) extractTarGzUpload(fileData io.Reader, site, branch string) error { - if !isValidSite(site) { + if !validation.IsValidSite(site) { return fmt.Errorf("invalid site name: %s", site) } - if !isValidBranch(branch) { + if !validation.IsValidBranch(branch) { return fmt.Errorf("invalid branch name: %s", branch) } if slices.Contains(indexBranches, branch) { diff --git a/upload/upload_test.go b/upload/upload_test.go index 4e05053..1cf5030 100644 --- a/upload/upload_test.go +++ b/upload/upload_test.go @@ -17,66 +17,6 @@ import ( "testing" ) -func TestIsValidSite(t *testing.T) { - for _, i := range []struct { - s string - valid bool - }{ - {"", false}, - {"a", true}, - {"abc", true}, - {"0", true}, - {"0123456789", true}, - {"_", true}, - {"_ab", true}, - {"-", false}, - {"-ab", false}, - {".", true}, - {".ab", true}, - {"a-b", true}, - {"a_b", true}, - {"a/b", false}, - {"a.b", true}, - {"ab-", true}, - {"ab_", true}, - {"ab.", true}, - {"/ab", false}, - {"ab/", false}, - } { - assert.Equal(t, i.valid, isValidSite(i.s), "Test failed \"%s\" - %v", i.s, i.valid) - } -} - -func TestIsValidBranch(t *testing.T) { - for _, i := range []struct { - s string - valid bool - }{ - {"", false}, - {"a", true}, - {"abc", true}, - {"0", true}, - {"0123456789", true}, - {"_", true}, - {"_ab", true}, - {"-", false}, - {"-ab", false}, - {".", true}, - {".ab", true}, - {"a-b", true}, - {"a_b", true}, - {"a/b", true}, - {"a.b", true}, - {"ab-", true}, - {"ab_", true}, - {"ab.", true}, - {"/ab", false}, - {"ab/", false}, - } { - assert.Equal(t, i.valid, isValidBranch(i.s), "Test failed \"%s\" - %v", i.s, i.valid) - } -} - var ( //go:embed test-archive.tar.gz testArchiveTarGz []byte diff --git a/validation/validation.go b/validation/validation.go new file mode 100644 index 0000000..aa68ee5 --- /dev/null +++ b/validation/validation.go @@ -0,0 +1,54 @@ +package validation + +func IsValidSite(site string) bool { + if len(site) < 1 || site[0] == '-' { + return false + } + switch site[0] { + case '-': + return false + } + return containsOnly(site, func(r rune) bool { + return isAlphanumericOrDash(r) || r == '.' + }) +} + +func IsValidBranch(branch string) bool { + if len(branch) < 1 { + return false + } + switch branch[0] { + case '-', '/': + return false + } + if branch[len(branch)-1] == '/' { + return false + } + return containsOnly(branch, func(r rune) bool { + return isAlphanumericOrDash(r) || r == '/' || r == '.' + }) +} + +func isAlphanumericOrDash(r rune) bool { + switch { + case r >= '0' && r <= '9': + return true + case r >= 'a' && r <= 'z': + return true + case r >= 'A' && r <= 'Z': + return true + case r == '-', r == '_': + return true + default: + return false + } +} + +func containsOnly(s string, f func(r rune) bool) bool { + for _, r := range []rune(s) { + if !f(r) { + return false + } + } + return true +} diff --git a/validation/validation_test.go b/validation/validation_test.go new file mode 100644 index 0000000..382f2ff --- /dev/null +++ b/validation/validation_test.go @@ -0,0 +1,66 @@ +package validation + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestIsValidSite(t *testing.T) { + for _, i := range []struct { + s string + valid bool + }{ + {"", false}, + {"a", true}, + {"abc", true}, + {"0", true}, + {"0123456789", true}, + {"_", true}, + {"_ab", true}, + {"-", false}, + {"-ab", false}, + {".", true}, + {".ab", true}, + {"a-b", true}, + {"a_b", true}, + {"a/b", false}, + {"a.b", true}, + {"ab-", true}, + {"ab_", true}, + {"ab.", true}, + {"/ab", false}, + {"ab/", false}, + } { + assert.Equal(t, i.valid, IsValidSite(i.s), "Test failed \"%s\" - %v", i.s, i.valid) + } +} + +func TestIsValidBranch(t *testing.T) { + for _, i := range []struct { + s string + valid bool + }{ + {"", false}, + {"a", true}, + {"abc", true}, + {"0", true}, + {"0123456789", true}, + {"_", true}, + {"_ab", true}, + {"-", false}, + {"-ab", false}, + {".", true}, + {".ab", true}, + {"a-b", true}, + {"a_b", true}, + {"a/b", true}, + {"a.b", true}, + {"ab-", true}, + {"ab_", true}, + {"ab.", true}, + {"/ab", false}, + {"ab/", false}, + } { + assert.Equal(t, i.valid, IsValidBranch(i.s), "Test failed \"%s\" - %v", i.s, i.valid) + } +}