Loads of API code

This commit is contained in:
Melon 2023-07-10 17:51:14 +01:00
parent d648555af1
commit d6927cd822
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
11 changed files with 355 additions and 45 deletions

View File

@ -87,7 +87,7 @@ func normalLoad(conf startUpConfig, wd string) {
if err != nil {
log.Fatal("[Orchid] Error:", err)
}
srv := servers.NewApiServer(conf.Listen, mJwtVerify, conf.Domains)
srv := servers.NewApiServer(conf.Listen, db, mJwtVerify, conf.Domains)
utils.RunBackgroundHttp("API", srv)
// Wait for exit signal

View File

@ -5,8 +5,8 @@ package pebble
import _ "embed"
var (
//go:embed asset/pebble-cert.pem
//go:embed pebble-cert.pem
RawCert []byte
//go:embed asset/pebble-config.json
//go:embed pebble-config.json
RawConfig []byte
)

View File

@ -1,7 +1,7 @@
CREATE TABLE IF NOT EXISTS certificates
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
owner INTEGER,
owner VARCHAR,
dns INTEGER,
auto_renew INTEGER DEFAULT 0,
active INTEGER DEFAULT 0,
@ -9,7 +9,9 @@ CREATE TABLE IF NOT EXISTS certificates
renew_failed INTEGER DEFAULT 0,
not_after DATETIME,
updated_at DATETIME,
FOREIGN KEY (dns) REFERENCES dns (id)
temp_parent INTEGER DEFAULT 0,
FOREIGN KEY (dns) REFERENCES dns_acme (id),
FOREIGN KEY (temp_parent) REFERENCES certificates (id)
);
CREATE TABLE IF NOT EXISTS certificate_domains
@ -17,10 +19,12 @@ CREATE TABLE IF NOT EXISTS certificate_domains
domain_id INTEGER PRIMARY KEY AUTOINCREMENT,
cert_id INTEGER,
domain VARCHAR,
state INTEGER DEFAULT 1,
UNIQUE (cert_id, domain),
FOREIGN KEY (cert_id) REFERENCES certificates (id)
);
CREATE TABLE IF NOT EXISTS dns
CREATE TABLE IF NOT EXISTS dns_acme
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
type VARCHAR,

View File

@ -1,9 +1,9 @@
select cert.id, cert.not_after, dns.type, dns.token
select cert.id, cert.not_after, dns.type, dns.token, cert.temp_parent
from certificates as cert
left outer join dns on cert.dns = dns.id
where cert.active = 1
and cert.auto_renew = 1
and (cert.auto_renew = 1 or cert.not_after IS NULL)
and cert.renewing = 0
and cert.renew_failed = 0
and (cert.not_after IS NULL or DATETIME(cert.not_after, 'utc', '-30 days') < DATETIME())
order by cert.not_after DESC NULLS FIRST
order by cert.temp_parent, cert.not_after DESC NULLS FIRST

View File

@ -11,6 +11,7 @@ type localCertData struct {
name sql.NullString
token sql.NullString
}
notAfter sql.NullTime
domains []string
notAfter sql.NullTime
domains []string
tempParent uint64
}

View File

