mirror of
https://github.com/1f349/orchid.git
synced 2025-01-21 06:36:27 +00:00
Continue changes for sqlc
This commit is contained in:
parent
094ac9030a
commit
2b160d4309
@ -2,9 +2,9 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"flag"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/orchid"
|
||||
httpAcme "github.com/1f349/orchid/http-acme"
|
||||
"github.com/1f349/orchid/renewal"
|
||||
"github.com/1f349/orchid/servers"
|
||||
@ -70,7 +70,7 @@ func normalLoad(conf startUpConfig, wd string) {
|
||||
}
|
||||
|
||||
// open sqlite database
|
||||
db, err := sql.Open("sqlite3", filepath.Join(wd, "orchid.db.sqlite"))
|
||||
db, err := orchid.InitDB(filepath.Join(wd, "orchid.db.sqlite"))
|
||||
if err != nil {
|
||||
log.Fatal("[Orchid] Failed to open database:", err)
|
||||
}
|
||||
|
@ -33,6 +33,22 @@ func (q *Queries) AddCertificate(ctx context.Context, arg AddCertificateParams)
|
||||
return err
|
||||
}
|
||||
|
||||
const addTempCertificate = `-- name: AddTempCertificate :exec
|
||||
INSERT INTO certificates (owner, dns, active, updated_at, temp_parent)
|
||||
VALUES (?, NULL, 1, ?, ?)
|
||||
`
|
||||
|
||||
type AddTempCertificateParams struct {
|
||||
Owner string `json:"owner"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
TempParent sql.NullInt64 `json:"temp_parent"`
|
||||
}
|
||||
|
||||
func (q *Queries) AddTempCertificate(ctx context.Context, arg AddTempCertificateParams) error {
|
||||
_, err := q.db.ExecContext(ctx, addTempCertificate, arg.Owner, arg.UpdatedAt, arg.TempParent)
|
||||
return err
|
||||
}
|
||||
|
||||
const checkCertOwner = `-- name: CheckCertOwner :one
|
||||
SELECT id, owner
|
||||
FROM certificates
|
||||
|
@ -7,6 +7,7 @@ package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const addDomains = `-- name: AddDomains :exec
|
||||
@ -107,15 +108,26 @@ func (q *Queries) SetDomainStateForCert(ctx context.Context, arg SetDomainStateF
|
||||
const updateDomains = `-- name: UpdateDomains :exec
|
||||
UPDATE certificate_domains
|
||||
SET state = ?
|
||||
WHERE domain IN ?
|
||||
WHERE domain IN (/*SLICE:domains*/?)
|
||||
`
|
||||
|
||||
type UpdateDomainsParams struct {
|
||||
State int64 `json:"state"`
|
||||
Domain string `json:"domain"`
|
||||
State int64 `json:"state"`
|
||||
Domains []string `json:"domains"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateDomains(ctx context.Context, arg UpdateDomainsParams) error {
|
||||
_, err := q.db.ExecContext(ctx, updateDomains, arg.State, arg.Domain)
|
||||
query := updateDomains
|
||||
var queryParams []interface{}
|
||||
queryParams = append(queryParams, arg.State)
|
||||
if len(arg.Domains) > 0 {
|
||||
for _, v := range arg.Domains {
|
||||
queryParams = append(queryParams, v)
|
||||
}
|
||||
query = strings.Replace(query, "/*SLICE:domains*/?", strings.Repeat(",?", len(arg.Domains))[1:], 1)
|
||||
} else {
|
||||
query = strings.Replace(query, "/*SLICE:domains*/?", "NULL", 1)
|
||||
}
|
||||
_, err := q.db.ExecContext(ctx, query, queryParams...)
|
||||
return err
|
||||
}
|
||||
|
@ -40,6 +40,10 @@ WHERE id = ?;
|
||||
INSERT INTO certificates (owner, dns, not_after, updated_at)
|
||||
VALUES (?, ?, ?, ?);
|
||||
|
||||
-- name: AddTempCertificate :exec
|
||||
INSERT INTO certificates (owner, dns, active, updated_at, temp_parent)
|
||||
VALUES (?, NULL, 1, ?, ?);
|
||||
|
||||
-- name: RemoveCertificate :exec
|
||||
UPDATE certificates
|
||||
SET active = 0
|
||||
|
@ -20,4 +20,4 @@ VALUES (?, ?, ?);
|
||||
-- name: UpdateDomains :exec
|
||||
UPDATE certificate_domains
|
||||
SET state = ?
|
||||
WHERE domain IN (sqlc.slice("domains"));
|
||||
WHERE domain IN (sqlc.slice(domains));
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"database/sql"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"github.com/1f349/orchid"
|
||||
"github.com/1f349/orchid/pebble"
|
||||
"github.com/1f349/orchid/test"
|
||||
"github.com/MrMelon54/certgen"
|
||||
@ -97,11 +98,14 @@ func setupPebbleSuite(tb testing.TB) (*certgen.CertGen, func()) {
|
||||
}
|
||||
}
|
||||
|
||||
func setupPebbleTest(t *testing.T, serverTls *certgen.CertGen) *Service {
|
||||
func setupPebbleTest(t *testing.T, serverTls *certgen.CertGen) (*Service, *sql.DB) {
|
||||
wg := &sync.WaitGroup{}
|
||||
dbFile := fmt.Sprintf("file:%s?mode=memory&cache=shared", uuid.NewString())
|
||||
db, err := sql.Open("sqlite3", dbFile)
|
||||
db, err := orchid.InitDB(dbFile)
|
||||
assert.NoError(t, err)
|
||||
db2, err := sql.Open("sqlite3", dbFile)
|
||||
assert.NoError(t, err)
|
||||
|
||||
log.Println("DB File:", dbFile)
|
||||
|
||||
certDir, err := os.MkdirTemp("", "orchid-certs")
|
||||
@ -127,7 +131,7 @@ func setupPebbleTest(t *testing.T, serverTls *certgen.CertGen) *Service {
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, os.WriteFile(filepath.Join(keyDir, "1.key.pem"), pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privKey)}), os.ModePerm))
|
||||
|
||||
return service
|
||||
return service, db2
|
||||
}
|
||||
|
||||
func TestPebbleRenewal(t *testing.T) {
|
||||
@ -151,19 +155,19 @@ func TestPebbleRenewal(t *testing.T) {
|
||||
for _, i := range tests {
|
||||
t.Run(i.name, func(t *testing.T) {
|
||||
//t.Parallel()
|
||||
service := setupPebbleTest(t, serverTls)
|
||||
service, db2 := setupPebbleTest(t, serverTls)
|
||||
//goland:noinspection SqlWithoutWhere
|
||||
_, err := service.db.Exec("DELETE FROM certificate_domains")
|
||||
_, err := db2.Exec("DELETE FROM certificate_domains")
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = service.db.Exec(`INSERT INTO certificates (owner, dns, auto_renew, active, renewing, renew_failed, not_after, updated_at) VALUES (1, 1, 1, 1, 0, 0, NULL, NULL)`)
|
||||
_, err = db2.Exec(`INSERT INTO certificates (owner, dns, auto_renew, active, renewing, renew_failed, not_after, updated_at) VALUES (1, 1, 1, 1, 0, 0, 0, 0)`)
|
||||
assert.NoError(t, err)
|
||||
for _, j := range i.domains {
|
||||
_, err = service.db.Exec(`INSERT INTO certificate_domains (cert_id, domain) VALUES (1, ?)`, j)
|
||||
_, err = db2.Exec(`INSERT INTO certificate_domains (cert_id, domain) VALUES (1, ?)`, j)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
query, err := service.db.Query("SELECT cert_id, domain from certificate_domains")
|
||||
query, err := db2.Query("SELECT cert_id, domain from certificate_domains")
|
||||
assert.NoError(t, err)
|
||||
for query.Next() {
|
||||
var a uint64
|
||||
|
@ -163,9 +163,13 @@ func NewApiServer(listen string, db *database.Queries, signer mjwt.Verifier, dom
|
||||
}
|
||||
|
||||
// run a safe transaction to create the temporary certificate
|
||||
if safeTransaction(rw, db, func(rw http.ResponseWriter, tx *sql.Tx) error {
|
||||
if db.UseTx(req.Context(), func(tx *database.Queries) error {
|
||||
// insert temporary certificate into database
|
||||
_, err := db.Exec(`INSERT INTO certificates (owner, dns, active, updated_at, temp_parent) VALUES (?, 0, 1, ?, ?)`, b.Subject, time.Now(), id)
|
||||
err := tx.AddTempCertificate(req.Context(), database.AddTempCertificateParams{
|
||||
Owner: b.Subject,
|
||||
UpdatedAt: time.Now(),
|
||||
TempParent: sql.NullInt64{Valid: true, Int64: id},
|
||||
})
|
||||
return err
|
||||
}) != nil {
|
||||
apiError(rw, http.StatusInsufficientStorage, "Database error")
|
||||
|
@ -13,8 +13,8 @@ import (
|
||||
)
|
||||
|
||||
func certDomainManageGET(db *database.Queries, signer mjwt.Verifier) httprouter.Handle {
|
||||
return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) {
|
||||
rows, err := db.GetDomainStatesForCert(context.Background(), int64(certId))
|
||||
return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId int64) {
|
||||
rows, err := db.GetDomainStatesForCert(context.Background(), certId)
|
||||
if err != nil {
|
||||
apiError(rw, http.StatusInsufficientStorage, "Database error")
|
||||
return
|
||||
@ -71,8 +71,8 @@ func certDomainManagePUTandDELETE(db *database.Queries, signer mjwt.Verifier, do
|
||||
} else {
|
||||
// update domains to removed state
|
||||
err := tx.UpdateDomains(req.Context(), database.UpdateDomainsParams{
|
||||
State: renewal.DomainStateRemoved,
|
||||
Domain: d,
|
||||
State: renewal.DomainStateRemoved,
|
||||
Domains: d,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove domains from the database")
|
||||
|
Loading…
Reference in New Issue
Block a user