Test site and branch names

This commit is contained in:
Melon 2025-01-07 23:51:31 +00:00
parent 731a9953f1
commit e0fb935aaf
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
2 changed files with 127 additions and 2 deletions

View File

@ -24,6 +24,59 @@ var indexBranches = []string{
"master", "master",
} }
func containsOnly(s string, f func(r rune) bool) bool {
for _, r := range []rune(s) {
if !f(r) {
return false
}
}
return true
}
func isValidSite(site string) bool {
if len(site) < 1 || site[0] == '-' {
return false
}
switch site[0] {
case '-':
return false
}
return containsOnly(site, func(r rune) bool {
return isAlphanumericOrDash(r) || r == '.'
})
}
func isValidBranch(branch string) bool {
if len(branch) < 1 {
return false
}
switch branch[0] {
case '-', '/':
return false
}
if branch[len(branch)-1] == '/' {
return false
}
return containsOnly(branch, func(r rune) bool {
return isAlphanumericOrDash(r) || r == '/' || r == '.'
})
}
func isAlphanumericOrDash(r rune) bool {
switch {
case r >= '0' && r <= '9':
return true
case r >= 'a' && r <= 'z':
return true
case r >= 'A' && r <= 'Z':
return true
case r == '-', r == '_':
return true
default:
return false
}
}
type sitesQueries interface { type sitesQueries interface {
GetSiteByDomain(ctx context.Context, domain string) (database.Site, error) GetSiteByDomain(ctx context.Context, domain string) (database.Site, error)
} }
@ -76,14 +129,26 @@ func (h *Handler) Handle(rw http.ResponseWriter, req *http.Request, params httpr
} }
func (h *Handler) extractTarGzUpload(fileData io.Reader, site, branch string) error { func (h *Handler) extractTarGzUpload(fileData io.Reader, site, branch string) error {
if !isValidSite(site) {
return fmt.Errorf("invalid site name: %s", site)
}
if !isValidBranch(branch) {
return fmt.Errorf("invalid branch name: %s", branch)
}
if slices.Contains(indexBranches, branch) { if slices.Contains(indexBranches, branch) {
branch = "" branch = ""
} }
_, err := h.db.GetSiteByDomain(context.Background(), site)
if err != nil {
return fmt.Errorf("invalid site: %w", err)
}
siteBranchPath := filepath.Join(site, "@"+branch) siteBranchPath := filepath.Join(site, "@"+branch)
siteBranchOldPath := filepath.Join(site, "old@"+branch) siteBranchOldPath := filepath.Join(site, "old@"+branch)
// try the new "old@[...]" and old "@[...].old" paths // try the new "old@[...]" and old "@[...].old" paths
err := h.storageFs.RemoveAll(siteBranchPath + ".old") err = h.storageFs.RemoveAll(siteBranchPath + ".old")
if err != nil && !errors.Is(err, fs.ErrNotExist) { if err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("failed to remove old site branch %s: %w", siteBranchPath, err) return fmt.Errorf("failed to remove old site branch %s: %w", siteBranchPath, err)
} }

View File

@ -17,6 +17,66 @@ import (
"testing" "testing"
) )
func TestIsValidSite(t *testing.T) {
for _, i := range []struct {
s string
valid bool
}{
{"", false},
{"a", true},
{"abc", true},
{"0", true},
{"0123456789", true},
{"_", true},
{"_ab", true},
{"-", false},
{"-ab", false},
{".", true},
{".ab", true},
{"a-b", true},
{"a_b", true},
{"a/b", false},
{"a.b", true},
{"ab-", true},
{"ab_", true},
{"ab.", true},
{"/ab", false},
{"ab/", false},
} {
assert.Equal(t, i.valid, isValidSite(i.s), "Test failed \"%s\" - %v", i.s, i.valid)
}
}
func TestIsValidBranch(t *testing.T) {
for _, i := range []struct {
s string
valid bool
}{
{"", false},
{"a", true},
{"abc", true},
{"0", true},
{"0123456789", true},
{"_", true},
{"_ab", true},
{"-", false},
{"-ab", false},
{".", true},
{".ab", true},
{"a-b", true},
{"a_b", true},
{"a/b", true},
{"a.b", true},
{"ab-", true},
{"ab_", true},
{"ab.", true},
{"/ab", false},
{"ab/", false},
} {
assert.Equal(t, i.valid, isValidBranch(i.s), "Test failed \"%s\" - %v", i.s, i.valid)
}
}
var ( var (
//go:embed test-archive.tar.gz //go:embed test-archive.tar.gz
testArchiveTarGz []byte testArchiveTarGz []byte
@ -100,7 +160,7 @@ func TestHandler_extractTarGzUpload(t *testing.T) {
for _, branch := range []string{"main", "test", "dev"} { for _, branch := range []string{"main", "test", "dev"} {
t.Run(branch+" branch", func(t *testing.T) { t.Run(branch+" branch", func(t *testing.T) {
fs := afero.NewMemMapFs() fs := afero.NewMemMapFs()
h := New(fs, nil) h := New(fs, new(fakeUploadDB))
buffer := bytes.NewBuffer(testArchiveTarGz) buffer := bytes.NewBuffer(testArchiveTarGz)
assert.NoError(t, h.extractTarGzUpload(buffer, "example.com", branch)) assert.NoError(t, h.extractTarGzUpload(buffer, "example.com", branch))