From e0fb935aaf3d06d64e176b56499f9222d0048933 Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Tue, 7 Jan 2025 23:51:31 +0000 Subject: [PATCH] Test site and branch names --- upload/upload.go | 67 ++++++++++++++++++++++++++++++++++++++++++- upload/upload_test.go | 62 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 127 insertions(+), 2 deletions(-) diff --git a/upload/upload.go b/upload/upload.go index 703aa21..7fee8b7 100644 --- a/upload/upload.go +++ b/upload/upload.go @@ -24,6 +24,59 @@ var indexBranches = []string{ "master", } +func containsOnly(s string, f func(r rune) bool) bool { + for _, r := range []rune(s) { + if !f(r) { + return false + } + } + return true +} + +func isValidSite(site string) bool { + if len(site) < 1 || site[0] == '-' { + return false + } + switch site[0] { + case '-': + return false + } + return containsOnly(site, func(r rune) bool { + return isAlphanumericOrDash(r) || r == '.' + }) +} + +func isValidBranch(branch string) bool { + if len(branch) < 1 { + return false + } + switch branch[0] { + case '-', '/': + return false + } + if branch[len(branch)-1] == '/' { + return false + } + return containsOnly(branch, func(r rune) bool { + return isAlphanumericOrDash(r) || r == '/' || r == '.' + }) +} + +func isAlphanumericOrDash(r rune) bool { + switch { + case r >= '0' && r <= '9': + return true + case r >= 'a' && r <= 'z': + return true + case r >= 'A' && r <= 'Z': + return true + case r == '-', r == '_': + return true + default: + return false + } +} + type sitesQueries interface { GetSiteByDomain(ctx context.Context, domain string) (database.Site, error) } @@ -76,14 +129,26 @@ func (h *Handler) Handle(rw http.ResponseWriter, req *http.Request, params httpr } func (h *Handler) extractTarGzUpload(fileData io.Reader, site, branch string) error { + if !isValidSite(site) { + return fmt.Errorf("invalid site name: %s", site) + } + if !isValidBranch(branch) { + return fmt.Errorf("invalid branch name: %s", branch) + } if slices.Contains(indexBranches, branch) { branch = "" } + + _, err := h.db.GetSiteByDomain(context.Background(), site) + if err != nil { + return fmt.Errorf("invalid site: %w", err) + } + siteBranchPath := filepath.Join(site, "@"+branch) siteBranchOldPath := filepath.Join(site, "old@"+branch) // try the new "old@[...]" and old "@[...].old" paths - err := h.storageFs.RemoveAll(siteBranchPath + ".old") + err = h.storageFs.RemoveAll(siteBranchPath + ".old") if err != nil && !errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("failed to remove old site branch %s: %w", siteBranchPath, err) } diff --git a/upload/upload_test.go b/upload/upload_test.go index a874079..4e05053 100644 --- a/upload/upload_test.go +++ b/upload/upload_test.go @@ -17,6 +17,66 @@ import ( "testing" ) +func TestIsValidSite(t *testing.T) { + for _, i := range []struct { + s string + valid bool + }{ + {"", false}, + {"a", true}, + {"abc", true}, + {"0", true}, + {"0123456789", true}, + {"_", true}, + {"_ab", true}, + {"-", false}, + {"-ab", false}, + {".", true}, + {".ab", true}, + {"a-b", true}, + {"a_b", true}, + {"a/b", false}, + {"a.b", true}, + {"ab-", true}, + {"ab_", true}, + {"ab.", true}, + {"/ab", false}, + {"ab/", false}, + } { + assert.Equal(t, i.valid, isValidSite(i.s), "Test failed \"%s\" - %v", i.s, i.valid) + } +} + +func TestIsValidBranch(t *testing.T) { + for _, i := range []struct { + s string + valid bool + }{ + {"", false}, + {"a", true}, + {"abc", true}, + {"0", true}, + {"0123456789", true}, + {"_", true}, + {"_ab", true}, + {"-", false}, + {"-ab", false}, + {".", true}, + {".ab", true}, + {"a-b", true}, + {"a_b", true}, + {"a/b", true}, + {"a.b", true}, + {"ab-", true}, + {"ab_", true}, + {"ab.", true}, + {"/ab", false}, + {"ab/", false}, + } { + assert.Equal(t, i.valid, isValidBranch(i.s), "Test failed \"%s\" - %v", i.s, i.valid) + } +} + var ( //go:embed test-archive.tar.gz testArchiveTarGz []byte @@ -100,7 +160,7 @@ func TestHandler_extractTarGzUpload(t *testing.T) { for _, branch := range []string{"main", "test", "dev"} { t.Run(branch+" branch", func(t *testing.T) { fs := afero.NewMemMapFs() - h := New(fs, nil) + h := New(fs, new(fakeUploadDB)) buffer := bytes.NewBuffer(testArchiveTarGz) assert.NoError(t, h.extractTarGzUpload(buffer, "example.com", branch))