diff --git a/cmd/orchid/serve.go b/cmd/orchid/serve.go index a928556..f2a3d59 100644 --- a/cmd/orchid/serve.go +++ b/cmd/orchid/serve.go @@ -87,7 +87,7 @@ func normalLoad(conf startUpConfig, wd string) { if err != nil { log.Fatal("[Orchid] Error:", err) } - srv := servers.NewApiServer(conf.Listen, mJwtVerify, conf.Domains) + srv := servers.NewApiServer(conf.Listen, db, mJwtVerify, conf.Domains) utils.RunBackgroundHttp("API", srv) // Wait for exit signal diff --git a/pebble/asset/pebble-cert.pem b/pebble/pebble-cert.pem similarity index 100% rename from pebble/asset/pebble-cert.pem rename to pebble/pebble-cert.pem diff --git a/pebble/asset/pebble-config.json b/pebble/pebble-config.json similarity index 100% rename from pebble/asset/pebble-config.json rename to pebble/pebble-config.json diff --git a/pebble/pebble.go b/pebble/pebble.go index bfaf4f0..ce04439 100644 --- a/pebble/pebble.go +++ b/pebble/pebble.go @@ -5,8 +5,8 @@ package pebble import _ "embed" var ( - //go:embed asset/pebble-cert.pem + //go:embed pebble-cert.pem RawCert []byte - //go:embed asset/pebble-config.json + //go:embed pebble-config.json RawConfig []byte ) diff --git a/renewal/create-tables.sql b/renewal/create-tables.sql index 7a1016d..c594638 100644 --- a/renewal/create-tables.sql +++ b/renewal/create-tables.sql @@ -1,7 +1,7 @@ CREATE TABLE IF NOT EXISTS certificates ( id INTEGER PRIMARY KEY AUTOINCREMENT, - owner INTEGER, + owner VARCHAR, dns INTEGER, auto_renew INTEGER DEFAULT 0, active INTEGER DEFAULT 0, @@ -9,7 +9,9 @@ CREATE TABLE IF NOT EXISTS certificates renew_failed INTEGER DEFAULT 0, not_after DATETIME, updated_at DATETIME, - FOREIGN KEY (dns) REFERENCES dns (id) + temp_parent INTEGER DEFAULT 0, + FOREIGN KEY (dns) REFERENCES dns_acme (id), + FOREIGN KEY (temp_parent) REFERENCES certificates (id) ); CREATE TABLE IF NOT EXISTS certificate_domains @@ -17,10 +19,12 @@ CREATE TABLE IF NOT EXISTS certificate_domains domain_id INTEGER PRIMARY KEY AUTOINCREMENT, cert_id INTEGER, domain VARCHAR, + state INTEGER DEFAULT 1, + UNIQUE (cert_id, domain), FOREIGN KEY (cert_id) REFERENCES certificates (id) ); -CREATE TABLE IF NOT EXISTS dns +CREATE TABLE IF NOT EXISTS dns_acme ( id INTEGER PRIMARY KEY AUTOINCREMENT, type VARCHAR, diff --git a/renewal/find-next-cert.sql b/renewal/find-next-cert.sql index 315227e..ec8f99d 100644 --- a/renewal/find-next-cert.sql +++ b/renewal/find-next-cert.sql @@ -1,9 +1,9 @@ -select cert.id, cert.not_after, dns.type, dns.token +select cert.id, cert.not_after, dns.type, dns.token, cert.temp_parent from certificates as cert left outer join dns on cert.dns = dns.id where cert.active = 1 - and cert.auto_renew = 1 + and (cert.auto_renew = 1 or cert.not_after IS NULL) and cert.renewing = 0 and cert.renew_failed = 0 and (cert.not_after IS NULL or DATETIME(cert.not_after, 'utc', '-30 days') < DATETIME()) -order by cert.not_after DESC NULLS FIRST +order by cert.temp_parent, cert.not_after DESC NULLS FIRST diff --git a/renewal/local.go b/renewal/local.go index 382f357..10957b1 100644 --- a/renewal/local.go +++ b/renewal/local.go @@ -11,6 +11,7 @@ type localCertData struct { name sql.NullString token sql.NullString } - notAfter sql.NullTime - domains []string + notAfter sql.NullTime + domains []string + tempParent uint64 } diff --git a/renewal/service.go b/renewal/service.go index 996ad30..0e9911d 100644 --- a/renewal/service.go +++ b/renewal/service.go @@ -35,6 +35,12 @@ var ( createTableCertificates string ) +const ( + DomainStateNormal = 0 + DomainStateAdded = 1 + DomainStateRemoved = 2 +) + // overrides only used in testing var testDnsOptions interface { challenge.Provider @@ -282,7 +288,7 @@ func (s *Service) findNextCertificateToRenew() (*localCertData, error) { } // scan the first row - err = row.Scan(&d.id, &d.notAfter, &d.dns.name, &d.dns.token) + err = row.Scan(&d.id, &d.notAfter, &d.dns.name, &d.dns.token, &d.tempParent) switch err { case nil: // no nothing @@ -299,7 +305,7 @@ func (s *Service) findNextCertificateToRenew() (*localCertData, error) { func (s *Service) fetchDomains(localData *localCertData) ([]string, error) { // more sql: this one just grabs all the domains for a certificate - query, err := s.db.Query(`SELECT domain FROM certificate_domains WHERE cert_id = ?`, localData.id) + query, err := s.db.Query(`SELECT domain FROM certificate_domains WHERE cert_id = ?`, resolveTempParent(localData)) if err != nil { return nil, fmt.Errorf("failed to fetch domains for certificate: %d: %w", localData.id, err) } @@ -343,9 +349,9 @@ func (s *Service) setupLegoClient(localData *localCertData) (*lego.Client, error if testDnsOptions != nil { // set up the dns provider used during tests and disable propagation as no dns // will validate these tests - dnsAddrs := testDnsOptions.GetDnsAddrs() - log.Printf("Using testDnsOptions with DNS server: %v\n", dnsAddrs) - _ = client.Challenge.SetDNS01Provider(testDnsOptions, dns01.AddRecursiveNameservers(dnsAddrs), dns01.DisableCompletePropagationRequirement()) + dnsAddr := testDnsOptions.GetDnsAddrs() + log.Printf("Using testDnsOptions with DNS server: %v\n", dnsAddr) + _ = client.Challenge.SetDNS01Provider(testDnsOptions, dns01.AddRecursiveNameservers(dnsAddr), dns01.DisableCompletePropagationRequirement()) } else if localData.dns.name.Valid && localData.dns.token.Valid { // if the dns name and token are "valid" meaning non-null in this case // set up the specific dns provider requested @@ -413,6 +419,12 @@ func (s *Service) renewCert(localData *localCertData) error { return fmt.Errorf("failed to update cert %d in database: %w", localData.id, err) } + // set domains to normal state + _, err = s.db.Exec(`UPDATE certificate_domains SET state = ? WHERE cert_id = ?`, DomainStateNormal, localData.id) + if err != nil { + return fmt.Errorf("failed to update domains for %d in database: %w", localData.id, err) + } + // write out the certificate file err = s.writeCertFile(localData.id, certBytes) if err != nil { @@ -506,3 +518,10 @@ func (s *Service) writeCertFile(id uint64, certBytes []byte) error { return nil } + +func resolveTempParent(local *localCertData) uint64 { + if local.tempParent > 0 { + return local.tempParent + } + return local.id +} diff --git a/servers/api.go b/servers/api.go index 4e94a4c..3175c2e 100644 --- a/servers/api.go +++ b/servers/api.go @@ -1,40 +1,88 @@ package servers import ( + "database/sql" + "encoding/json" + "fmt" "github.com/MrMelon54/mjwt" - "github.com/MrMelon54/mjwt/auth" oUtils "github.com/MrMelon54/orchid/utils" vUtils "github.com/MrMelon54/violet/utils" + "github.com/golang-jwt/jwt/v4" "github.com/julienschmidt/httprouter" "net/http" + "strconv" "time" ) +type DomainStateValue struct { + Domain string `json:"domain"` + State int `json:"state"` +} + // NewApiServer creates and runs a http server containing all the API // endpoints for the software // // `/cert` - edit certificate -func NewApiServer(listen string, signer mjwt.Verifier, domains oUtils.DomainChecker) *http.Server { +func NewApiServer(listen string, db *sql.DB, signer mjwt.Verifier, domains oUtils.DomainChecker) *http.Server { r := httprouter.New() - // Endpoint for adding a certificate - r.POST("/cert", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { - // TODO: register domains to a certificate - vUtils.RespondVioletError(rw, http.StatusNotImplemented, "API unavailable") - rw.WriteHeader(http.StatusNotImplemented) - return - - if !hasPerms(signer, req, "orchid:cert:") { - vUtils.RespondHttpStatus(rw, http.StatusForbidden) - return - } + // Endpoint for looking up a certificate + r.GET("/lookup/:domain", checkAuthWithPerm(signer, "orchid:cert", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) { domain := params.ByName("domain") if !domains.ValidateDomain(domain) { vUtils.RespondVioletError(rw, http.StatusBadRequest, "Invalid domain") return } + })) + + r.POST("/cert", checkAuthWithPerm(signer, "orchid:cert:create", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) { + _, err := db.Exec(`INSERT INTO certificates (owner, dns, updated_at) VALUES (?, ?, ?)`, b.Subject, 0, time.Now()) + if err != nil { + apiError(rw, http.StatusInternalServerError, "Failed to delete certificate") + return + } rw.WriteHeader(http.StatusAccepted) - }) + })) + r.DELETE("/cert/:id", checkAuthForCertificate(signer, "orchid:cert:delete", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) { + _, err := db.Exec(`UPDATE certificates SET active = 0 WHERE id = ?`, certId) + if err != nil { + apiError(rw, http.StatusInternalServerError, "Failed to delete certificate") + return + } + rw.WriteHeader(http.StatusAccepted) + })) + + // Endpoint for adding/removing domains to/from a certificate + manageGet, managePutDelete := certDomainManageGET(db, signer), certDomainManagePUTandDELETE(db, signer, domains) + r.GET("/cert/:id/domains", manageGet) + r.PUT("/cert/:id/domains", managePutDelete) + r.DELETE("/cert/:id/domains", managePutDelete) + + // Endpoint for generating a temporary certificate for modified domains + r.POST("/cert/:id/temp", checkAuth(signer, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) { + if !b.Claims.Perms.Has("orchid:cert:quick") { + apiError(rw, http.StatusForbidden, "No permission") + return + } + + // lookup certificate owner + id, err := checkCertOwner(db, "", b) + if err != nil { + apiError(rw, http.StatusInsufficientStorage, "Database error") + return + } + + // run a safe transaction to create the temporary certificate + if safeTransaction(rw, db, func(rw http.ResponseWriter, tx *sql.Tx) error { + // insert temporary certificate into database + _, err := db.Exec(`INSERT INTO certificates (owner, dns, active, updated_at, temp_parent) VALUES (?, 0, 1, ?, ?)`, b.Subject, time.Now(), id) + return err + }) != nil { + apiError(rw, http.StatusInsufficientStorage, "Database error") + fmt.Printf("Internal error: %s\n", err) + return + } + })) // Create and run http server return &http.Server{ @@ -48,19 +96,83 @@ func NewApiServer(listen string, signer mjwt.Verifier, domains oUtils.DomainChec } } -func hasPerms(verify mjwt.Verifier, req *http.Request, perm string) bool { - // Get bearer token - bearer := vUtils.GetBearer(req) - if bearer == "" { - return false - } - - // Read claims from mjwt - _, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer) - if err != nil { - return false - } - - // Token must have perm - return b.Claims.Perms.Has(perm) +// apiError outputs a generic JSON error message +func apiError(rw http.ResponseWriter, code int, m string) { + rw.WriteHeader(code) + _ = json.NewEncoder(rw).Encode(map[string]string{ + "error": m, + }) +} + +// lookupCertOwner finds the certificate matching the id string and returns the +// numeric id, owner and possible error, only works for active certificates. +func checkCertOwner(db *sql.DB, idStr string, b AuthClaims) (uint64, error) { + // parse the id + rawId, err := strconv.ParseUint(idStr, 10, 64) + if err != nil { + return 0, err + } + + // run database query + row := db.QueryRow(`SELECT id, owner FROM certificates WHERE active = 1 and id = ?`, rawId) + + // scan in result values + var id uint64 + var owner string + err = row.Scan(&id, &owner) + if err != nil { + return 0, fmt.Errorf("scan error: %w", err) + } + + // check the owner is the mjwt token subject + if b.Subject != owner { + return id, fmt.Errorf("not the certificate owner") + } + + // it's all valid, return the values + return id, nil +} + +// safeTransaction completes a database transaction safely allowing for rollbacks +// if the callback errors +func safeTransaction(rw http.ResponseWriter, db *sql.DB, cb func(rw http.ResponseWriter, tx *sql.Tx) error) error { + // start a transaction + begin, err := db.Begin() + if err != nil { + return fmt.Errorf("failed to begin a transaction") + } + + // init defer rollback + needsRollback := true + defer func() { + if needsRollback { + _ = begin.Rollback() + } + }() + + // run main code within the transaction session + err = cb(rw, begin) + if err != nil { + return err + } + + // clear the rollback flag and commit the transaction + needsRollback = false + if begin.Commit() != nil { + return fmt.Errorf("failed to commit a transaction") + } + return nil +} + +// validateDomainAudienceClaims validates if the audience claims contain the +// `owns=` field with the matching top level domain +func validateDomainAudienceClaims(a string, aud jwt.ClaimStrings) bool { + if fqdn, ok := vUtils.GetTopFqdn(a); ok { + for _, i := range aud { + if i == "owns="+fqdn { + return true + } + } + } + return false } diff --git a/servers/auth.go b/servers/auth.go new file mode 100644 index 0000000..dc021d5 --- /dev/null +++ b/servers/auth.go @@ -0,0 +1,66 @@ +package servers + +import ( + "database/sql" + "github.com/MrMelon54/mjwt" + "github.com/MrMelon54/mjwt/auth" + vUtils "github.com/MrMelon54/violet/utils" + "github.com/julienschmidt/httprouter" + "net/http" +) + +type AuthClaims mjwt.BaseTypeClaims[auth.AccessTokenClaims] + +type AuthCallback func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) + +type CertAuthCallback func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) + +// checkAuth validates the bearer token against a mjwt.Verifier and returns an +// error message or continues to the next handler +func checkAuth(verify mjwt.Verifier, cb AuthCallback) httprouter.Handle { + return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + // Get bearer token + bearer := vUtils.GetBearer(req) + if bearer == "" { + apiError(rw, http.StatusForbidden, "Missing bearer token") + return + } + + // Read claims from mjwt + _, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](verify, bearer) + if err != nil { + apiError(rw, http.StatusForbidden, "Invalid token") + return + } + + cb(rw, req, params, AuthClaims(b)) + } +} + +// checkAuthWithPerm validates the bearer token and checks if it contains a +// required permission and returns an error message or continues to the next +// handler +func checkAuthWithPerm(verify mjwt.Verifier, perm string, cb AuthCallback) httprouter.Handle { + return checkAuth(verify, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) { + // check perms + if !b.Claims.Perms.Has(perm) { + apiError(rw, http.StatusForbidden, "No permission") + return + } + cb(rw, req, params, b) + }) +} + +// checkAuthForCertificate +func checkAuthForCertificate(verify mjwt.Verifier, perm string, db *sql.DB, cb CertAuthCallback) httprouter.Handle { + return checkAuthWithPerm(verify, perm, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) { + // lookup certificate owner + id, err := checkCertOwner(db, "", b) + if err != nil { + apiError(rw, http.StatusInsufficientStorage, "Database error") + return + } + + cb(rw, req, params, b, id) + }) +} diff --git a/servers/certDomainManage.go b/servers/certDomainManage.go new file mode 100644 index 0000000..29ef316 --- /dev/null +++ b/servers/certDomainManage.go @@ -0,0 +1,108 @@ +package servers + +import ( + "database/sql" + "encoding/json" + "fmt" + "github.com/MrMelon54/mjwt" + "github.com/MrMelon54/orchid/renewal" + "github.com/MrMelon54/orchid/utils" + "github.com/julienschmidt/httprouter" + "net/http" +) + +func certDomainManageGET(db *sql.DB, signer mjwt.Verifier) httprouter.Handle { + return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) { + query, err := db.Query(`SELECT domain, state FROM certificate_domains WHERE cert_id = ?`, certId) + if err != nil { + apiError(rw, http.StatusInsufficientStorage, "Database error") + return + } + + // collect all the domains and state values + var domainStates []DomainStateValue + for query.Next() { + var a DomainStateValue + err := query.Scan(&a.Domain, &a.State) + if err != nil { + apiError(rw, http.StatusInsufficientStorage, "Database error") + return + } + domainStates = append(domainStates, a) + } + + // write output + rw.WriteHeader(http.StatusAccepted) + m := map[string]any{ + "id": fmt.Sprintf("%d", certId), + "domains": domainStates, + } + _ = json.NewEncoder(rw).Encode(m) + }) +} + +func certDomainManagePUTandDELETE(db *sql.DB, signer mjwt.Verifier, domains utils.DomainChecker) httprouter.Handle { + return checkAuthForCertificate(signer, "orchid:cert:edit", db, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims, certId uint64) { + // check request type + isAdd := req.Method == http.MethodPut + + if len(b.Audience) == 0 { + apiError(rw, http.StatusForbidden, "Missing audience tag, to specify owned domains") + return + } + + // read domains from request body + var d []string + if json.NewDecoder(req.Body).Decode(&d) != nil { + apiError(rw, http.StatusBadRequest, "Invalid request body") + return + } + + // validate all domains + for _, i := range d { + if !validateDomainAudienceClaims(i, b.Audience) { + apiError(rw, http.StatusBadRequest, "Token cannot modify a specified domain") + return + } + if !domains.ValidateDomain(i) { + apiError(rw, http.StatusBadRequest, "Invalid domain") + return + } + } + + // run a safe transaction to insert or update the certificate domains + if safeTransaction(rw, db, func(rw http.ResponseWriter, tx *sql.Tx) error { + if isAdd { + // insert domains to add + for _, i := range d { + _, err := tx.Exec(`INSERT INTO certificate_domains (cert_id, domain, state) VALUES (?, ?, ?)`, certId, i, renewal.DomainStateAdded) + if err != nil { + return fmt.Errorf("failed to add domains to the database") + } + } + } else { + // update domains to removed state + _, err := tx.Exec(`UPDATE certificate_domains SET state = ? WHERE domain IN ?`, renewal.DomainStateRemoved, d) + if err != nil { + return fmt.Errorf("failed to remove domains from the database") + } + } + return nil + }) != nil { + apiError(rw, http.StatusInsufficientStorage, "Database error") + return + } + + // write output + rw.WriteHeader(http.StatusAccepted) + m := map[string]any{ + "id": fmt.Sprintf("%d", certId), + } + if isAdd { + m["add_domains"] = d + } else { + m["remove_domains"] = d + } + _ = json.NewEncoder(rw).Encode(m) + }) +}