@ -35,6 +35,12 @@ var (
createTableCertificates string
)
const (
DomainStateNormal = 0
DomainStateAdded = 1
DomainStateRemoved = 2
)
// overrides only used in testing
var testDnsOptions interface {
challenge.Provider
@ -282,7 +288,7 @@ func (s *Service) findNextCertificateToRenew() (*localCertData, error) {
}
// scan the first row
err = row.Scan(&d.id, &d.notAfter, &d.dns.name, &d.dns.token)
err = row.Scan(&d.id, &d.notAfter, &d.dns.name, &d.dns.token, &d.tempParent)
switch err {
case nil:
// no nothing
@ -299,7 +305,7 @@ func (s *Service) findNextCertificateToRenew() (*localCertData, error) {
func (s *Service) fetchDomains(localData *localCertData) ([]string, error) {
// more sql: this one just grabs all the domains for a certificate
query, err := s.db.Query(`SELECT domain FROM certificate_domains WHERE cert_id = ?`, localData.id)
query, err := s.db.Query(`SELECT domain FROM certificate_domains WHERE cert_id = ?`, resolveTempParent(localData))
if err != nil {
return nil, fmt.Errorf("failed to fetch domains for certificate: %d: %w", localData.id, err)
}
@ -343,9 +349,9 @@ func (s *Service) setupLegoClient(localData *localCertData) (*lego.Client, error
if testDnsOptions != nil {
// set up the dns provider used during tests and disable propagation as no dns
// will validate these tests
dnsAddrs := testDnsOptions.GetDnsAddrs()
log.Printf("Using testDnsOptions with DNS server: %v\n", dnsAddrs)
_ = client.Challenge.SetDNS01Provider(testDnsOptions, dns01.AddRecursiveNameservers(dnsAddrs), dns01.DisableCompletePropagationRequirement())
dnsAddr := testDnsOptions.GetDnsAddrs()
log.Printf("Using testDnsOptions with DNS server: %v\n", dnsAddr)
_ = client.Challenge.SetDNS01Provider(testDnsOptions, dns01.AddRecursiveNameservers(dnsAddr), dns01.DisableCompletePropagationRequirement())
} else if localData.dns.name.Valid && localData.dns.token.Valid {
// if the dns name and token are "valid" meaning non-null in this case
// set up the specific dns provider requested
@ -413,6 +419,12 @@ func (s *Service) renewCert(localData *localCertData) error {
return fmt.Errorf("failed to update cert %d in database: %w", localData.id, err)
}
// set domains to normal state
_, err = s.db.Exec(`UPDATE certificate_domains SET state = ? WHERE cert_id = ?`, DomainStateNormal, localData.id)
if err != nil {
return fmt.Errorf("failed to update domains for %d in database: %w", localData.id, err)
}
// write out the certificate file
err = s.writeCertFile(localData.id, certBytes)
if err != nil {
@ -506,3 +518,10 @@ func (s *Service) writeCertFile(id uint64, certBytes []byte) error {
return nil
}
func resolveTempParent(local *localCertData) uint64 {
if local.tempParent > 0 {
return local.tempParent
}
return local.id
}

View File

@ -1,40 +1,88 @@
package servers
import (
"database/sql"
"encoding/json"
"fmt"
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/mjwt/auth"
oUtils "github.com/MrMelon54/orchid/utils"
vUtils "github.com/MrMelon54/violet/utils"
"github.com/golang-jwt/jwt/v4"
"github.com/julienschmidt/httprouter"
"net/http"
"strconv"
"time"
)
type DomainStateValue struct {
Domain string `json:"domain"`
State int `json:"state"`
}
// NewApiServer creates and runs a http server containing all the API
// endpoints for the software
//
// `/cert` - edit certificate
func NewApiServer(listen string, signer mjwt.Verifier, domains oUtils.DomainChecker) *http.Server {
func NewApiServer(listen string, db *sql.DB, signer mjwt.Verifier, domains oUtils.DomainChecker) *http.Server {
r := httprouter.New()
// Endpoint for adding a certificate
r.POST("/cert", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
// TODO: register domains to a certificate
vUtils.RespondVioletError(rw, http.StatusNotImplemented, "API unavailable")
rw.WriteHeader(http.StatusNotImplemented)
return
if !hasPerms(signer, req, "orchid:cert:") {
vUtils.RespondHttpStatus(rw, http.StatusForbidden)
return
}
// Endpoint for looking up a certificate
r.GET("/lookup/:domain", checkAuthWithPerm(signer, "orchid:cert", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
domain := params.ByName("domain")
if !domains.ValidateDomain(domain) {
vUtils.RespondVioletError(rw, http.StatusBadRequest, "Invalid domain")
return
}
}))
r.POST("/cert", checkAuthWithPerm(signer, "orchid:cert:create", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
_, err := db.Exec(`INSERT INTO certificates (owner, dns, updated_at) VALUES (?, ?, ?)`, b.Subject, 0, time.Now())
if err != nil {
apiError(rw, http.StatusInternalServerError, "Failed to delete certificate")
return
}
rw.WriteHeader(http.StatusAccepted)
})
}))
r.DELETE("/cert/:id", checkAuthForCertificate(signer, "orchid:cert:delete", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) {
_, err := db.Exec(`UPDATE certificates SET active = 0 WHERE id = ?`, certId)
if err != nil {
apiError(rw, http.StatusInternalServerError, "Failed to delete certificate")
return
}
rw.WriteHeader(http.StatusAccepted)
}))
// Endpoint for adding/removing domains to/from a certificate
manageGet, managePutDelete := certDomainManageGET(db, signer), certDomainManagePUTandDELETE(db, signer, domains)
r.GET("/cert/:id/domains", manageGet)
r.PUT("/cert/:id/domains", managePutDelete)
r.DELETE("/cert/:id/domains", managePutDelete)
// Endpoint for generating a temporary certificate for modified domains
r.POST("/cert/:id/temp", checkAuth(signer, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
if !b.Claims.Perms.Has("orchid:cert:quick") {
apiError(rw, http.StatusForbidden, "No permission")
return
}
// lookup certificate owner
id, err := checkCertOwner(db, "", b)
if err != nil {
apiError(rw, http.StatusInsufficientStorage, "Database error")
return
}
// run a safe transaction to create the temporary certificate
if safeTransaction(rw, db, func(rw http.ResponseWriter, tx *sql.Tx) error {
// insert temporary certificate into database
_, err := db.Exec(`INSERT INTO certificates (owner, dns, active, updated_at, temp_parent) VALUES (?, 0, 1, ?, ?)`, b.Subject, time.Now(), id)
return err
}) != nil {
apiError(rw, http.StatusInsufficientStorage, "Database error")
fmt.Printf("Internal error: %s\n", err)
return
}
}))
// Create and run http server
return &http.Server{
@ -48,19 +96,83 @@ func NewApiServer(listen string, signer mjwt.Verifier, domains oUtils.DomainChec
}
}
func hasPerms(verify mjwt.Verifier, req *http.Request, perm string) bool {
// Get bearer token
bearer := vUtils.GetBearer(req)
if bearer == "" {
return false
}
// Read claims from mjwt
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer)
if err != nil {
return false
}
// Token must have perm
return b.Claims.Perms.Has(perm)
// apiError outputs a generic JSON error message
func apiError(rw http.ResponseWriter, code int, m string) {
rw.WriteHeader(code)
_ = json.NewEncoder(rw).Encode(map[string]string{
"error": m,
})
}
// lookupCertOwner finds the certificate matching the id string and returns the
// numeric id, owner and possible error, only works for active certificates.
func checkCertOwner(db *sql.DB, idStr string, b AuthClaims) (uint64, error) {
// parse the id
rawId, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
return 0, err
}
// run database query
row := db.QueryRow(`SELECT id, owner FROM certificates WHERE active = 1 and id = ?`, rawId)
// scan in result values
var id uint64
var owner string
err = row.Scan(&id, &owner)
if err != nil {
return 0, fmt.Errorf("scan error: %w", err)
}
// check the owner is the mjwt token subject
if b.Subject != owner {
return id, fmt.Errorf("not the certificate owner")
}
// it's all valid, return the values
return id, nil
}
// safeTransaction completes a database transaction safely allowing for rollbacks
// if the callback errors
func safeTransaction(rw http.ResponseWriter, db *sql.DB, cb func(rw http.ResponseWriter, tx *sql.Tx) error) error {
// start a transaction
begin, err := db.Begin()
if err != nil {
return fmt.Errorf("failed to begin a transaction")
}
// init defer rollback
needsRollback := true
defer func() {
if needsRollback {
_ = begin.Rollback()
}
}()
// run main code within the transaction session
err = cb(rw, begin)
if err != nil {
return err
}
// clear the rollback flag and commit the transaction
needsRollback = false
if begin.Commit() != nil {
return fmt.Errorf("failed to commit a transaction")
}
return nil
}
// validateDomainAudienceClaims validates if the audience claims contain the
// `owns=<fqdn>` field with the matching top level domain
func validateDomainAudienceClaims(a string, aud jwt.ClaimStrings) bool {
if fqdn, ok := vUtils.GetTopFqdn(a); ok {
for _, i := range aud {
if i == "owns="+fqdn {
return true
}
}
}
return false
}

66
servers/auth.go Normal file
View File

@ -0,0 +1,66 @@
package servers
import (
"database/sql"
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/mjwt/auth"
vUtils "github.com/MrMelon54/violet/utils"
"github.com/julienschmidt/httprouter"
"net/http"
)
type AuthClaims mjwt.BaseTypeClaims[auth.AccessTokenClaims]
type AuthCallback func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims)
type CertAuthCallback func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64)
// checkAuth validates the bearer token against a mjwt.Verifier and returns an
// error message or continues to the next handler
func checkAuth(verify mjwt.Verifier, cb AuthCallback) httprouter.Handle {
return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
// Get bearer token
bearer := vUtils.GetBearer(req)
if bearer == "" {
apiError(rw, http.StatusForbidden, "Missing bearer token")
return
}
// Read claims from mjwt
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer)
if err != nil {
apiError(rw, http.StatusForbidden, "Invalid token")
return
}
cb(rw, req, params, AuthClaims(b))
}
}
// checkAuthWithPerm validates the bearer token and checks if it contains a
// required permission and returns an error message or continues to the next
// handler
func checkAuthWithPerm(verify mjwt.Verifier, perm string, cb AuthCallback) httprouter.Handle {
return checkAuth(verify, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
// check perms
if !b.Claims.Perms.Has(perm) {
apiError(rw, http.StatusForbidden, "No permission")
return
}
cb(rw, req, params, b)
})
}
// checkAuthForCertificate
func checkAuthForCertificate(verify mjwt.Verifier, perm string, db *sql.DB, cb CertAuthCallback) httprouter.Handle {
return checkAuthWithPerm(verify, perm, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
// lookup certificate owner
id, err := checkCertOwner(db, "", b)
if err != nil {
apiError(rw, http.StatusInsufficientStorage, "Database error")
return
}
cb(rw, req, params, b, id)
})
}

