Update branches table after updating the site

This also adds tests to ensure multiple calls to the upload endpoint lock correctly
This commit is contained in:
Melon 2025-01-08 17:58:56 +00:00
parent 4ecd88c0c0
commit 5f50ec5a56
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
4 changed files with 151 additions and 2 deletions

1
go.mod
View File

@ -9,6 +9,7 @@ require (
github.com/cloudflare/tableflip v1.2.3 github.com/cloudflare/tableflip v1.2.3
github.com/dustin/go-humanize v1.0.1 github.com/dustin/go-humanize v1.0.1
github.com/golang-migrate/migrate/v4 v4.18.1 github.com/golang-migrate/migrate/v4 v4.18.1
github.com/google/uuid v1.6.0
github.com/julienschmidt/httprouter v1.3.0 github.com/julienschmidt/httprouter v1.3.0
github.com/spf13/afero v1.11.0 github.com/spf13/afero v1.11.0
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.10.0

2
go.sum
View File

@ -30,6 +30,8 @@ github.com/golang-migrate/migrate/v4 v4.18.1 h1:JML/k+t4tpHCpQTCAD62Nu43NUFzHY4C
github.com/golang-migrate/migrate/v4 v4.18.1/go.mod h1:HAX6m3sQgcdO81tdjn5exv20+3Kb13cmGli1hrD6hks= github.com/golang-migrate/migrate/v4 v4.18.1/go.mod h1:HAX6m3sQgcdO81tdjn5exv20+3Kb13cmGli1hrD6hks=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=

View File

@ -20,6 +20,7 @@ import (
"slices" "slices"
"strings" "strings"
"sync" "sync"
"time"
) )
var indexBranches = []string{ var indexBranches = []string{
@ -29,6 +30,8 @@ var indexBranches = []string{
type uploadQueries interface { type uploadQueries interface {
GetSiteByDomain(ctx context.Context, domain string) (database.Site, error) GetSiteByDomain(ctx context.Context, domain string) (database.Site, error)
AddBranch(ctx context.Context, arg database.AddBranchParams) error
UpdateBranch(ctx context.Context, arg database.UpdateBranchParams) error
} }
func New(storage afero.Fs, db uploadQueries) *Handler { func New(storage afero.Fs, db uploadQueries) *Handler {
@ -165,5 +168,21 @@ func (h *Handler) extractTarGzUpload(fileData io.Reader, site, branch string) er
return fmt.Errorf("failed to copy from archive to output file: '%s': %w", next.Name, err) return fmt.Errorf("failed to copy from archive to output file: '%s': %w", next.Name, err)
} }
} }
n := time.Now().UTC()
err = h.db.AddBranch(context.Background(), database.AddBranchParams{
Branch: branch,
Domain: site,
LastUpdate: n,
Enable: true,
})
if err != nil {
return h.db.UpdateBranch(context.Background(), database.UpdateBranchParams{
Branch: branch,
Domain: site,
LastUpdate: n,
})
}
return nil return nil
} }

View File

