Start updating to support sqlc and migrations

This commit is contained in:
Melon 2024-03-09 00:31:52 +00:00
parent 9b3c801ebf
commit 094ac9030a
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
21 changed files with 672 additions and 198 deletions

192
database/certificate.sql.go Normal file
View File

@ -0,0 +1,192 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.25.0
// source: certificate.sql
package database
import (
"context"
"database/sql"
"time"
)
const addCertificate = `-- name: AddCertificate :exec
INSERT INTO certificates (owner, dns, not_after, updated_at)
VALUES (?, ?, ?, ?)
`
type AddCertificateParams struct {
Owner string `json:"owner"`
Dns sql.NullInt64 `json:"dns"`
NotAfter time.Time `json:"not_after"`
UpdatedAt time.Time `json:"updated_at"`
}
func (q *Queries) AddCertificate(ctx context.Context, arg AddCertificateParams) error {
_, err := q.db.ExecContext(ctx, addCertificate,
arg.Owner,
arg.Dns,
arg.NotAfter,
arg.UpdatedAt,
)
return err
}
const checkCertOwner = `-- name: CheckCertOwner :one
SELECT id, owner
FROM certificates
WHERE active = 1
and id = ?
`
type CheckCertOwnerRow struct {
ID int64 `json:"id"`
Owner string `json:"owner"`
}
func (q *Queries) CheckCertOwner(ctx context.Context, id int64) (CheckCertOwnerRow, error) {
row := q.db.QueryRowContext(ctx, checkCertOwner, id)
var i CheckCertOwnerRow
err := row.Scan(&i.ID, &i.Owner)
return i, err
}
const findNextCert = `-- name: FindNextCert :one
SELECT cert.id, cert.not_after, dns_acme.type, dns_acme.token, cert.temp_parent
FROM certificates AS cert
LEFT OUTER JOIN dns_acme ON cert.dns = dns_acme.id
WHERE cert.active = 1
AND (cert.auto_renew = 1 OR cert.not_after IS NULL)
AND cert.renewing = 0
AND cert.renew_failed = 0
AND (cert.not_after IS NULL OR DATETIME(cert.not_after, 'utc', '-30 days') < DATETIME())
ORDER BY cert.temp_parent, cert.not_after DESC NULLS FIRST
LIMIT 1
`
type FindNextCertRow struct {
ID int64 `json:"id"`
NotAfter time.Time `json:"not_after"`
Type sql.NullString `json:"type"`
Token sql.NullString `json:"token"`
TempParent sql.NullInt64 `json:"temp_parent"`
}
func (q *Queries) FindNextCert(ctx context.Context) (FindNextCertRow, error) {
row := q.db.QueryRowContext(ctx, findNextCert)
var i FindNextCertRow
err := row.Scan(
&i.ID,
&i.NotAfter,
&i.Type,
&i.Token,
&i.TempParent,
)
return i, err
}
const findOwnedCerts = `-- name: FindOwnedCerts :many
SELECT cert.id,
cert.auto_renew,
cert.active,
cert.renewing,
cert.renew_failed,
cert.not_after,
cert.updated_at,
certificate_domains.domain
FROM certificates AS cert
INNER JOIN certificate_domains ON cert.id = certificate_domains.cert_id
`
type FindOwnedCertsRow struct {
ID int64 `json:"id"`
AutoRenew bool `json:"auto_renew"`
Active bool `json:"active"`
Renewing bool `json:"renewing"`
RenewFailed bool `json:"renew_failed"`
NotAfter time.Time `json:"not_after"`
UpdatedAt time.Time `json:"updated_at"`
Domain string `json:"domain"`
}
func (q *Queries) FindOwnedCerts(ctx context.Context) ([]FindOwnedCertsRow, error) {
rows, err := q.db.QueryContext(ctx, findOwnedCerts)
if err != nil {
return nil, err
}
defer rows.Close()
var items []FindOwnedCertsRow
for rows.Next() {
var i FindOwnedCertsRow
if err := rows.Scan(
&i.ID,
&i.AutoRenew,
&i.Active,
&i.Renewing,
&i.RenewFailed,
&i.NotAfter,
&i.UpdatedAt,
&i.Domain,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const removeCertificate = `-- name: RemoveCertificate :exec
UPDATE certificates
SET active = 0
WHERE id = ?
`
func (q *Queries) RemoveCertificate(ctx context.Context, id int64) error {
_, err := q.db.ExecContext(ctx, removeCertificate, id)
return err
}
const updateCertAfterRenewal = `-- name: UpdateCertAfterRenewal :exec
UPDATE certificates
SET renewing = 0,
renew_failed=0,
not_after=?,
updated_at=?
WHERE id = ?
`
type UpdateCertAfterRenewalParams struct {
NotAfter time.Time `json:"not_after"`
UpdatedAt time.Time `json:"updated_at"`
ID int64 `json:"id"`
}
func (q *Queries) UpdateCertAfterRenewal(ctx context.Context, arg UpdateCertAfterRenewalParams) error {
_, err := q.db.ExecContext(ctx, updateCertAfterRenewal, arg.NotAfter, arg.UpdatedAt, arg.ID)
return err
}
const updateRenewingState = `-- name: UpdateRenewingState :exec
UPDATE certificates
SET renewing = ?,
renew_failed = ?
WHERE id = ?
`
type UpdateRenewingStateParams struct {
Renewing bool `json:"renewing"`
RenewFailed bool `json:"renew_failed"`
ID int64 `json:"id"`
}
func (q *Queries) UpdateRenewingState(ctx context.Context, arg UpdateRenewingStateParams) error {
_, err := q.db.ExecContext(ctx, updateRenewingState, arg.Renewing, arg.RenewFailed, arg.ID)
return err
}

View File

@ -0,0 +1,121 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.25.0
// source: certificate_domains.sql
package database
import (
"context"
)
const addDomains = `-- name: AddDomains :exec
INSERT INTO certificate_domains (cert_id, domain, state)
VALUES (?, ?, ?)
`
type AddDomainsParams struct {
CertID int64 `json:"cert_id"`
Domain string `json:"domain"`
State int64 `json:"state"`
}
func (q *Queries) AddDomains(ctx context.Context, arg AddDomainsParams) error {
_, err := q.db.ExecContext(ctx, addDomains, arg.CertID, arg.Domain, arg.State)
return err
}
const getDomainStatesForCert = `-- name: GetDomainStatesForCert :many
SELECT domain, state
FROM certificate_domains
WHERE cert_id = ?
`
type GetDomainStatesForCertRow struct {
Domain string `json:"domain"`
State int64 `json:"state"`
}
func (q *Queries) GetDomainStatesForCert(ctx context.Context, certID int64) ([]GetDomainStatesForCertRow, error) {
rows, err := q.db.QueryContext(ctx, getDomainStatesForCert, certID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetDomainStatesForCertRow
for rows.Next() {
var i GetDomainStatesForCertRow
if err := rows.Scan(&i.Domain, &i.State); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getDomainsForCertificate = `-- name: GetDomainsForCertificate :many
SELECT domain
FROM certificate_domains
WHERE cert_id = ?
`
func (q *Queries) GetDomainsForCertificate(ctx context.Context, certID int64) ([]string, error) {
rows, err := q.db.QueryContext(ctx, getDomainsForCertificate, certID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []string
for rows.Next() {
var domain string
if err := rows.Scan(&domain); err != nil {
return nil, err
}
items = append(items, domain)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const setDomainStateForCert = `-- name: SetDomainStateForCert :exec
UPDATE certificate_domains
SET state = ?
WHERE cert_id = ?
`
type SetDomainStateForCertParams struct {
State int64 `json:"state"`
CertID int64 `json:"cert_id"`
}
func (q *Queries) SetDomainStateForCert(ctx context.Context, arg SetDomainStateForCertParams) error {
_, err := q.db.ExecContext(ctx, setDomainStateForCert, arg.State, arg.CertID)
return err
}
const updateDomains = `-- name: UpdateDomains :exec
UPDATE certificate_domains
SET state = ?
WHERE domain IN ?
`
type UpdateDomainsParams struct {
State int64 `json:"state"`
Domain string `json:"domain"`
}
func (q *Queries) UpdateDomains(ctx context.Context, arg UpdateDomainsParams) error {
_, err := q.db.ExecContext(ctx, updateDomains, arg.State, arg.Domain)
return err
}

31
database/db.go Normal file
View File

@ -0,0 +1,31 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.25.0
package database
import (
"context"
"database/sql"
)
type DBTX interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
PrepareContext(context.Context, string) (*sql.Stmt, error)
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}
func New(db DBTX) *Queries {
return &Queries{db: db}
}
type Queries struct {
db DBTX
}
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{
db: tx,
}
}

View File

@ -0,0 +1,33 @@
CREATE TABLE IF NOT EXISTS certificates
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
owner VARCHAR NOT NULL,
dns INTEGER,
auto_renew BOOLEAN NOT NULL DEFAULT 0,
active BOOLEAN NOT NULL DEFAULT 0,
renewing BOOLEAN NOT NULL DEFAULT 0,
renew_failed BOOLEAN NOT NULL DEFAULT 0,
not_after DATETIME NOT NULL,
updated_at DATETIME NOT NULL,
temp_parent INTEGER,
FOREIGN KEY (dns) REFERENCES dns_acme (id),
FOREIGN KEY (temp_parent) REFERENCES certificates (id)
);
CREATE TABLE IF NOT EXISTS certificate_domains
(
domain_id INTEGER PRIMARY KEY AUTOINCREMENT,
cert_id INTEGER NOT NULL,
domain VARCHAR NOT NULL,
state INTEGER NOT NULL DEFAULT 1,
UNIQUE (cert_id, domain),
FOREIGN KEY (cert_id) REFERENCES certificates (id)
);
CREATE TABLE IF NOT EXISTS dns_acme
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
type VARCHAR NOT NULL,
email VARCHAR NOT NULL,
token VARCHAR NOT NULL
);

37
database/models.go Normal file
View File

@ -0,0 +1,37 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.25.0
package database
import (
"database/sql"
"time"
)
type Certificate struct {
ID int64 `json:"id"`
Owner string `json:"owner"`
Dns sql.NullInt64 `json:"dns"`
AutoRenew bool `json:"auto_renew"`
Active bool `json:"active"`
Renewing bool `json:"renewing"`
RenewFailed bool `json:"renew_failed"`
NotAfter time.Time `json:"not_after"`
UpdatedAt time.Time `json:"updated_at"`
TempParent sql.NullInt64 `json:"temp_parent"`
}
type CertificateDomain struct {
DomainID int64 `json:"domain_id"`
CertID int64 `json:"cert_id"`
Domain string `json:"domain"`
State int64 `json:"state"`
}
type DnsAcme struct {
ID int64 `json:"id"`
Type string `json:"type"`
Email string `json:"email"`
Token string `json:"token"`
}

View File

@ -0,0 +1,52 @@
-- name: FindNextCert :one
SELECT cert.id, cert.not_after, dns_acme.type, dns_acme.token, cert.temp_parent
FROM certificates AS cert
LEFT OUTER JOIN dns_acme ON cert.dns = dns_acme.id
WHERE cert.active = 1
AND (cert.auto_renew = 1 OR cert.not_after IS NULL)
AND cert.renewing = 0
AND cert.renew_failed = 0
AND (cert.not_after IS NULL OR DATETIME(cert.not_after, 'utc', '-30 days') < DATETIME())
ORDER BY cert.temp_parent, cert.not_after DESC NULLS FIRST
LIMIT 1;
-- name: FindOwnedCerts :many
SELECT cert.id,
cert.auto_renew,
cert.active,
cert.renewing,
cert.renew_failed,
cert.not_after,
cert.updated_at,
certificate_domains.domain
FROM certificates AS cert
INNER JOIN certificate_domains ON cert.id = certificate_domains.cert_id;
-- name: UpdateRenewingState :exec
UPDATE certificates
SET renewing = ?,
renew_failed = ?
WHERE id = ?;
-- name: UpdateCertAfterRenewal :exec
UPDATE certificates
SET renewing = 0,
renew_failed=0,
not_after=?,
updated_at=?
WHERE id = ?;
-- name: AddCertificate :exec
INSERT INTO certificates (owner, dns, not_after, updated_at)
VALUES (?, ?, ?, ?);
-- name: RemoveCertificate :exec
UPDATE certificates
SET active = 0
WHERE id = ?;
-- name: CheckCertOwner :one
SELECT id, owner
FROM certificates
WHERE active = 1
and id = ?;

View File

@ -0,0 +1,23 @@
-- name: GetDomainsForCertificate :many
SELECT domain
FROM certificate_domains
WHERE cert_id = ?;
-- name: GetDomainStatesForCert :many
SELECT domain, state
FROM certificate_domains
WHERE cert_id = ?;
-- name: SetDomainStateForCert :exec
UPDATE certificate_domains
SET state = ?
WHERE cert_id = ?;
-- name: AddDomains :exec
INSERT INTO certificate_domains (cert_id, domain, state)
VALUES (?, ?, ?);
-- name: UpdateDomains :exec
UPDATE certificate_domains
SET state = ?
WHERE domain IN (sqlc.slice("domains"));

23
database/tx.go Normal file
View File

@ -0,0 +1,23 @@
package database
import (
"context"
"database/sql"
)
func (q *Queries) UseTx(ctx context.Context, cb func(tx *Queries) error) error {
sqlDB, ok := q.db.(*sql.DB)
if !ok {
panic("cannot open transaction without sql.DB")
}
tx, err := sqlDB.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
err = cb(q.WithTx(tx))
if err != nil {
return err
}
return tx.Commit()
}

8
go.mod
View File

@ -9,6 +9,7 @@ require (
github.com/MrMelon54/certgen v0.0.1 github.com/MrMelon54/certgen v0.0.1
github.com/MrMelon54/exit-reload v0.0.1 github.com/MrMelon54/exit-reload v0.0.1
github.com/go-acme/lego/v4 v4.14.2 github.com/go-acme/lego/v4 v4.14.2
github.com/golang-migrate/migrate/v4 v4.17.1
github.com/google/subcommands v1.2.0 github.com/google/subcommands v1.2.0
github.com/google/uuid v1.4.0 github.com/google/uuid v1.4.0
github.com/julienschmidt/httprouter v1.3.0 github.com/julienschmidt/httprouter v1.3.0
@ -25,6 +26,8 @@ require (
github.com/go-jose/go-jose/v3 v3.0.3 // indirect github.com/go-jose/go-jose/v3 v3.0.3 // indirect
github.com/golang-jwt/jwt/v4 v4.5.0 // indirect github.com/golang-jwt/jwt/v4 v4.5.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
@ -32,9 +35,10 @@ require (
github.com/nrdcg/namesilo v0.2.1 // indirect github.com/nrdcg/namesilo v0.2.1 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/crypto v0.19.0 // indirect go.uber.org/atomic v1.7.0 // indirect
golang.org/x/crypto v0.20.0 // indirect
golang.org/x/mod v0.14.0 // indirect golang.org/x/mod v0.14.0 // indirect
golang.org/x/net v0.19.0 // indirect golang.org/x/net v0.21.0 // indirect
golang.org/x/sys v0.17.0 // indirect golang.org/x/sys v0.17.0 // indirect
golang.org/x/term v0.17.0 // indirect golang.org/x/term v0.17.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect

18
go.sum
View File

@ -25,6 +25,8 @@ github.com/go-jose/go-jose/v3 v3.0.3 h1:fFKWeig/irsp7XD2zBxvnmA/XaRWp5V3CBsZXJF7
github.com/go-jose/go-jose/v3 v3.0.3/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ= github.com/go-jose/go-jose/v3 v3.0.3/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-migrate/migrate/v4 v4.17.1 h1:4zQ6iqL6t6AiItphxJctQb3cFqWiSpMnX7wLTPnnYO4=
github.com/golang-migrate/migrate/v4 v4.17.1/go.mod h1:m8hinFyWBn0SA4QKHuKh175Pm9wjmxj3S2Mia7dbXzM=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
@ -35,6 +37,11 @@ github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4=
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog=
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
@ -45,6 +52,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
@ -74,10 +83,13 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.20.0 h1:jmAMJJZXr5KiCw05dfYK9QnqaqKLYXijU23lsEdcQqg=
golang.org/x/crypto v0.20.0/go.mod h1:Xwo95rrVNIoSMx9wa1JroENMToLWn3RNVrTBpLHgZPQ=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0=
@ -87,8 +99,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

38
initdb.go Normal file
View File

@ -0,0 +1,38 @@
package orchid
import (
"database/sql"
"embed"
"errors"
"github.com/1f349/orchid/database"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/iofs"
)
//go:embed database/migrations/*.sql
var migrations embed.FS
func InitDB(p string) (*database.Queries, error) {
migDrv, err := iofs.New(migrations, "database/migrations")
if err != nil {
return nil, err
}
dbOpen, err := sql.Open("sqlite3", p)
if err != nil {
return nil, err
}
dbDrv, err := sqlite3.WithInstance(dbOpen, &sqlite3.Config{})
if err != nil {
return nil, err
}
mig, err := migrate.NewWithInstance("iofs", migDrv, "sqlite3", dbDrv)
if err != nil {
return nil, err
}
err = mig.Up()
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
return nil, err
}
return database.New(dbOpen), nil
}

View File

@ -1,33 +0,0 @@
CREATE TABLE IF NOT EXISTS certificates
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
owner VARCHAR,
dns INTEGER,
auto_renew INTEGER DEFAULT 0,
active INTEGER DEFAULT 0,
renewing INTEGER DEFAULT 0,
renew_failed INTEGER DEFAULT 0,
not_after DATETIME,
updated_at DATETIME,
temp_parent INTEGER DEFAULT 0,
FOREIGN KEY (dns) REFERENCES dns_acme (id),
FOREIGN KEY (temp_parent) REFERENCES certificates (id)
);
CREATE TABLE IF NOT EXISTS certificate_domains
(
domain_id INTEGER PRIMARY KEY AUTOINCREMENT,
cert_id INTEGER,
domain VARCHAR,
state INTEGER DEFAULT 1,
UNIQUE (cert_id, domain),
FOREIGN KEY (cert_id) REFERENCES certificates (id)
);
CREATE TABLE IF NOT EXISTS dns_acme
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
type VARCHAR,
email VARCHAR,
token VARCHAR
);

View File

@ -1,9 +0,0 @@
select cert.id, cert.not_after, dns_acme.type, dns_acme.token, cert.temp_parent
from certificates as cert
left outer join dns_acme on cert.dns = dns_acme.id
where cert.active = 1
and (cert.auto_renew = 1 or cert.not_after IS NULL)
and cert.renewing = 0
and cert.renew_failed = 0
and (cert.not_after IS NULL or DATETIME(cert.not_after, 'utc', '-30 days') < DATETIME())
order by cert.temp_parent, cert.not_after DESC NULLS FIRST

View File

@ -2,16 +2,17 @@ package renewal
import ( import (
"database/sql" "database/sql"
"time"
) )
// Contains local types for the renewal service // Contains local types for the renewal service
type localCertData struct { type localCertData struct {
id uint64 id int64
dns struct { dns struct {
name sql.NullString name sql.NullString
token sql.NullString token sql.NullString
} }
notAfter sql.NullTime notAfter time.Time
domains []string domains []string
tempParent uint64 tempParent sql.NullInt64
} }

View File

@ -2,6 +2,7 @@ package renewal
import ( import (
"bytes" "bytes"
"context"
"crypto/rsa" "crypto/rsa"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
@ -10,6 +11,7 @@ import (
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
"github.com/1f349/orchid/database"
"github.com/1f349/orchid/pebble" "github.com/1f349/orchid/pebble"
"github.com/go-acme/lego/v4/certificate" "github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/challenge"
@ -18,7 +20,6 @@ import (
"github.com/go-acme/lego/v4/providers/dns/duckdns" "github.com/go-acme/lego/v4/providers/dns/duckdns"
"github.com/go-acme/lego/v4/providers/dns/namesilo" "github.com/go-acme/lego/v4/providers/dns/namesilo"
"github.com/go-acme/lego/v4/registration" "github.com/go-acme/lego/v4/registration"
"io"
"log" "log"
"math/rand" "math/rand"
"net/http" "net/http"
@ -28,13 +29,7 @@ import (
"time" "time"
) )
var ( var ErrUnsupportedDNSProvider = errors.New("unsupported DNS provider")
ErrUnsupportedDNSProvider = errors.New("unsupported DNS provider")
//go:embed find-next-cert.sql
findNextCertSql string
//go:embed create-tables.sql
createTableCertificates string
)
const ( const (
DomainStateNormal = 0 DomainStateNormal = 0
@ -57,7 +52,7 @@ var testDnsOptions interface {
// `_acme-challenges` TXT records are updated to validate the ownership of the // `_acme-challenges` TXT records are updated to validate the ownership of the
// specified domains. // specified domains.
type Service struct { type Service struct {
db *sql.DB db *database.Queries
httpAcme challenge.Provider httpAcme challenge.Provider
certTicker *time.Ticker certTicker *time.Ticker
certDone chan struct{} certDone chan struct{}
@ -73,7 +68,7 @@ type Service struct {
} }
// NewService creates a new certificate renewal service. // NewService creates a new certificate renewal service.
func NewService(wg *sync.WaitGroup, db *sql.DB, httpAcme challenge.Provider, leConfig LetsEncryptConfig, certDir, keyDir string) (*Service, error) { func NewService(wg *sync.WaitGroup, db *database.Queries, httpAcme challenge.Provider, leConfig LetsEncryptConfig, certDir, keyDir string) (*Service, error) {
s := &Service{ s := &Service{
db: db, db: db,
httpAcme: httpAcme, httpAcme: httpAcme,
@ -104,12 +99,6 @@ func NewService(wg *sync.WaitGroup, db *sql.DB, httpAcme challenge.Provider, leC
return nil, fmt.Errorf("failed to resolve LetsEncrypt account private key: %w", err) return nil, fmt.Errorf("failed to resolve LetsEncrypt account private key: %w", err)
} }
// init domains table
_, err = s.db.Exec(createTableCertificates)
if err != nil {
return nil, fmt.Errorf("failed to create certificates table: %w", err)
}
// resolve CA information // resolve CA information
s.resolveCADirectory(leConfig.Directory) s.resolveCADirectory(leConfig.Directory)
err = s.resolveCACertificate(leConfig.Certificate) err = s.resolveCACertificate(leConfig.Certificate)
@ -286,50 +275,30 @@ func (s *Service) findNextCertificateToRenew() (*localCertData, error) {
d := &localCertData{} d := &localCertData{}
// sql or something, the query is in `find-next-cert.sql` // sql or something, the query is in `find-next-cert.sql`
row, err := s.db.Query(findNextCertSql) row, err := s.db.FindNextCert(context.Background())
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to run query: %w", err) return nil, fmt.Errorf("failed to run query: %w", err)
} }
defer row.Close()
// if next fails no rows were found d.id = row.ID
if !row.Next() { d.dns.name = row.Type
return nil, nil d.dns.token = row.Token
} d.notAfter = row.NotAfter
d.tempParent = row.TempParent
// scan the first row
err = row.Scan(&d.id, &d.notAfter, &d.dns.name, &d.dns.token, &d.tempParent)
switch err {
case nil:
// no nothing
break
case io.EOF:
// no certificate to update
return nil, nil
default:
return nil, fmt.Errorf("failed to scan table row: %w", err)
}
return d, nil return d, nil
} }
func (s *Service) fetchDomains(localData *localCertData) ([]string, error) { func (s *Service) fetchDomains(localData *localCertData) ([]string, error) {
// more sql: this one just grabs all the domains for a certificate // more sql: this one just grabs all the domains for a certificate
query, err := s.db.Query(`SELECT domain FROM certificate_domains WHERE cert_id = ?`, resolveTempParent(localData)) domains, err := s.db.GetDomainsForCertificate(context.Background(), resolveTempParent(localData))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch domains for certificate: %d: %w", localData.id, err) return nil, fmt.Errorf("failed to fetch domains for certificate: %d: %w", localData.id, err)
} }
// convert query responses to a string slice
domains := make([]string, 0)
for query.Next() {
var domain string
err := query.Scan(&domain)
if err != nil {
return nil, fmt.Errorf("failed to scan row from domains table: %d: %w", localData.id, err)
}
domains = append(domains, domain)
}
// if no domains were found then the renewal will fail // if no domains were found then the renewal will fail
if len(domains) == 0 { if len(domains) == 0 {
return nil, fmt.Errorf("no domains registered for certificate: %d", localData.id) return nil, fmt.Errorf("no domains registered for certificate: %d", localData.id)
@ -391,7 +360,7 @@ func (s *Service) getDnsProvider(name, token string) (challenge.Provider, error)
// getPrivateKey reads the private key for the specified certificate id, or // getPrivateKey reads the private key for the specified certificate id, or
// generates one is the file doesn't exist // generates one is the file doesn't exist
func (s *Service) getPrivateKey(id uint64) (*rsa.PrivateKey, error) { func (s *Service) getPrivateKey(id int64) (*rsa.PrivateKey, error) {
fPath := filepath.Join(s.keyDir, fmt.Sprintf("%d.key.pem", id)) fPath := filepath.Join(s.keyDir, fmt.Sprintf("%d.key.pem", id))
pemBytes, err := os.ReadFile(fPath) pemBytes, err := os.ReadFile(fPath)
if err != nil { if err != nil {
@ -433,13 +402,20 @@ func (s *Service) renewCert(localData *localCertData) error {
} }
// set the NotAfter/NotBefore in the database // set the NotAfter/NotBefore in the database
_, err = s.db.Exec(`UPDATE certificates SET renewing = 0, renew_failed = 0, not_after = ?, updated_at = ? WHERE id = ?`, cert.NotAfter, cert.NotBefore, localData.id) err = s.db.UpdateCertAfterRenewal(context.Background(), database.UpdateCertAfterRenewalParams{
NotAfter: cert.NotAfter,
UpdatedAt: cert.NotBefore,
ID: localData.id,
})
if err != nil { if err != nil {
return fmt.Errorf("failed to update cert %d in database: %w", localData.id, err) return fmt.Errorf("failed to update cert %d in database: %w", localData.id, err)
} }
// set domains to normal state // set domains to normal state
_, err = s.db.Exec(`UPDATE certificate_domains SET state = ? WHERE cert_id = ?`, DomainStateNormal, localData.id) err = s.db.SetDomainStateForCert(context.Background(), database.SetDomainStateForCertParams{
State: DomainStateNormal,
CertID: localData.id,
})
if err != nil { if err != nil {
return fmt.Errorf("failed to update domains for %d in database: %w", localData.id, err) return fmt.Errorf("failed to update domains for %d in database: %w", localData.id, err)
} }
@ -517,8 +493,12 @@ func (s *Service) renewCertInternal(localData *localCertData) (*x509.Certificate
// setRenewing sets the renewing and failed states in the database for a // setRenewing sets the renewing and failed states in the database for a
// specified certificate id. // specified certificate id.
func (s *Service) setRenewing(id uint64, renewing, failed bool) { func (s *Service) setRenewing(id int64, renewing, failed bool) {
_, err := s.db.Exec("UPDATE certificates SET renewing = ?, renew_failed = ? WHERE id = ?", renewing, failed, id) err := s.db.UpdateRenewingState(context.Background(), database.UpdateRenewingStateParams{
Renewing: renewing,
RenewFailed: failed,
ID: id,
})
if err != nil { if err != nil {
log.Printf("[Renewal] Failed to set renewing/failed mode in database %d: %s\n", id, err) log.Printf("[Renewal] Failed to set renewing/failed mode in database %d: %s\n", id, err)
} }
@ -526,7 +506,7 @@ func (s *Service) setRenewing(id uint64, renewing, failed bool) {
// writeCertFile writes the output certificate file and renames the current one // writeCertFile writes the output certificate file and renames the current one
// to include `-old` in the name. // to include `-old` in the name.
func (s *Service) writeCertFile(id uint64, certBytes []byte) error { func (s *Service) writeCertFile(id int64, certBytes []byte) error {
oldPath := filepath.Join(s.certDir, fmt.Sprintf("%d-old.cert.pem", id)) oldPath := filepath.Join(s.certDir, fmt.Sprintf("%d-old.cert.pem", id))
newPath := filepath.Join(s.certDir, fmt.Sprintf("%d.cert.pem", id)) newPath := filepath.Join(s.certDir, fmt.Sprintf("%d.cert.pem", id))
@ -552,9 +532,9 @@ func (s *Service) writeCertFile(id uint64, certBytes []byte) error {
return nil return nil
} }
func resolveTempParent(local *localCertData) uint64 { func resolveTempParent(local *localCertData) int64 {
if local.tempParent > 0 { if local.tempParent.Valid {
return local.tempParent return local.tempParent.Int64
} }
return local.id return local.id
} }

View File

@ -1,12 +1,14 @@
package servers package servers
import ( import (
"context"
"database/sql" "database/sql"
_ "embed" _ "embed"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/1f349/mjwt" "github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims" "github.com/1f349/mjwt/claims"
"github.com/1f349/orchid/database"
oUtils "github.com/1f349/orchid/utils" oUtils "github.com/1f349/orchid/utils"
vUtils "github.com/1f349/violet/utils" vUtils "github.com/1f349/violet/utils"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
@ -22,7 +24,7 @@ type DomainStateValue struct {
} }
type Certificate struct { type Certificate struct {
Id int `json:"id"` Id int64 `json:"id"`
AutoRenew bool `json:"auto_renew"` AutoRenew bool `json:"auto_renew"`
Active bool `json:"active"` Active bool `json:"active"`
Renewing bool `json:"renewing"` Renewing bool `json:"renewing"`
@ -32,14 +34,11 @@ type Certificate struct {
Domains []string `json:"domains"` Domains []string `json:"domains"`
} }
//go:embed find-owned-certs.sql
var findOwnedCerts string
// NewApiServer creates and runs a http server containing all the API // NewApiServer creates and runs a http server containing all the API
// endpoints for the software // endpoints for the software
// //
// `/cert` - edit certificate // `/cert` - edit certificate
func NewApiServer(listen string, db *sql.DB, signer mjwt.Verifier, domains oUtils.DomainChecker) *http.Server { func NewApiServer(listen string, db *database.Queries, signer mjwt.Verifier, domains oUtils.DomainChecker) *http.Server {
r := httprouter.New() r := httprouter.New()
r.GET("/", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { r.GET("/", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
@ -55,25 +54,28 @@ func NewApiServer(listen string, db *sql.DB, signer mjwt.Verifier, domains oUtil
} }
// query database // query database
query, err := db.Query(findOwnedCerts) rows, err := db.FindOwnedCerts(context.Background())
if err != nil { if err != nil {
log.Println("Failed after reading certificates from database: ", err)
http.Error(rw, "Database Error", http.StatusInternalServerError) http.Error(rw, "Database Error", http.StatusInternalServerError)
return return
} }
mOther := make(map[int]*Certificate) // other certificates mOther := make(map[int64]*Certificate) // other certificates
m := make(map[int]*Certificate) // certificates owned by this user m := make(map[int64]*Certificate) // certificates owned by this user
// loop over query rows // loop over query rows
for query.Next() { for _, row := range rows {
var c Certificate c := Certificate{
var d string Id: row.ID,
err := query.Scan(&c.Id, &c.AutoRenew, &c.Active, &c.Renewing, &c.RenewFailed, &c.NotAfter, &c.UpdatedAt, &d) AutoRenew: row.AutoRenew,
if err != nil { Active: row.Active,
log.Println("Failed to read certificate from database: ", err) Renewing: row.Renewing,
http.Error(rw, "Database Error", http.StatusInternalServerError) RenewFailed: row.RenewFailed,
return NotAfter: row.NotAfter,
UpdatedAt: row.UpdatedAt,
} }
d := row.Domain
// check in owned map // check in owned map
if cert, ok := m[c.Id]; ok { if cert, ok := m[c.Id]; ok {
@ -105,11 +107,6 @@ func NewApiServer(listen string, db *sql.DB, signer mjwt.Verifier, domains oUtil
m[c.Id] = &c m[c.Id] = &c
} }
} }
if err := query.Err(); err != nil {
log.Println("Failed after reading certificates from database: ", err)
http.Error(rw, "Database Error", http.StatusInternalServerError)
return
}
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
_ = json.NewEncoder(rw).Encode(m) _ = json.NewEncoder(rw).Encode(m)
})) }))
@ -124,15 +121,20 @@ func NewApiServer(listen string, db *sql.DB, signer mjwt.Verifier, domains oUtil
})) }))
r.POST("/cert", checkAuthWithPerm(signer, "orchid:cert", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) { r.POST("/cert", checkAuthWithPerm(signer, "orchid:cert", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
_, err := db.Exec(`INSERT INTO certificates (owner, dns, updated_at) VALUES (?, ?, ?)`, b.Subject, 0, time.Now()) err := db.AddCertificate(req.Context(), database.AddCertificateParams{
Owner: b.Subject,
Dns: sql.NullInt64{},
NotAfter: time.Now(),
UpdatedAt: time.Now(),
})
if err != nil { if err != nil {
apiError(rw, http.StatusInternalServerError, "Failed to delete certificate") apiError(rw, http.StatusInternalServerError, "Failed to delete certificate")
return return
} }
rw.WriteHeader(http.StatusAccepted) rw.WriteHeader(http.StatusAccepted)
})) }))
r.DELETE("/cert/:id", checkAuthForCertificate(signer, "orchid:cert", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) { r.DELETE("/cert/:id", checkAuthForCertificate(signer, "orchid:cert", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId int64) {
_, err := db.Exec(`UPDATE certificates SET active = 0 WHERE id = ?`, certId) err := db.RemoveCertificate(req.Context(), certId)
if err != nil { if err != nil {
apiError(rw, http.StatusInternalServerError, "Failed to delete certificate") apiError(rw, http.StatusInternalServerError, "Failed to delete certificate")
return return
@ -194,7 +196,7 @@ func apiError(rw http.ResponseWriter, code int, m string) {
// lookupCertOwner finds the certificate matching the id string and returns the // lookupCertOwner finds the certificate matching the id string and returns the
// numeric id, owner and possible error, only works for active certificates. // numeric id, owner and possible error, only works for active certificates.
func checkCertOwner(db *sql.DB, idStr string, b AuthClaims) (uint64, error) { func checkCertOwner(db *database.Queries, idStr string, b AuthClaims) (int64, error) {
// parse the id // parse the id
rawId, err := strconv.ParseUint(idStr, 10, 64) rawId, err := strconv.ParseUint(idStr, 10, 64)
if err != nil { if err != nil {
@ -202,54 +204,18 @@ func checkCertOwner(db *sql.DB, idStr string, b AuthClaims) (uint64, error) {
} }
// run database query // run database query
row := db.QueryRow(`SELECT id, owner FROM certificates WHERE active = 1 and id = ?`, rawId) row, err := db.CheckCertOwner(context.Background(), int64(rawId))
// scan in result values
var id uint64
var owner string
err = row.Scan(&id, &owner)
if err != nil { if err != nil {
return 0, fmt.Errorf("scan error: %w", err) return 0, err
} }
// check the owner is the mjwt token subject // check the owner is the mjwt token subject
if b.Subject != owner { if b.Subject != row.Owner {
return id, fmt.Errorf("not the certificate owner") return row.ID, fmt.Errorf("not the certificate owner")
} }
// it's all valid, return the values // it's all valid, return the values
return id, nil return row.ID, nil
}
// safeTransaction completes a database transaction safely allowing for rollbacks
// if the callback errors
func safeTransaction(rw http.ResponseWriter, db *sql.DB, cb func(rw http.ResponseWriter, tx *sql.Tx) error) error {
// start a transaction
begin, err := db.Begin()
if err != nil {
return fmt.Errorf("failed to begin a transaction")
}
// init defer rollback
needsRollback := true
defer func() {
if needsRollback {
_ = begin.Rollback()
}
}()
// run main code within the transaction session
err = cb(rw, begin)
if err != nil {
return err
}
// clear the rollback flag and commit the transaction
needsRollback = false
if begin.Commit() != nil {
return fmt.Errorf("failed to commit a transaction")
}
return nil
} }
// getDomainOwnershipClaims returns the domains marked as owned from PermStorage, // getDomainOwnershipClaims returns the domains marked as owned from PermStorage,

View File

@ -1,9 +1,9 @@
package servers package servers
import ( import (
"database/sql"
"github.com/1f349/mjwt" "github.com/1f349/mjwt"
"github.com/1f349/mjwt/auth" "github.com/1f349/mjwt/auth"
"github.com/1f349/orchid/database"
vUtils "github.com/1f349/violet/utils" vUtils "github.com/1f349/violet/utils"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"log" "log"
@ -14,7 +14,7 @@ type AuthClaims mjwt.BaseTypeClaims[auth.AccessTokenClaims]
type AuthCallback func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) type AuthCallback func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims)
type CertAuthCallback func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) type CertAuthCallback func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId int64)
// checkAuth validates the bearer token against a mjwt.Verifier and returns an // checkAuth validates the bearer token against a mjwt.Verifier and returns an
// error message or continues to the next handler // error message or continues to the next handler
@ -53,7 +53,7 @@ func checkAuthWithPerm(verify mjwt.Verifier, perm string, cb AuthCallback) httpr
} }
// checkAuthForCertificate // checkAuthForCertificate
func checkAuthForCertificate(verify mjwt.Verifier, perm string, db *sql.DB, cb CertAuthCallback) httprouter.Handle { func checkAuthForCertificate(verify mjwt.Verifier, perm string, db *database.Queries, cb CertAuthCallback) httprouter.Handle {
return checkAuthWithPerm(verify, perm, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) { return checkAuthWithPerm(verify, perm, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
// lookup certificate owner // lookup certificate owner
id, err := checkCertOwner(db, params.ByName("id"), b) id, err := checkCertOwner(db, params.ByName("id"), b)

View File

@ -1,48 +1,37 @@
package servers package servers
import ( import (
"database/sql" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/1f349/mjwt" "github.com/1f349/mjwt"
"github.com/1f349/orchid/database"
"github.com/1f349/orchid/renewal" "github.com/1f349/orchid/renewal"
"github.com/1f349/orchid/utils" "github.com/1f349/orchid/utils"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"net/http" "net/http"
) )
func certDomainManageGET(db *sql.DB, signer mjwt.Verifier) httprouter.Handle { func certDomainManageGET(db *database.Queries, signer mjwt.Verifier) httprouter.Handle {
return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) { return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) {
query, err := db.Query(`SELECT domain, state FROM certificate_domains WHERE cert_id = ?`, certId) rows, err := db.GetDomainStatesForCert(context.Background(), int64(certId))
if err != nil { if err != nil {
apiError(rw, http.StatusInsufficientStorage, "Database error") apiError(rw, http.StatusInsufficientStorage, "Database error")
return return
} }
// collect all the domains and state values
var domainStates []DomainStateValue
for query.Next() {
var a DomainStateValue
err := query.Scan(&a.Domain, &a.State)
if err != nil {
apiError(rw, http.StatusInsufficientStorage, "Database error")
return
}
domainStates = append(domainStates, a)
}
// write output // write output
rw.WriteHeader(http.StatusAccepted) rw.WriteHeader(http.StatusAccepted)
m := map[string]any{ m := map[string]any{
"id": fmt.Sprintf("%d", certId), "id": fmt.Sprintf("%d", certId),
"domains": domainStates, "domains": rows,
} }
_ = json.NewEncoder(rw).Encode(m) _ = json.NewEncoder(rw).Encode(m)
}) })
} }
func certDomainManagePUTandDELETE(db *sql.DB, signer mjwt.Verifier, domains utils.DomainChecker) httprouter.Handle { func certDomainManagePUTandDELETE(db *database.Queries, signer mjwt.Verifier, domains utils.DomainChecker) httprouter.Handle {
return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) { return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId int64) {
// check request type // check request type
isAdd := req.Method == http.MethodPut isAdd := req.Method == http.MethodPut
@ -66,18 +55,25 @@ func certDomainManagePUTandDELETE(db *sql.DB, signer mjwt.Verifier, domains util
} }
// run a safe transaction to insert or update the certificate domains // run a safe transaction to insert or update the certificate domains
if safeTransaction(rw, db, func(rw http.ResponseWriter, tx *sql.Tx) error { if db.UseTx(req.Context(), func(tx *database.Queries) error {
if isAdd { if isAdd {
// insert domains to add // insert domains to add
for _, i := range d { for _, i := range d {
_, err := tx.Exec(`INSERT INTO certificate_domains (cert_id, domain, state) VALUES (?, ?, ?)`, certId, i, renewal.DomainStateAdded) err := tx.AddDomains(req.Context(), database.AddDomainsParams{
CertID: certId,
Domain: i,
State: renewal.DomainStateAdded,
})
if err != nil { if err != nil {
return fmt.Errorf("failed to add domains to the database") return fmt.Errorf("failed to add domains to the database")
} }
} }
} else { } else {
// update domains to removed state // update domains to removed state
_, err := tx.Exec(`UPDATE certificate_domains SET state = ? WHERE domain IN ?`, renewal.DomainStateRemoved, d) err := tx.UpdateDomains(req.Context(), database.UpdateDomainsParams{
State: renewal.DomainStateRemoved,
Domain: d,
})
if err != nil { if err != nil {
return fmt.Errorf("failed to remove domains from the database") return fmt.Errorf("failed to remove domains from the database")
} }

View File

@ -1,3 +0,0 @@
select cert.id, cert.auto_renew, cert.active, cert.renewing, cert.renew_failed, cert.not_after, cert.updated_at, certificate_domains.domain
from certificates as cert
inner join certificate_domains on cert.id = certificate_domains.cert_id

10
sqlc.yaml Normal file
View File

@ -0,0 +1,10 @@
version: "2"
sql:
- engine: sqlite
queries: database/queries
schema: database/migrations
gen:
go:
package: "database"
out: "database"
emit_json_tags: true