diff --git a/go.mod b/go.mod index bdd51ae..0d5230f 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/cloudflare/tableflip v1.2.3 github.com/dustin/go-humanize v1.0.1 github.com/golang-migrate/migrate/v4 v4.18.1 + github.com/google/uuid v1.6.0 github.com/julienschmidt/httprouter v1.3.0 github.com/spf13/afero v1.11.0 github.com/stretchr/testify v1.10.0 diff --git a/go.sum b/go.sum index e31a4d7..7330e28 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/golang-migrate/migrate/v4 v4.18.1 h1:JML/k+t4tpHCpQTCAD62Nu43NUFzHY4C github.com/golang-migrate/migrate/v4 v4.18.1/go.mod h1:HAX6m3sQgcdO81tdjn5exv20+3Kb13cmGli1hrD6hks= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= diff --git a/upload/upload.go b/upload/upload.go index c5282b6..a40b75d 100644 --- a/upload/upload.go +++ b/upload/upload.go @@ -20,6 +20,7 @@ import ( "slices" "strings" "sync" + "time" ) var indexBranches = []string{ @@ -29,6 +30,8 @@ var indexBranches = []string{ type uploadQueries interface { GetSiteByDomain(ctx context.Context, domain string) (database.Site, error) + AddBranch(ctx context.Context, arg database.AddBranchParams) error + UpdateBranch(ctx context.Context, arg database.UpdateBranchParams) error } func New(storage afero.Fs, db uploadQueries) *Handler { @@ -165,5 +168,21 @@ func (h *Handler) extractTarGzUpload(fileData io.Reader, site, branch string) er return fmt.Errorf("failed to copy from archive to output file: '%s': %w", next.Name, err) } } + + n := time.Now().UTC() + + err = h.db.AddBranch(context.Background(), database.AddBranchParams{ + Branch: branch, + Domain: site, + LastUpdate: n, + Enable: true, + }) + if err != nil { + return h.db.UpdateBranch(context.Background(), database.UpdateBranchParams{ + Branch: branch, + Domain: site, + LastUpdate: n, + }) + } return nil } diff --git a/upload/upload_test.go b/upload/upload_test.go index 1bdbdfe..abaeb42 100644 --- a/upload/upload_test.go +++ b/upload/upload_test.go @@ -6,7 +6,9 @@ import ( "database/sql" _ "embed" "fmt" + "github.com/1f349/bluebell" "github.com/1f349/bluebell/database" + "github.com/google/uuid" "github.com/julienschmidt/httprouter" "github.com/spf13/afero" "github.com/stretchr/testify/assert" @@ -14,6 +16,8 @@ import ( "mime/multipart" "net/http" "net/http/httptest" + "sync" + "sync/atomic" "testing" ) @@ -22,6 +26,21 @@ var ( testArchiveTarGz []byte ) +func initMemoryDB(t *testing.T) (*sql.DB, *database.Queries) { + rawDB, err := sql.Open("sqlite3", "file:"+uuid.NewString()+"?mode=memory&cache=shared") + assert.NoError(t, err) + + db, err := bluebell.InitRawDB(rawDB) + assert.NoError(t, err) + + assert.NoError(t, db.AddSite(context.Background(), database.AddSiteParams{ + Domain: "example.com", + Token: "abcd1234", + })) + + return rawDB, db +} + func assertUploadedFile(t *testing.T, fs afero.Fs, branch string) { switch branch { case "main", "master": @@ -43,6 +62,8 @@ func assertUploadedFile(t *testing.T, fs afero.Fs, branch string) { } type fakeUploadDB struct { + branchesMu sync.Mutex + branchesMap map[string]database.Branch } func (f *fakeUploadDB) GetSiteByDomain(_ context.Context, domain string) (database.Site, error) { @@ -55,6 +76,42 @@ func (f *fakeUploadDB) GetSiteByDomain(_ context.Context, domain string) (databa return database.Site{}, sql.ErrNoRows } +func (f *fakeUploadDB) AddBranch(ctx context.Context, arg database.AddBranchParams) error { + f.branchesMu.Lock() + defer f.branchesMu.Unlock() + if f.branchesMap == nil { + f.branchesMap = make(map[string]database.Branch) + } + + key := arg.Domain + "@" + arg.Branch + _, exists := f.branchesMap[key] + if exists { + return fmt.Errorf("primary key constraint failed") + } + f.branchesMap[key] = database.Branch{ + Domain: arg.Domain, + Branch: arg.Branch, + LastUpdate: arg.LastUpdate, + Enable: arg.Enable, + } + return nil +} + +func (f *fakeUploadDB) UpdateBranch(ctx context.Context, arg database.UpdateBranchParams) error { + f.branchesMu.Lock() + defer f.branchesMu.Unlock() + if f.branchesMap == nil { + f.branchesMap = make(map[string]database.Branch) + } + + item, exists := f.branchesMap[arg.Domain+"@"+arg.Branch] + if !exists { + return sql.ErrNoRows + } + item.LastUpdate = arg.LastUpdate + return nil +} + func TestHandler_Handle(t *testing.T) { fs := afero.NewMemMapFs() h := New(fs, new(fakeUploadDB)) @@ -95,11 +152,20 @@ func TestHandler_Handle(t *testing.T) { } } -func TestHandler_extractTarGzUpload(t *testing.T) { +func TestHandler_extractTarGzUpload_fakeDB(t *testing.T) { + extractTarGzUploadTest(t, new(fakeUploadDB)) +} + +func TestHandler_extractTarGzUpload_memoryDB(t *testing.T) { + _, db := initMemoryDB(t) + extractTarGzUploadTest(t, db) +} + +func extractTarGzUploadTest(t *testing.T, db uploadQueries) { for _, branch := range []string{"main", "test", "dev"} { t.Run(branch+" branch", func(t *testing.T) { fs := afero.NewMemMapFs() - h := New(fs, new(fakeUploadDB)) + h := New(fs, db) buffer := bytes.NewBuffer(testArchiveTarGz) assert.NoError(t, h.extractTarGzUpload(buffer, "example.com", branch)) @@ -107,3 +173,64 @@ func TestHandler_extractTarGzUpload(t *testing.T) { }) } } + +func TestHandler_extractTarGzUpload_fakeDB_multiple(t *testing.T) { + extractTarGzUploadMultipleTest(t, new(fakeUploadDB)) +} + +func TestHandler_extractTarGzUpload_memoryDB_multiple(t *testing.T) { + _, db := initMemoryDB(t) + extractTarGzUploadMultipleTest(t, db) +} + +func extractTarGzUploadMultipleTest(t *testing.T, db uploadQueries) { + fs := afero.NewMemMapFs() + h := New(fs, db) + sig := new(atomic.Bool) + wg := new(sync.WaitGroup) + + const callCount = 2 + wg.Add(callCount) + + for range callCount { + go func() { + defer wg.Done() + buffer := newSingleBufferReader(testArchiveTarGz, sig) + assert.NoError(t, h.extractTarGzUpload(buffer, "example.com", "main")) + assertUploadedFile(t, fs, "main") + }() + } + + wg.Wait() + + assertUploadedFile(t, fs, "main") +} + +type singleBufferReader struct { + r io.Reader + signal *atomic.Bool + active bool + size int +} + +func (s *singleBufferReader) Read(b []byte) (n int, err error) { + // only check the signal if this reader is not active + if !s.active { + old := s.signal.Swap(true) + if old { + panic("singleBufferReader: peer buffer is already being used") + } + s.active = true + } + n, err = s.r.Read(b) + s.size -= n + if s.size <= 0 { + s.active = false + s.signal.Store(false) + } + return n, err +} + +func newSingleBufferReader(buf []byte, b *atomic.Bool) *singleBufferReader { + return &singleBufferReader{r: bytes.NewReader(buf), signal: b, size: len(buf)} +}