@ -6,7 +6,9 @@ import (
"database/sql" "database/sql"
_ "embed" _ "embed"
"fmt" "fmt"
"github.com/1f349/bluebell"
"github.com/1f349/bluebell/database" "github.com/1f349/bluebell/database"
"github.com/google/uuid"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -14,6 +16,8 @@ import (
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync"
"sync/atomic"
"testing" "testing"
) )
@ -22,6 +26,21 @@ var (
testArchiveTarGz []byte testArchiveTarGz []byte
) )
func initMemoryDB(t *testing.T) (*sql.DB, *database.Queries) {
rawDB, err := sql.Open("sqlite3", "file:"+uuid.NewString()+"?mode=memory&cache=shared")
assert.NoError(t, err)
db, err := bluebell.InitRawDB(rawDB)
assert.NoError(t, err)
assert.NoError(t, db.AddSite(context.Background(), database.AddSiteParams{
Domain: "example.com",
Token: "abcd1234",
}))
return rawDB, db
}
func assertUploadedFile(t *testing.T, fs afero.Fs, branch string) { func assertUploadedFile(t *testing.T, fs afero.Fs, branch string) {
switch branch { switch branch {
case "main", "master": case "main", "master":
@ -43,6 +62,8 @@ func assertUploadedFile(t *testing.T, fs afero.Fs, branch string) {
} }
type fakeUploadDB struct { type fakeUploadDB struct {
branchesMu sync.Mutex
branchesMap map[string]database.Branch
} }
func (f *fakeUploadDB) GetSiteByDomain(_ context.Context, domain string) (database.Site, error) { func (f *fakeUploadDB) GetSiteByDomain(_ context.Context, domain string) (database.Site, error) {
@ -55,6 +76,42 @@ func (f *fakeUploadDB) GetSiteByDomain(_ context.Context, domain string) (databa
return database.Site{}, sql.ErrNoRows return database.Site{}, sql.ErrNoRows
} }
func (f *fakeUploadDB) AddBranch(ctx context.Context, arg database.AddBranchParams) error {
f.branchesMu.Lock()
defer f.branchesMu.Unlock()
if f.branchesMap == nil {
f.branchesMap = make(map[string]database.Branch)
}
key := arg.Domain + "@" + arg.Branch
_, exists := f.branchesMap[key]
if exists {
return fmt.Errorf("primary key constraint failed")
}
f.branchesMap[key] = database.Branch{
Domain: arg.Domain,
Branch: arg.Branch,
LastUpdate: arg.LastUpdate,
Enable: arg.Enable,
}
return nil
}
func (f *fakeUploadDB) UpdateBranch(ctx context.Context, arg database.UpdateBranchParams) error {
f.branchesMu.Lock()
defer f.branchesMu.Unlock()
if f.branchesMap == nil {
f.branchesMap = make(map[string]database.Branch)
}
item, exists := f.branchesMap[arg.Domain+"@"+arg.Branch]
if !exists {
return sql.ErrNoRows
}
item.LastUpdate = arg.LastUpdate
return nil
}
func TestHandler_Handle(t *testing.T) { func TestHandler_Handle(t *testing.T) {
fs := afero.NewMemMapFs() fs := afero.NewMemMapFs()
h := New(fs, new(fakeUploadDB)) h := New(fs, new(fakeUploadDB))
@ -95,11 +152,20 @@ func TestHandler_Handle(t *testing.T) {
} }
} }
func TestHandler_extractTarGzUpload(t *testing.T) { func TestHandler_extractTarGzUpload_fakeDB(t *testing.T) {
extractTarGzUploadTest(t, new(fakeUploadDB))
}
func TestHandler_extractTarGzUpload_memoryDB(t *testing.T) {
_, db := initMemoryDB(t)
extractTarGzUploadTest(t, db)
}
func extractTarGzUploadTest(t *testing.T, db uploadQueries) {
for _, branch := range []string{"main", "test", "dev"} { for _, branch := range []string{"main", "test", "dev"} {
t.Run(branch+" branch", func(t *testing.T) { t.Run(branch+" branch", func(t *testing.T) {
fs := afero.NewMemMapFs() fs := afero.NewMemMapFs()
h := New(fs, new(fakeUploadDB)) h := New(fs, db)
buffer := bytes.NewBuffer(testArchiveTarGz) buffer := bytes.NewBuffer(testArchiveTarGz)
assert.NoError(t, h.extractTarGzUpload(buffer, "example.com", branch)) assert.NoError(t, h.extractTarGzUpload(buffer, "example.com", branch))
@ -107,3 +173,64 @@ func TestHandler_extractTarGzUpload(t *testing.T) {
}) })
} }
} }
func TestHandler_extractTarGzUpload_fakeDB_multiple(t *testing.T) {
extractTarGzUploadMultipleTest(t, new(fakeUploadDB))
}
func TestHandler_extractTarGzUpload_memoryDB_multiple(t *testing.T) {
_, db := initMemoryDB(t)
extractTarGzUploadMultipleTest(t, db)
}
func extractTarGzUploadMultipleTest(t *testing.T, db uploadQueries) {
fs := afero.NewMemMapFs()
h := New(fs, db)
sig := new(atomic.Bool)
wg := new(sync.WaitGroup)
const callCount = 2
wg.Add(callCount)
for range callCount {
go func() {
defer wg.Done()
buffer := newSingleBufferReader(testArchiveTarGz, sig)
assert.NoError(t, h.extractTarGzUpload(buffer, "example.com", "main"))
assertUploadedFile(t, fs, "main")
}()
}
wg.Wait()
assertUploadedFile(t, fs, "main")
}
type singleBufferReader struct {
r io.Reader
signal *atomic.Bool
active bool
size int
}
func (s *singleBufferReader) Read(b []byte) (n int, err error) {
// only check the signal if this reader is not active
if !s.active {
old := s.signal.Swap(true)
if old {
panic("singleBufferReader: peer buffer is already being used")
}
s.active = true
}
n, err = s.r.Read(b)
s.size -= n
if s.size <= 0 {
s.active = false
s.signal.Store(false)
}
return n, err
}
func newSingleBufferReader(buf []byte, b *atomic.Bool) *singleBufferReader {
return &singleBufferReader{r: bytes.NewReader(buf), signal: b, size: len(buf)}
}