Continue changes for sqlc

This commit is contained in:
Melon 2024-03-09 00:55:06 +00:00
parent 094ac9030a
commit 2b160d4309
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
8 changed files with 61 additions and 21 deletions

View File

@ -2,9 +2,9 @@ package main
import ( import (
"context" "context"
"database/sql"
"flag" "flag"
"github.com/1f349/mjwt" "github.com/1f349/mjwt"
"github.com/1f349/orchid"
httpAcme "github.com/1f349/orchid/http-acme" httpAcme "github.com/1f349/orchid/http-acme"
"github.com/1f349/orchid/renewal" "github.com/1f349/orchid/renewal"
"github.com/1f349/orchid/servers" "github.com/1f349/orchid/servers"
@ -70,7 +70,7 @@ func normalLoad(conf startUpConfig, wd string) {
} }
// open sqlite database // open sqlite database
db, err := sql.Open("sqlite3", filepath.Join(wd, "orchid.db.sqlite")) db, err := orchid.InitDB(filepath.Join(wd, "orchid.db.sqlite"))
if err != nil { if err != nil {
log.Fatal("[Orchid] Failed to open database:", err) log.Fatal("[Orchid] Failed to open database:", err)
} }

View File

@ -33,6 +33,22 @@ func (q *Queries) AddCertificate(ctx context.Context, arg AddCertificateParams)
return err return err
} }
const addTempCertificate = `-- name: AddTempCertificate :exec
INSERT INTO certificates (owner, dns, active, updated_at, temp_parent)
VALUES (?, NULL, 1, ?, ?)
`
type AddTempCertificateParams struct {
Owner string `json:"owner"`
UpdatedAt time.Time `json:"updated_at"`
TempParent sql.NullInt64 `json:"temp_parent"`
}
func (q *Queries) AddTempCertificate(ctx context.Context, arg AddTempCertificateParams) error {
_, err := q.db.ExecContext(ctx, addTempCertificate, arg.Owner, arg.UpdatedAt, arg.TempParent)
return err
}
const checkCertOwner = `-- name: CheckCertOwner :one const checkCertOwner = `-- name: CheckCertOwner :one
SELECT id, owner SELECT id, owner
FROM certificates FROM certificates

View File

