mirror of
https://github.com/1f349/orchid.git
synced 2024-12-22 08:04:10 +00:00
Loads of API code
This commit is contained in:
parent
d648555af1
commit
d6927cd822
@ -87,7 +87,7 @@ func normalLoad(conf startUpConfig, wd string) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("[Orchid] Error:", err)
|
log.Fatal("[Orchid] Error:", err)
|
||||||
}
|
}
|
||||||
srv := servers.NewApiServer(conf.Listen, mJwtVerify, conf.Domains)
|
srv := servers.NewApiServer(conf.Listen, db, mJwtVerify, conf.Domains)
|
||||||
utils.RunBackgroundHttp("API", srv)
|
utils.RunBackgroundHttp("API", srv)
|
||||||
|
|
||||||
// Wait for exit signal
|
// Wait for exit signal
|
||||||
|
@ -5,8 +5,8 @@ package pebble
|
|||||||
import _ "embed"
|
import _ "embed"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
//go:embed asset/pebble-cert.pem
|
//go:embed pebble-cert.pem
|
||||||
RawCert []byte
|
RawCert []byte
|
||||||
//go:embed asset/pebble-config.json
|
//go:embed pebble-config.json
|
||||||
RawConfig []byte
|
RawConfig []byte
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
CREATE TABLE IF NOT EXISTS certificates
|
CREATE TABLE IF NOT EXISTS certificates
|
||||||
(
|
(
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
owner INTEGER,
|
owner VARCHAR,
|
||||||
dns INTEGER,
|
dns INTEGER,
|
||||||
auto_renew INTEGER DEFAULT 0,
|
auto_renew INTEGER DEFAULT 0,
|
||||||
active INTEGER DEFAULT 0,
|
active INTEGER DEFAULT 0,
|
||||||
@ -9,7 +9,9 @@ CREATE TABLE IF NOT EXISTS certificates
|
|||||||
renew_failed INTEGER DEFAULT 0,
|
renew_failed INTEGER DEFAULT 0,
|
||||||
not_after DATETIME,
|
not_after DATETIME,
|
||||||
updated_at DATETIME,
|
updated_at DATETIME,
|
||||||
FOREIGN KEY (dns) REFERENCES dns (id)
|
temp_parent INTEGER DEFAULT 0,
|
||||||
|
FOREIGN KEY (dns) REFERENCES dns_acme (id),
|
||||||
|
FOREIGN KEY (temp_parent) REFERENCES certificates (id)
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS certificate_domains
|
CREATE TABLE IF NOT EXISTS certificate_domains
|
||||||
@ -17,10 +19,12 @@ CREATE TABLE IF NOT EXISTS certificate_domains
|
|||||||
domain_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
domain_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
cert_id INTEGER,
|
cert_id INTEGER,
|
||||||
domain VARCHAR,
|
domain VARCHAR,
|
||||||
|
state INTEGER DEFAULT 1,
|
||||||
|
UNIQUE (cert_id, domain),
|
||||||
FOREIGN KEY (cert_id) REFERENCES certificates (id)
|
FOREIGN KEY (cert_id) REFERENCES certificates (id)
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS dns
|
CREATE TABLE IF NOT EXISTS dns_acme
|
||||||
(
|
(
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
type VARCHAR,
|
type VARCHAR,
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
select cert.id, cert.not_after, dns.type, dns.token
|
select cert.id, cert.not_after, dns.type, dns.token, cert.temp_parent
|
||||||
from certificates as cert
|
from certificates as cert
|
||||||
left outer join dns on cert.dns = dns.id
|
left outer join dns on cert.dns = dns.id
|
||||||
where cert.active = 1
|
where cert.active = 1
|
||||||
and cert.auto_renew = 1
|
and (cert.auto_renew = 1 or cert.not_after IS NULL)
|
||||||
and cert.renewing = 0
|
and cert.renewing = 0
|
||||||
and cert.renew_failed = 0
|
and cert.renew_failed = 0
|
||||||
and (cert.not_after IS NULL or DATETIME(cert.not_after, 'utc', '-30 days') < DATETIME())
|
and (cert.not_after IS NULL or DATETIME(cert.not_after, 'utc', '-30 days') < DATETIME())
|
||||||
order by cert.not_after DESC NULLS FIRST
|
order by cert.temp_parent, cert.not_after DESC NULLS FIRST
|
||||||
|
@ -11,6 +11,7 @@ type localCertData struct {
|
|||||||
name sql.NullString
|
name sql.NullString
|
||||||
token sql.NullString
|
token sql.NullString
|
||||||
}
|
}
|
||||||
notAfter sql.NullTime
|
notAfter sql.NullTime
|
||||||
domains []string
|
domains []string
|
||||||
|
tempParent uint64
|
||||||
}
|
}
|
||||||
|
@ -35,6 +35,12 @@ var (
|
|||||||
createTableCertificates string
|
createTableCertificates string
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DomainStateNormal = 0
|
||||||
|
DomainStateAdded = 1
|
||||||
|
DomainStateRemoved = 2
|
||||||
|
)
|
||||||
|
|
||||||
// overrides only used in testing
|
// overrides only used in testing
|
||||||
var testDnsOptions interface {
|
var testDnsOptions interface {
|
||||||
challenge.Provider
|
challenge.Provider
|
||||||
@ -282,7 +288,7 @@ func (s *Service) findNextCertificateToRenew() (*localCertData, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// scan the first row
|
// scan the first row
|
||||||
err = row.Scan(&d.id, &d.notAfter, &d.dns.name, &d.dns.token)
|
err = row.Scan(&d.id, &d.notAfter, &d.dns.name, &d.dns.token, &d.tempParent)
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
// no nothing
|
// no nothing
|
||||||
@ -299,7 +305,7 @@ func (s *Service) findNextCertificateToRenew() (*localCertData, error) {
|
|||||||
|
|
||||||
func (s *Service) fetchDomains(localData *localCertData) ([]string, error) {
|
func (s *Service) fetchDomains(localData *localCertData) ([]string, error) {
|
||||||
// more sql: this one just grabs all the domains for a certificate
|
// more sql: this one just grabs all the domains for a certificate
|
||||||
query, err := s.db.Query(`SELECT domain FROM certificate_domains WHERE cert_id = ?`, localData.id)
|
query, err := s.db.Query(`SELECT domain FROM certificate_domains WHERE cert_id = ?`, resolveTempParent(localData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to fetch domains for certificate: %d: %w", localData.id, err)
|
return nil, fmt.Errorf("failed to fetch domains for certificate: %d: %w", localData.id, err)
|
||||||
}
|
}
|
||||||
@ -343,9 +349,9 @@ func (s *Service) setupLegoClient(localData *localCertData) (*lego.Client, error
|
|||||||
if testDnsOptions != nil {
|
if testDnsOptions != nil {
|
||||||
// set up the dns provider used during tests and disable propagation as no dns
|
// set up the dns provider used during tests and disable propagation as no dns
|
||||||
// will validate these tests
|
// will validate these tests
|
||||||
dnsAddrs := testDnsOptions.GetDnsAddrs()
|
dnsAddr := testDnsOptions.GetDnsAddrs()
|
||||||
log.Printf("Using testDnsOptions with DNS server: %v\n", dnsAddrs)
|
log.Printf("Using testDnsOptions with DNS server: %v\n", dnsAddr)
|
||||||
_ = client.Challenge.SetDNS01Provider(testDnsOptions, dns01.AddRecursiveNameservers(dnsAddrs), dns01.DisableCompletePropagationRequirement())
|
_ = client.Challenge.SetDNS01Provider(testDnsOptions, dns01.AddRecursiveNameservers(dnsAddr), dns01.DisableCompletePropagationRequirement())
|
||||||
} else if localData.dns.name.Valid && localData.dns.token.Valid {
|
} else if localData.dns.name.Valid && localData.dns.token.Valid {
|
||||||
// if the dns name and token are "valid" meaning non-null in this case
|
// if the dns name and token are "valid" meaning non-null in this case
|
||||||
// set up the specific dns provider requested
|
// set up the specific dns provider requested
|
||||||
@ -413,6 +419,12 @@ func (s *Service) renewCert(localData *localCertData) error {
|
|||||||
return fmt.Errorf("failed to update cert %d in database: %w", localData.id, err)
|
return fmt.Errorf("failed to update cert %d in database: %w", localData.id, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set domains to normal state
|
||||||
|
_, err = s.db.Exec(`UPDATE certificate_domains SET state = ? WHERE cert_id = ?`, DomainStateNormal, localData.id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update domains for %d in database: %w", localData.id, err)
|
||||||
|
}
|
||||||
|
|
||||||
// write out the certificate file
|
// write out the certificate file
|
||||||
err = s.writeCertFile(localData.id, certBytes)
|
err = s.writeCertFile(localData.id, certBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -506,3 +518,10 @@ func (s *Service) writeCertFile(id uint64, certBytes []byte) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resolveTempParent(local *localCertData) uint64 {
|
||||||
|
if local.tempParent > 0 {
|
||||||
|
return local.tempParent
|
||||||
|
}
|
||||||
|
return local.id
|
||||||
|
}
|
||||||
|
170
servers/api.go
170
servers/api.go
@ -1,40 +1,88 @@
|
|||||||
package servers
|
package servers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"github.com/MrMelon54/mjwt"
|
"github.com/MrMelon54/mjwt"
|
||||||
"github.com/MrMelon54/mjwt/auth"
|
|
||||||
oUtils "github.com/MrMelon54/orchid/utils"
|
oUtils "github.com/MrMelon54/orchid/utils"
|
||||||
vUtils "github.com/MrMelon54/violet/utils"
|
vUtils "github.com/MrMelon54/violet/utils"
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type DomainStateValue struct {
|
||||||
|
Domain string `json:"domain"`
|
||||||
|
State int `json:"state"`
|
||||||
|
}
|
||||||
|
|
||||||
// NewApiServer creates and runs a http server containing all the API
|
// NewApiServer creates and runs a http server containing all the API
|
||||||
// endpoints for the software
|
// endpoints for the software
|
||||||
//
|
//
|
||||||
// `/cert` - edit certificate
|
// `/cert` - edit certificate
|
||||||
func NewApiServer(listen string, signer mjwt.Verifier, domains oUtils.DomainChecker) *http.Server {
|
func NewApiServer(listen string, db *sql.DB, signer mjwt.Verifier, domains oUtils.DomainChecker) *http.Server {
|
||||||
r := httprouter.New()
|
r := httprouter.New()
|
||||||
|
|
||||||
// Endpoint for adding a certificate
|
// Endpoint for looking up a certificate
|
||||||
r.POST("/cert", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
r.GET("/lookup/:domain", checkAuthWithPerm(signer, "orchid:cert", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
|
||||||
// TODO: register domains to a certificate
|
|
||||||
vUtils.RespondVioletError(rw, http.StatusNotImplemented, "API unavailable")
|
|
||||||
rw.WriteHeader(http.StatusNotImplemented)
|
|
||||||
return
|
|
||||||
|
|
||||||
if !hasPerms(signer, req, "orchid:cert:") {
|
|
||||||
vUtils.RespondHttpStatus(rw, http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
domain := params.ByName("domain")
|
domain := params.ByName("domain")
|
||||||
if !domains.ValidateDomain(domain) {
|
if !domains.ValidateDomain(domain) {
|
||||||
vUtils.RespondVioletError(rw, http.StatusBadRequest, "Invalid domain")
|
vUtils.RespondVioletError(rw, http.StatusBadRequest, "Invalid domain")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
r.POST("/cert", checkAuthWithPerm(signer, "orchid:cert:create", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
|
||||||
|
_, err := db.Exec(`INSERT INTO certificates (owner, dns, updated_at) VALUES (?, ?, ?)`, b.Subject, 0, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
apiError(rw, http.StatusInternalServerError, "Failed to delete certificate")
|
||||||
|
return
|
||||||
|
}
|
||||||
rw.WriteHeader(http.StatusAccepted)
|
rw.WriteHeader(http.StatusAccepted)
|
||||||
})
|
}))
|
||||||
|
r.DELETE("/cert/:id", checkAuthForCertificate(signer, "orchid:cert:delete", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) {
|
||||||
|
_, err := db.Exec(`UPDATE certificates SET active = 0 WHERE id = ?`, certId)
|
||||||
|
if err != nil {
|
||||||
|
apiError(rw, http.StatusInternalServerError, "Failed to delete certificate")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rw.WriteHeader(http.StatusAccepted)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Endpoint for adding/removing domains to/from a certificate
|
||||||
|
manageGet, managePutDelete := certDomainManageGET(db, signer), certDomainManagePUTandDELETE(db, signer, domains)
|
||||||
|
r.GET("/cert/:id/domains", manageGet)
|
||||||
|
r.PUT("/cert/:id/domains", managePutDelete)
|
||||||
|
r.DELETE("/cert/:id/domains", managePutDelete)
|
||||||
|
|
||||||
|
// Endpoint for generating a temporary certificate for modified domains
|
||||||
|
r.POST("/cert/:id/temp", checkAuth(signer, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
|
||||||
|
if !b.Claims.Perms.Has("orchid:cert:quick") {
|
||||||
|
apiError(rw, http.StatusForbidden, "No permission")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookup certificate owner
|
||||||
|
id, err := checkCertOwner(db, "", b)
|
||||||
|
if err != nil {
|
||||||
|
apiError(rw, http.StatusInsufficientStorage, "Database error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// run a safe transaction to create the temporary certificate
|
||||||
|
if safeTransaction(rw, db, func(rw http.ResponseWriter, tx *sql.Tx) error {
|
||||||
|
// insert temporary certificate into database
|
||||||
|
_, err := db.Exec(`INSERT INTO certificates (owner, dns, active, updated_at, temp_parent) VALUES (?, 0, 1, ?, ?)`, b.Subject, time.Now(), id)
|
||||||
|
return err
|
||||||
|
}) != nil {
|
||||||
|
apiError(rw, http.StatusInsufficientStorage, "Database error")
|
||||||
|
fmt.Printf("Internal error: %s\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
// Create and run http server
|
// Create and run http server
|
||||||
return &http.Server{
|
return &http.Server{
|
||||||
@ -48,19 +96,83 @@ func NewApiServer(listen string, signer mjwt.Verifier, domains oUtils.DomainChec
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func hasPerms(verify mjwt.Verifier, req *http.Request, perm string) bool {
|
// apiError outputs a generic JSON error message
|
||||||
// Get bearer token
|
func apiError(rw http.ResponseWriter, code int, m string) {
|
||||||
bearer := vUtils.GetBearer(req)
|
rw.WriteHeader(code)
|
||||||
if bearer == "" {
|
_ = json.NewEncoder(rw).Encode(map[string]string{
|
||||||
return false
|
"error": m,
|
||||||
}
|
})
|
||||||
|
}
|
||||||
// Read claims from mjwt
|
|
||||||
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer)
|
// lookupCertOwner finds the certificate matching the id string and returns the
|
||||||
if err != nil {
|
// numeric id, owner and possible error, only works for active certificates.
|
||||||
return false
|
func checkCertOwner(db *sql.DB, idStr string, b AuthClaims) (uint64, error) {
|
||||||
}
|
// parse the id
|
||||||
|
rawId, err := strconv.ParseUint(idStr, 10, 64)
|
||||||
// Token must have perm
|
if err != nil {
|
||||||
return b.Claims.Perms.Has(perm)
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// run database query
|
||||||
|
row := db.QueryRow(`SELECT id, owner FROM certificates WHERE active = 1 and id = ?`, rawId)
|
||||||
|
|
||||||
|
// scan in result values
|
||||||
|
var id uint64
|
||||||
|
var owner string
|
||||||
|
err = row.Scan(&id, &owner)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("scan error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check the owner is the mjwt token subject
|
||||||
|
if b.Subject != owner {
|
||||||
|
return id, fmt.Errorf("not the certificate owner")
|
||||||
|
}
|
||||||
|
|
||||||
|
// it's all valid, return the values
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// safeTransaction completes a database transaction safely allowing for rollbacks
|
||||||
|
// if the callback errors
|
||||||
|
func safeTransaction(rw http.ResponseWriter, db *sql.DB, cb func(rw http.ResponseWriter, tx *sql.Tx) error) error {
|
||||||
|
// start a transaction
|
||||||
|
begin, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to begin a transaction")
|
||||||
|
}
|
||||||
|
|
||||||
|
// init defer rollback
|
||||||
|
needsRollback := true
|
||||||
|
defer func() {
|
||||||
|
if needsRollback {
|
||||||
|
_ = begin.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// run main code within the transaction session
|
||||||
|
err = cb(rw, begin)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear the rollback flag and commit the transaction
|
||||||
|
needsRollback = false
|
||||||
|
if begin.Commit() != nil {
|
||||||
|
return fmt.Errorf("failed to commit a transaction")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateDomainAudienceClaims validates if the audience claims contain the
|
||||||
|
// `owns=<fqdn>` field with the matching top level domain
|
||||||
|
func validateDomainAudienceClaims(a string, aud jwt.ClaimStrings) bool {
|
||||||
|
if fqdn, ok := vUtils.GetTopFqdn(a); ok {
|
||||||
|
for _, i := range aud {
|
||||||
|
if i == "owns="+fqdn {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
66
servers/auth.go
Normal file
66
servers/auth.go
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
package servers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"github.com/MrMelon54/mjwt"
|
||||||
|
"github.com/MrMelon54/mjwt/auth"
|
||||||
|
vUtils "github.com/MrMelon54/violet/utils"
|
||||||
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthClaims mjwt.BaseTypeClaims[auth.AccessTokenClaims]
|
||||||
|
|
||||||
|
type AuthCallback func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims)
|
||||||
|
|
||||||
|
type CertAuthCallback func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64)
|
||||||
|
|
||||||
|
// checkAuth validates the bearer token against a mjwt.Verifier and returns an
|
||||||
|
// error message or continues to the next handler
|
||||||
|
func checkAuth(verify mjwt.Verifier, cb AuthCallback) httprouter.Handle {
|
||||||
|
return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||||
|
// Get bearer token
|
||||||
|
bearer := vUtils.GetBearer(req)
|
||||||
|
if bearer == "" {
|
||||||
|
apiError(rw, http.StatusForbidden, "Missing bearer token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read claims from mjwt
|
||||||
|
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer)
|
||||||
|
if err != nil {
|
||||||
|
apiError(rw, http.StatusForbidden, "Invalid token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cb(rw, req, params, AuthClaims(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkAuthWithPerm validates the bearer token and checks if it contains a
|
||||||
|
// required permission and returns an error message or continues to the next
|
||||||
|
// handler
|
||||||
|
func checkAuthWithPerm(verify mjwt.Verifier, perm string, cb AuthCallback) httprouter.Handle {
|
||||||
|
return checkAuth(verify, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
|
||||||
|
// check perms
|
||||||
|
if !b.Claims.Perms.Has(perm) {
|
||||||
|
apiError(rw, http.StatusForbidden, "No permission")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cb(rw, req, params, b)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkAuthForCertificate
|
||||||
|
func checkAuthForCertificate(verify mjwt.Verifier, perm string, db *sql.DB, cb CertAuthCallback) httprouter.Handle {
|
||||||
|
return checkAuthWithPerm(verify, perm, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
|
||||||
|
// lookup certificate owner
|
||||||
|
id, err := checkCertOwner(db, "", b)
|
||||||
|
if err != nil {
|
||||||
|
apiError(rw, http.StatusInsufficientStorage, "Database error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cb(rw, req, params, b, id)
|
||||||
|
})
|
||||||
|
}
|
108
servers/certDomainManage.go
Normal file
108
servers/certDomainManage.go
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
package servers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/MrMelon54/mjwt"
|
||||||
|
"github.com/MrMelon54/orchid/renewal"
|
||||||
|
"github.com/MrMelon54/orchid/utils"
|
||||||
|
"github.com/julienschmidt/httprouter"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func certDomainManageGET(db *sql.DB, signer mjwt.Verifier) httprouter.Handle {
|
||||||
|
return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) {
|
||||||
|
query, err := db.Query(`SELECT domain, state FROM certificate_domains WHERE cert_id = ?`, certId)
|
||||||
|
if err != nil {
|
||||||
|
apiError(rw, http.StatusInsufficientStorage, "Database error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// collect all the domains and state values
|
||||||
|
var domainStates []DomainStateValue
|
||||||
|
for query.Next() {
|
||||||
|
var a DomainStateValue
|
||||||
|
err := query.Scan(&a.Domain, &a.State)
|
||||||
|
if err != nil {
|
||||||
|
apiError(rw, http.StatusInsufficientStorage, "Database error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
domainStates = append(domainStates, a)
|
||||||
|
}
|
||||||
|
|
||||||
|
// write output
|
||||||
|
rw.WriteHeader(http.StatusAccepted)
|
||||||
|
m := map[string]any{
|
||||||
|
"id": fmt.Sprintf("%d", certId),
|
||||||
|
"domains": domainStates,
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(rw).Encode(m)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func certDomainManagePUTandDELETE(db *sql.DB, signer mjwt.Verifier, domains utils.DomainChecker) httprouter.Handle {
|
||||||
|
return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) {
|
||||||
|
// check request type
|
||||||
|
isAdd := req.Method == http.MethodPut
|
||||||
|
|
||||||
|
if len(b.Audience) == 0 {
|
||||||
|
apiError(rw, http.StatusForbidden, "Missing audience tag, to specify owned domains")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// read domains from request body
|
||||||
|
var d []string
|
||||||
|
if json.NewDecoder(req.Body).Decode(&d) != nil {
|
||||||
|
apiError(rw, http.StatusBadRequest, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate all domains
|
||||||
|
for _, i := range d {
|
||||||
|
if !validateDomainAudienceClaims(i, b.Audience) {
|
||||||
|
apiError(rw, http.StatusBadRequest, "Token cannot modify a specified domain")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !domains.ValidateDomain(i) {
|
||||||
|
apiError(rw, http.StatusBadRequest, "Invalid domain")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// run a safe transaction to insert or update the certificate domains
|
||||||
|
if safeTransaction(rw, db, func(rw http.ResponseWriter, tx *sql.Tx) error {
|
||||||
|
if isAdd {
|
||||||
|
// insert domains to add
|
||||||
|
for _, i := range d {
|
||||||
|
_, err := tx.Exec(`INSERT INTO certificate_domains (cert_id, domain, state) VALUES (?, ?, ?)`, certId, i, renewal.DomainStateAdded)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add domains to the database")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// update domains to removed state
|
||||||
|
_, err := tx.Exec(`UPDATE certificate_domains SET state = ? WHERE domain IN ?`, renewal.DomainStateRemoved, d)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to remove domains from the database")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}) != nil {
|
||||||
|
apiError(rw, http.StatusInsufficientStorage, "Database error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// write output
|
||||||
|
rw.WriteHeader(http.StatusAccepted)
|
||||||
|
m := map[string]any{
|
||||||
|
"id": fmt.Sprintf("%d", certId),
|
||||||
|
}
|
||||||
|
if isAdd {
|
||||||
|
m["add_domains"] = d
|
||||||
|
} else {
|
||||||
|
m["remove_domains"] = d
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(rw).Encode(m)
|
||||||
|
})
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user