diff --git a/cmd/orchid/serve.go b/cmd/orchid/serve.go index eda812a..2caaf21 100644 --- a/cmd/orchid/serve.go +++ b/cmd/orchid/serve.go @@ -2,9 +2,9 @@ package main import ( "context" - "database/sql" "flag" "github.com/1f349/mjwt" + "github.com/1f349/orchid" httpAcme "github.com/1f349/orchid/http-acme" "github.com/1f349/orchid/renewal" "github.com/1f349/orchid/servers" @@ -70,7 +70,7 @@ func normalLoad(conf startUpConfig, wd string) { } // open sqlite database - db, err := sql.Open("sqlite3", filepath.Join(wd, "orchid.db.sqlite")) + db, err := orchid.InitDB(filepath.Join(wd, "orchid.db.sqlite")) if err != nil { log.Fatal("[Orchid] Failed to open database:", err) } diff --git a/database/certificate.sql.go b/database/certificate.sql.go index a217602..de6be34 100644 --- a/database/certificate.sql.go +++ b/database/certificate.sql.go @@ -33,6 +33,22 @@ func (q *Queries) AddCertificate(ctx context.Context, arg AddCertificateParams) return err } +const addTempCertificate = `-- name: AddTempCertificate :exec +INSERT INTO certificates (owner, dns, active, updated_at, temp_parent) +VALUES (?, NULL, 1, ?, ?) +` + +type AddTempCertificateParams struct { + Owner string `json:"owner"` + UpdatedAt time.Time `json:"updated_at"` + TempParent sql.NullInt64 `json:"temp_parent"` +} + +func (q *Queries) AddTempCertificate(ctx context.Context, arg AddTempCertificateParams) error { + _, err := q.db.ExecContext(ctx, addTempCertificate, arg.Owner, arg.UpdatedAt, arg.TempParent) + return err +} + const checkCertOwner = `-- name: CheckCertOwner :one SELECT id, owner FROM certificates diff --git a/database/certificate_domains.sql.go b/database/certificate_domains.sql.go index 678ba0f..5f328be 100644 --- a/database/certificate_domains.sql.go +++ b/database/certificate_domains.sql.go @@ -7,6 +7,7 @@ package database import ( "context" + "strings" ) const addDomains = `-- name: AddDomains :exec @@ -107,15 +108,26 @@ func (q *Queries) SetDomainStateForCert(ctx context.Context, arg SetDomainStateF const updateDomains = `-- name: UpdateDomains :exec UPDATE certificate_domains SET state = ? -WHERE domain IN ? +WHERE domain IN (/*SLICE:domains*/?) ` type UpdateDomainsParams struct { - State int64 `json:"state"` - Domain string `json:"domain"` + State int64 `json:"state"` + Domains []string `json:"domains"` } func (q *Queries) UpdateDomains(ctx context.Context, arg UpdateDomainsParams) error { - _, err := q.db.ExecContext(ctx, updateDomains, arg.State, arg.Domain) + query := updateDomains + var queryParams []interface{} + queryParams = append(queryParams, arg.State) + if len(arg.Domains) > 0 { + for _, v := range arg.Domains { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:domains*/?", strings.Repeat(",?", len(arg.Domains))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:domains*/?", "NULL", 1) + } + _, err := q.db.ExecContext(ctx, query, queryParams...) return err } diff --git a/database/queries/certificate.sql b/database/queries/certificate.sql index 15eb34b..45bb2f2 100644 --- a/database/queries/certificate.sql +++ b/database/queries/certificate.sql @@ -40,6 +40,10 @@ WHERE id = ?; INSERT INTO certificates (owner, dns, not_after, updated_at) VALUES (?, ?, ?, ?); +-- name: AddTempCertificate :exec +INSERT INTO certificates (owner, dns, active, updated_at, temp_parent) +VALUES (?, NULL, 1, ?, ?); + -- name: RemoveCertificate :exec UPDATE certificates SET active = 0 diff --git a/database/queries/certificate_domains.sql b/database/queries/certificate_domains.sql index 08f35f5..af191b5 100644 --- a/database/queries/certificate_domains.sql +++ b/database/queries/certificate_domains.sql @@ -20,4 +20,4 @@ VALUES (?, ?, ?); -- name: UpdateDomains :exec UPDATE certificate_domains SET state = ? -WHERE domain IN (sqlc.slice("domains")); +WHERE domain IN (sqlc.slice(domains)); diff --git a/renewal/service_test.go b/renewal/service_test.go index dea63eb..a9b2d44 100644 --- a/renewal/service_test.go +++ b/renewal/service_test.go @@ -9,6 +9,7 @@ import ( "database/sql" "encoding/pem" "fmt" + "github.com/1f349/orchid" "github.com/1f349/orchid/pebble" "github.com/1f349/orchid/test" "github.com/MrMelon54/certgen" @@ -97,11 +98,14 @@ func setupPebbleSuite(tb testing.TB) (*certgen.CertGen, func()) { } } -func setupPebbleTest(t *testing.T, serverTls *certgen.CertGen) *Service { +func setupPebbleTest(t *testing.T, serverTls *certgen.CertGen) (*Service, *sql.DB) { wg := &sync.WaitGroup{} dbFile := fmt.Sprintf("file:%s?mode=memory&cache=shared", uuid.NewString()) - db, err := sql.Open("sqlite3", dbFile) + db, err := orchid.InitDB(dbFile) assert.NoError(t, err) + db2, err := sql.Open("sqlite3", dbFile) + assert.NoError(t, err) + log.Println("DB File:", dbFile) certDir, err := os.MkdirTemp("", "orchid-certs") @@ -127,7 +131,7 @@ func setupPebbleTest(t *testing.T, serverTls *certgen.CertGen) *Service { assert.NoError(t, err) assert.NoError(t, os.WriteFile(filepath.Join(keyDir, "1.key.pem"), pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privKey)}), os.ModePerm)) - return service + return service, db2 } func TestPebbleRenewal(t *testing.T) { @@ -151,19 +155,19 @@ func TestPebbleRenewal(t *testing.T) { for _, i := range tests { t.Run(i.name, func(t *testing.T) { //t.Parallel() - service := setupPebbleTest(t, serverTls) + service, db2 := setupPebbleTest(t, serverTls) //goland:noinspection SqlWithoutWhere - _, err := service.db.Exec("DELETE FROM certificate_domains") + _, err := db2.Exec("DELETE FROM certificate_domains") assert.NoError(t, err) - _, err = service.db.Exec(`INSERT INTO certificates (owner, dns, auto_renew, active, renewing, renew_failed, not_after, updated_at) VALUES (1, 1, 1, 1, 0, 0, NULL, NULL)`) + _, err = db2.Exec(`INSERT INTO certificates (owner, dns, auto_renew, active, renewing, renew_failed, not_after, updated_at) VALUES (1, 1, 1, 1, 0, 0, 0, 0)`) assert.NoError(t, err) for _, j := range i.domains { - _, err = service.db.Exec(`INSERT INTO certificate_domains (cert_id, domain) VALUES (1, ?)`, j) + _, err = db2.Exec(`INSERT INTO certificate_domains (cert_id, domain) VALUES (1, ?)`, j) assert.NoError(t, err) } - query, err := service.db.Query("SELECT cert_id, domain from certificate_domains") + query, err := db2.Query("SELECT cert_id, domain from certificate_domains") assert.NoError(t, err) for query.Next() { var a uint64 diff --git a/servers/api.go b/servers/api.go index 042532e..05131d2 100644 --- a/servers/api.go +++ b/servers/api.go @@ -163,9 +163,13 @@ func NewApiServer(listen string, db *database.Queries, signer mjwt.Verifier, dom } // run a safe transaction to create the temporary certificate - if safeTransaction(rw, db, func(rw http.ResponseWriter, tx *sql.Tx) error { + if db.UseTx(req.Context(), func(tx *database.Queries) error { // insert temporary certificate into database - _, err := db.Exec(`INSERT INTO certificates (owner, dns, active, updated_at, temp_parent) VALUES (?, 0, 1, ?, ?)`, b.Subject, time.Now(), id) + err := tx.AddTempCertificate(req.Context(), database.AddTempCertificateParams{ + Owner: b.Subject, + UpdatedAt: time.Now(), + TempParent: sql.NullInt64{Valid: true, Int64: id}, + }) return err }) != nil { apiError(rw, http.StatusInsufficientStorage, "Database error") diff --git a/servers/certDomainManage.go b/servers/certDomainManage.go index c6b6e9a..a9204cc 100644 --- a/servers/certDomainManage.go +++ b/servers/certDomainManage.go @@ -13,8 +13,8 @@ import ( ) func certDomainManageGET(db *database.Queries, signer mjwt.Verifier) httprouter.Handle { - return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) { - rows, err := db.GetDomainStatesForCert(context.Background(), int64(certId)) + return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId int64) { + rows, err := db.GetDomainStatesForCert(context.Background(), certId) if err != nil { apiError(rw, http.StatusInsufficientStorage, "Database error") return @@ -71,8 +71,8 @@ func certDomainManagePUTandDELETE(db *database.Queries, signer mjwt.Verifier, do } else { // update domains to removed state err := tx.UpdateDomains(req.Context(), database.UpdateDomainsParams{ - State: renewal.DomainStateRemoved, - Domain: d, + State: renewal.DomainStateRemoved, + Domains: d, }) if err != nil { return fmt.Errorf("failed to remove domains from the database")