108
servers/certDomainManage.go Normal file
View File

@ -0,0 +1,108 @@
package servers
import (
"database/sql"
"encoding/json"
"fmt"
"github.com/MrMelon54/mjwt"
"github.com/MrMelon54/orchid/renewal"
"github.com/MrMelon54/orchid/utils"
"github.com/julienschmidt/httprouter"
"net/http"
)
func certDomainManageGET(db *sql.DB, signer mjwt.Verifier) httprouter.Handle {
return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) {
query, err := db.Query(`SELECT domain, state FROM certificate_domains WHERE cert_id = ?`, certId)
if err != nil {
apiError(rw, http.StatusInsufficientStorage, "Database error")
return
}
// collect all the domains and state values
var domainStates []DomainStateValue
for query.Next() {
var a DomainStateValue
err := query.Scan(&a.Domain, &a.State)
if err != nil {
apiError(rw, http.StatusInsufficientStorage, "Database error")
return
}
domainStates = append(domainStates, a)
}
// write output
rw.WriteHeader(http.StatusAccepted)
m := map[string]any{
"id": fmt.Sprintf("%d", certId),
"domains": domainStates,
}
_ = json.NewEncoder(rw).Encode(m)
})
}
func certDomainManagePUTandDELETE(db *sql.DB, signer mjwt.Verifier, domains utils.DomainChecker) httprouter.Handle {
return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) {
// check request type
isAdd := req.Method == http.MethodPut
if len(b.Audience) == 0 {
apiError(rw, http.StatusForbidden, "Missing audience tag, to specify owned domains")
return
}
// read domains from request body
var d []string
if json.NewDecoder(req.Body).Decode(&d) != nil {
apiError(rw, http.StatusBadRequest, "Invalid request body")
return
}
// validate all domains
for _, i := range d {
if !validateDomainAudienceClaims(i, b.Audience) {
apiError(rw, http.StatusBadRequest, "Token cannot modify a specified domain")
return
}
if !domains.ValidateDomain(i) {
apiError(rw, http.StatusBadRequest, "Invalid domain")
return
}
}
// run a safe transaction to insert or update the certificate domains
if safeTransaction(rw, db, func(rw http.ResponseWriter, tx *sql.Tx) error {
if isAdd {
// insert domains to add
for _, i := range d {
_, err := tx.Exec(`INSERT INTO certificate_domains (cert_id, domain, state) VALUES (?, ?, ?)`, certId, i, renewal.DomainStateAdded)
if err != nil {
return fmt.Errorf("failed to add domains to the database")
}
}
} else {
// update domains to removed state
_, err := tx.Exec(`UPDATE certificate_domains SET state = ? WHERE domain IN ?`, renewal.DomainStateRemoved, d)
if err != nil {
return fmt.Errorf("failed to remove domains from the database")
}
}
return nil
}) != nil {
apiError(rw, http.StatusInsufficientStorage, "Database error")
return
}
// write output
rw.WriteHeader(http.StatusAccepted)
m := map[string]any{
"id": fmt.Sprintf("%d", certId),
}
if isAdd {
m["add_domains"] = d
} else {
m["remove_domains"] = d
}
_ = json.NewEncoder(rw).Encode(m)
})
}