@ -7,6 +7,7 @@ package database
import ( import (
"context" "context"
"strings"
) )
const addDomains = `-- name: AddDomains :exec const addDomains = `-- name: AddDomains :exec
@ -107,15 +108,26 @@ func (q *Queries) SetDomainStateForCert(ctx context.Context, arg SetDomainStateF
const updateDomains = `-- name: UpdateDomains :exec const updateDomains = `-- name: UpdateDomains :exec
UPDATE certificate_domains UPDATE certificate_domains
SET state = ? SET state = ?
WHERE domain IN ? WHERE domain IN (/*SLICE:domains*/?)
` `
type UpdateDomainsParams struct { type UpdateDomainsParams struct {
State int64 `json:"state"` State int64 `json:"state"`
Domain string `json:"domain"` Domains []string `json:"domains"`
} }
func (q *Queries) UpdateDomains(ctx context.Context, arg UpdateDomainsParams) error { func (q *Queries) UpdateDomains(ctx context.Context, arg UpdateDomainsParams) error {
_, err := q.db.ExecContext(ctx, updateDomains, arg.State, arg.Domain) query := updateDomains
var queryParams []interface{}
queryParams = append(queryParams, arg.State)
if len(arg.Domains) > 0 {
for _, v := range arg.Domains {
queryParams = append(queryParams, v)
}
query = strings.Replace(query, "/*SLICE:domains*/?", strings.Repeat(",?", len(arg.Domains))[1:], 1)
} else {
query = strings.Replace(query, "/*SLICE:domains*/?", "NULL", 1)
}
_, err := q.db.ExecContext(ctx, query, queryParams...)
return err return err
} }

View File

@ -40,6 +40,10 @@ WHERE id = ?;
INSERT INTO certificates (owner, dns, not_after, updated_at) INSERT INTO certificates (owner, dns, not_after, updated_at)
VALUES (?, ?, ?, ?); VALUES (?, ?, ?, ?);
-- name: AddTempCertificate :exec
INSERT INTO certificates (owner, dns, active, updated_at, temp_parent)
VALUES (?, NULL, 1, ?, ?);
-- name: RemoveCertificate :exec -- name: RemoveCertificate :exec
UPDATE certificates UPDATE certificates
SET active = 0 SET active = 0

View File

@ -20,4 +20,4 @@ VALUES (?, ?, ?);
-- name: UpdateDomains :exec -- name: UpdateDomains :exec
UPDATE certificate_domains UPDATE certificate_domains
SET state = ? SET state = ?
WHERE domain IN (sqlc.slice("domains")); WHERE domain IN (sqlc.slice(domains));

View File

@ -9,6 +9,7 @@ import (
"database/sql" "database/sql"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"github.com/1f349/orchid"
"github.com/1f349/orchid/pebble" "github.com/1f349/orchid/pebble"
"github.com/1f349/orchid/test" "github.com/1f349/orchid/test"
"github.com/MrMelon54/certgen" "github.com/MrMelon54/certgen"
@ -97,11 +98,14 @@ func setupPebbleSuite(tb testing.TB) (*certgen.CertGen, func()) {
} }
} }
func setupPebbleTest(t *testing.T, serverTls *certgen.CertGen) *Service { func setupPebbleTest(t *testing.T, serverTls *certgen.CertGen) (*Service, *sql.DB) {
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
dbFile := fmt.Sprintf("file:%s?mode=memory&cache=shared", uuid.NewString()) dbFile := fmt.Sprintf("file:%s?mode=memory&cache=shared", uuid.NewString())
db, err := sql.Open("sqlite3", dbFile) db, err := orchid.InitDB(dbFile)
assert.NoError(t, err) assert.NoError(t, err)
db2, err := sql.Open("sqlite3", dbFile)
assert.NoError(t, err)
log.Println("DB File:", dbFile) log.Println("DB File:", dbFile)
certDir, err := os.MkdirTemp("", "orchid-certs") certDir, err := os.MkdirTemp("", "orchid-certs")
@ -127,7 +131,7 @@ func setupPebbleTest(t *testing.T, serverTls *certgen.CertGen) *Service {
assert.NoError(t, err) assert.NoError(t, err)
assert.NoError(t, os.WriteFile(filepath.Join(keyDir, "1.key.pem"), pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privKey)}), os.ModePerm)) assert.NoError(t, os.WriteFile(filepath.Join(keyDir, "1.key.pem"), pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privKey)}), os.ModePerm))
return service return service, db2
} }
func TestPebbleRenewal(t *testing.T) { func TestPebbleRenewal(t *testing.T) {
@ -151,19 +155,19 @@ func TestPebbleRenewal(t *testing.T) {
for _, i := range tests { for _, i := range tests {
t.Run(i.name, func(t *testing.T) { t.Run(i.name, func(t *testing.T) {
//t.Parallel() //t.Parallel()
service := setupPebbleTest(t, serverTls) service, db2 := setupPebbleTest(t, serverTls)
//goland:noinspection SqlWithoutWhere //goland:noinspection SqlWithoutWhere
_, err := service.db.Exec("DELETE FROM certificate_domains") _, err := db2.Exec("DELETE FROM certificate_domains")
assert.NoError(t, err) assert.NoError(t, err)
_, err = service.db.Exec(`INSERT INTO certificates (owner, dns, auto_renew, active, renewing, renew_failed, not_after, updated_at) VALUES (1, 1, 1, 1, 0, 0, NULL, NULL)`) _, err = db2.Exec(`INSERT INTO certificates (owner, dns, auto_renew, active, renewing, renew_failed, not_after, updated_at) VALUES (1, 1, 1, 1, 0, 0, 0, 0)`)
assert.NoError(t, err) assert.NoError(t, err)
for _, j := range i.domains { for _, j := range i.domains {
_, err = service.db.Exec(`INSERT INTO certificate_domains (cert_id, domain) VALUES (1, ?)`, j) _, err = db2.Exec(`INSERT INTO certificate_domains (cert_id, domain) VALUES (1, ?)`, j)
assert.NoError(t, err) assert.NoError(t, err)
} }
query, err := service.db.Query("SELECT cert_id, domain from certificate_domains") query, err := db2.Query("SELECT cert_id, domain from certificate_domains")
assert.NoError(t, err) assert.NoError(t, err)
for query.Next() { for query.Next() {
var a uint64 var a uint64

View File

@ -163,9 +163,13 @@ func NewApiServer(listen string, db *database.Queries, signer mjwt.Verifier, dom
} }
// run a safe transaction to create the temporary certificate // run a safe transaction to create the temporary certificate
if safeTransaction(rw, db, func(rw http.ResponseWriter, tx *sql.Tx) error { if db.UseTx(req.Context(), func(tx *database.Queries) error {
// insert temporary certificate into database // insert temporary certificate into database
_, err := db.Exec(`INSERT INTO certificates (owner, dns, active, updated_at, temp_parent) VALUES (?, 0, 1, ?, ?)`, b.Subject, time.Now(), id) err := tx.AddTempCertificate(req.Context(), database.AddTempCertificateParams{
Owner: b.Subject,
UpdatedAt: time.Now(),
TempParent: sql.NullInt64{Valid: true, Int64: id},
})
return err return err
}) != nil { }) != nil {
apiError(rw, http.StatusInsufficientStorage, "Database error") apiError(rw, http.StatusInsufficientStorage, "Database error")

View File

@ -13,8 +13,8 @@ import (
) )
func certDomainManageGET(db *database.Queries, signer mjwt.Verifier) httprouter.Handle { func certDomainManageGET(db *database.Queries, signer mjwt.Verifier) httprouter.Handle {
return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) { return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId int64) {
rows, err := db.GetDomainStatesForCert(context.Background(), int64(certId)) rows, err := db.GetDomainStatesForCert(context.Background(), certId)
if err != nil { if err != nil {
apiError(rw, http.StatusInsufficientStorage, "Database error") apiError(rw, http.StatusInsufficientStorage, "Database error")
return return
@ -72,7 +72,7 @@ func certDomainManagePUTandDELETE(db *database.Queries, signer mjwt.Verifier, do
// update domains to removed state // update domains to removed state
err := tx.UpdateDomains(req.Context(), database.UpdateDomainsParams{ err := tx.UpdateDomains(req.Context(), database.UpdateDomainsParams{
State: renewal.DomainStateRemoved, State: renewal.DomainStateRemoved,
Domain: d, Domains: d,
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to remove domains from the database") return fmt.Errorf("failed to remove domains from the database")