Continue changes for sqlc

This commit is contained in:
Melon 2024-03-09 00:55:06 +00:00
parent 094ac9030a
commit 2b160d4309
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
8 changed files with 61 additions and 21 deletions

View File

@ -2,9 +2,9 @@ package main
import (
"context"
"database/sql"
"flag"
"github.com/1f349/mjwt"
"github.com/1f349/orchid"
httpAcme "github.com/1f349/orchid/http-acme"
"github.com/1f349/orchid/renewal"
"github.com/1f349/orchid/servers"
@ -70,7 +70,7 @@ func normalLoad(conf startUpConfig, wd string) {
}
// open sqlite database
db, err := sql.Open("sqlite3", filepath.Join(wd, "orchid.db.sqlite"))
db, err := orchid.InitDB(filepath.Join(wd, "orchid.db.sqlite"))
if err != nil {
log.Fatal("[Orchid] Failed to open database:", err)
}

View File

@ -33,6 +33,22 @@ func (q *Queries) AddCertificate(ctx context.Context, arg AddCertificateParams)
return err
}
const addTempCertificate = `-- name: AddTempCertificate :exec
INSERT INTO certificates (owner, dns, active, updated_at, temp_parent)
VALUES (?, NULL, 1, ?, ?)
`
type AddTempCertificateParams struct {
Owner string `json:"owner"`
UpdatedAt time.Time `json:"updated_at"`
TempParent sql.NullInt64 `json:"temp_parent"`
}
func (q *Queries) AddTempCertificate(ctx context.Context, arg AddTempCertificateParams) error {
_, err := q.db.ExecContext(ctx, addTempCertificate, arg.Owner, arg.UpdatedAt, arg.TempParent)
return err
}
const checkCertOwner = `-- name: CheckCertOwner :one
SELECT id, owner
FROM certificates

View File

@ -7,6 +7,7 @@ package database
import (
"context"
"strings"
)
const addDomains = `-- name: AddDomains :exec
@ -107,15 +108,26 @@ func (q *Queries) SetDomainStateForCert(ctx context.Context, arg SetDomainStateF
const updateDomains = `-- name: UpdateDomains :exec
UPDATE certificate_domains
SET state = ?
WHERE domain IN ?
WHERE domain IN (/*SLICE:domains*/?)
`
type UpdateDomainsParams struct {
State int64 `json:"state"`
Domain string `json:"domain"`
State int64 `json:"state"`
Domains []string `json:"domains"`
}
func (q *Queries) UpdateDomains(ctx context.Context, arg UpdateDomainsParams) error {
_, err := q.db.ExecContext(ctx, updateDomains, arg.State, arg.Domain)
query := updateDomains
var queryParams []interface{}
queryParams = append(queryParams, arg.State)
if len(arg.Domains) > 0 {
for _, v := range arg.Domains {
queryParams = append(queryParams, v)
}
query = strings.Replace(query, "/*SLICE:domains*/?", strings.Repeat(",?", len(arg.Domains))[1:], 1)
} else {
query = strings.Replace(query, "/*SLICE:domains*/?", "NULL", 1)
}
_, err := q.db.ExecContext(ctx, query, queryParams...)
return err
}

View File

@ -40,6 +40,10 @@ WHERE id = ?;
INSERT INTO certificates (owner, dns, not_after, updated_at)
VALUES (?, ?, ?, ?);
-- name: AddTempCertificate :exec
INSERT INTO certificates (owner, dns, active, updated_at, temp_parent)
VALUES (?, NULL, 1, ?, ?);
-- name: RemoveCertificate :exec
UPDATE certificates
SET active = 0

View File

@ -20,4 +20,4 @@ VALUES (?, ?, ?);
-- name: UpdateDomains :exec
UPDATE certificate_domains
SET state = ?
WHERE domain IN (sqlc.slice("domains"));
WHERE domain IN (sqlc.slice(domains));

View File

@ -9,6 +9,7 @@ import (
"database/sql"
"encoding/pem"
"fmt"
"github.com/1f349/orchid"
"github.com/1f349/orchid/pebble"
"github.com/1f349/orchid/test"
"github.com/MrMelon54/certgen"
@ -97,11 +98,14 @@ func setupPebbleSuite(tb testing.TB) (*certgen.CertGen, func()) {
}
}
func setupPebbleTest(t *testing.T, serverTls *certgen.CertGen) *Service {
func setupPebbleTest(t *testing.T, serverTls *certgen.CertGen) (*Service, *sql.DB) {
wg := &sync.WaitGroup{}
dbFile := fmt.Sprintf("file:%s?mode=memory&cache=shared", uuid.NewString())
db, err := sql.Open("sqlite3", dbFile)
db, err := orchid.InitDB(dbFile)
assert.NoError(t, err)
db2, err := sql.Open("sqlite3", dbFile)
assert.NoError(t, err)
log.Println("DB File:", dbFile)
certDir, err := os.MkdirTemp("", "orchid-certs")
@ -127,7 +131,7 @@ func setupPebbleTest(t *testing.T, serverTls *certgen.CertGen) *Service {
assert.NoError(t, err)
assert.NoError(t, os.WriteFile(filepath.Join(keyDir, "1.key.pem"), pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privKey)}), os.ModePerm))
return service
return service, db2
}
func TestPebbleRenewal(t *testing.T) {
@ -151,19 +155,19 @@ func TestPebbleRenewal(t *testing.T) {
for _, i := range tests {
t.Run(i.name, func(t *testing.T) {
//t.Parallel()
service := setupPebbleTest(t, serverTls)
service, db2 := setupPebbleTest(t, serverTls)
//goland:noinspection SqlWithoutWhere
_, err := service.db.Exec("DELETE FROM certificate_domains")
_, err := db2.Exec("DELETE FROM certificate_domains")
assert.NoError(t, err)
_, err = service.db.Exec(`INSERT INTO certificates (owner, dns, auto_renew, active, renewing, renew_failed, not_after, updated_at) VALUES (1, 1, 1, 1, 0, 0, NULL, NULL)`)
_, err = db2.Exec(`INSERT INTO certificates (owner, dns, auto_renew, active, renewing, renew_failed, not_after, updated_at) VALUES (1, 1, 1, 1, 0, 0, 0, 0)`)
assert.NoError(t, err)
for _, j := range i.domains {
_, err = service.db.Exec(`INSERT INTO certificate_domains (cert_id, domain) VALUES (1, ?)`, j)
_, err = db2.Exec(`INSERT INTO certificate_domains (cert_id, domain) VALUES (1, ?)`, j)
assert.NoError(t, err)
}
query, err := service.db.Query("SELECT cert_id, domain from certificate_domains")
query, err := db2.Query("SELECT cert_id, domain from certificate_domains")
assert.NoError(t, err)
for query.Next() {
var a uint64

View File

@ -163,9 +163,13 @@ func NewApiServer(listen string, db *database.Queries, signer mjwt.Verifier, dom
}
// run a safe transaction to create the temporary certificate
if safeTransaction(rw, db, func(rw http.ResponseWriter, tx *sql.Tx) error {
if db.UseTx(req.Context(), func(tx *database.Queries) error {
// insert temporary certificate into database
_, err := db.Exec(`INSERT INTO certificates (owner, dns, active, updated_at, temp_parent) VALUES (?, 0, 1, ?, ?)`, b.Subject, time.Now(), id)
err := tx.AddTempCertificate(req.Context(), database.AddTempCertificateParams{
Owner: b.Subject,
UpdatedAt: time.Now(),
TempParent: sql.NullInt64{Valid: true, Int64: id},
})
return err
}) != nil {
apiError(rw, http.StatusInsufficientStorage, "Database error")

View File

@ -13,8 +13,8 @@ import (
)
func certDomainManageGET(db *database.Queries, signer mjwt.Verifier) httprouter.Handle {
return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) {
rows, err := db.GetDomainStatesForCert(context.Background(), int64(certId))
return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId int64) {
rows, err := db.GetDomainStatesForCert(context.Background(), certId)
if err != nil {
apiError(rw, http.StatusInsufficientStorage, "Database error")
return
@ -71,8 +71,8 @@ func certDomainManagePUTandDELETE(db *database.Queries, signer mjwt.Verifier, do
} else {
// update domains to removed state
err := tx.UpdateDomains(req.Context(), database.UpdateDomainsParams{
State: renewal.DomainStateRemoved,
Domain: d,
State: renewal.DomainStateRemoved,
Domains: d,
})
if err != nil {
return fmt.Errorf("failed to remove domains from the database")