mirror of
https://github.com/1f349/orchid.git
synced 2024-09-19 10:26:57 +01:00
346 lines
9.1 KiB
Go
346 lines
9.1 KiB
Go
|
package renewal
|
||
|
|
||
|
import (
|
||
|
"crypto/rsa"
|
||
|
"crypto/tls"
|
||
|
"crypto/x509"
|
||
|
"database/sql"
|
||
|
_ "embed"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"github.com/MrMelon54/orchid/http-acme"
|
||
|
"github.com/MrMelon54/orchid/pebble-dev"
|
||
|
"github.com/go-acme/lego/v4/certificate"
|
||
|
"github.com/go-acme/lego/v4/challenge"
|
||
|
"github.com/go-acme/lego/v4/lego"
|
||
|
"github.com/go-acme/lego/v4/providers/dns/namesilo"
|
||
|
"github.com/go-acme/lego/v4/registration"
|
||
|
"io"
|
||
|
"log"
|
||
|
"net/http"
|
||
|
"os"
|
||
|
"path/filepath"
|
||
|
"sync"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
ErrUnsupportedDNSProvider = errors.New("unsupported DNS provider")
|
||
|
//go:embed find-next-cert.sql
|
||
|
findNextCertSql string
|
||
|
//go:embed create-tables.sql
|
||
|
createTableCertificates string
|
||
|
)
|
||
|
|
||
|
type Service struct {
|
||
|
db *sql.DB
|
||
|
httpAcme *http_acme.HttpAcmeProvider
|
||
|
certTicker *time.Ticker
|
||
|
certDone chan struct{}
|
||
|
caAddr string
|
||
|
caCert []byte
|
||
|
transport *http.Transport
|
||
|
renewLock *sync.Mutex
|
||
|
leAccount *Account
|
||
|
certDir string
|
||
|
keyDir string
|
||
|
|
||
|
//notify
|
||
|
}
|
||
|
|
||
|
func NewRenewalService(wg *sync.WaitGroup, db *sql.DB, httpAcme *http_acme.HttpAcmeProvider, leConfig LetsEncryptConfig) (*Service, error) {
|
||
|
r := &Service{
|
||
|
db: db,
|
||
|
httpAcme: httpAcme,
|
||
|
certTicker: time.NewTicker(time.Minute * 10),
|
||
|
certDone: make(chan struct{}),
|
||
|
renewLock: &sync.Mutex{},
|
||
|
leAccount: &Account{
|
||
|
email: leConfig.Account.Email,
|
||
|
key: leConfig.Account.PrivateKey,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
// init domains table
|
||
|
_, err := r.db.Exec(createTableCertificates)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to create certificates table: %w", err)
|
||
|
}
|
||
|
|
||
|
// resolve CA information
|
||
|
r.resolveCADirectory(leConfig)
|
||
|
err = r.resolveCACertificate(leConfig)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to resolve CA certificate: %w", err)
|
||
|
}
|
||
|
|
||
|
wg.Add(1)
|
||
|
go r.renewalRoutine(wg)
|
||
|
return r, nil
|
||
|
}
|
||
|
|
||
|
func (s *Service) Shutdown() {
|
||
|
log.Println("[Renewal] Shutting down certificate renewal service")
|
||
|
close(s.certDone)
|
||
|
}
|
||
|
|
||
|
func (s *Service) resolveCADirectory(conf LetsEncryptConfig) {
|
||
|
switch conf.Directory {
|
||
|
case "production", "prod":
|
||
|
s.caAddr = lego.LEDirectoryProduction
|
||
|
case "staging":
|
||
|
s.caAddr = lego.LEDirectoryStaging
|
||
|
default:
|
||
|
s.caAddr = conf.Directory
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *Service) resolveCACertificate(conf LetsEncryptConfig) error {
|
||
|
switch conf.Certificate {
|
||
|
case "default":
|
||
|
// no nothing
|
||
|
case "pebble":
|
||
|
s.caCert = pebble_dev.GetPebbleCert()
|
||
|
default:
|
||
|
caGet, err := http.Get(conf.Certificate)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to download CA certificate: %w", err)
|
||
|
}
|
||
|
s.caCert, err = io.ReadAll(caGet.Body)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to read CA certificate: %w", err)
|
||
|
}
|
||
|
}
|
||
|
if s.caCert != nil {
|
||
|
caPool := x509.NewCertPool()
|
||
|
if !caPool.AppendCertsFromPEM(s.caCert) {
|
||
|
return fmt.Errorf("failed to add certificate to CA cert pool")
|
||
|
}
|
||
|
t := http.DefaultTransport.(*http.Transport).Clone()
|
||
|
t.TLSClientConfig = &tls.Config{RootCAs: caPool}
|
||
|
s.transport = t
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
var ErrAlreadyRenewing = errors.New("already renewing")
|
||
|
|
||
|
func (s *Service) renewalRoutine(wg *sync.WaitGroup) {
|
||
|
defer func() {
|
||
|
s.certTicker.Stop()
|
||
|
log.Println("[Renewal] Stopped certificate renewal service")
|
||
|
wg.Done()
|
||
|
}()
|
||
|
|
||
|
log.Println("[Renewal] Doing quick certificate check before starting...")
|
||
|
err := s.renewalCheck()
|
||
|
if err != nil {
|
||
|
log.Println("[Renewal] Certificate check, should not error first try: ", err)
|
||
|
return
|
||
|
}
|
||
|
log.Println("[Renewal] Initial check complete, continually checking every 4 hours...")
|
||
|
|
||
|
for {
|
||
|
select {
|
||
|
case <-s.certDone:
|
||
|
return
|
||
|
case <-s.certTicker.C:
|
||
|
go func() {
|
||
|
err := s.renewalCheck()
|
||
|
if err != nil && err != ErrAlreadyRenewing {
|
||
|
log.Println("[Renewal] Certificate check, an error occurred: ", err)
|
||
|
}
|
||
|
}()
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *Service) renewalCheck() error {
|
||
|
if !s.renewLock.TryLock() {
|
||
|
return ErrAlreadyRenewing
|
||
|
}
|
||
|
defer s.renewLock.Unlock()
|
||
|
|
||
|
localData, err := s.findNextCertificateToRenew()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to find a certificate to renew: %w", err)
|
||
|
}
|
||
|
|
||
|
// no certificates to update
|
||
|
if localData == nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
s.renewCert(localData)
|
||
|
}
|
||
|
|
||
|
func (s *Service) findNextCertificateToRenew() (*localCertData, error) {
|
||
|
d := &localCertData{}
|
||
|
|
||
|
row := s.db.QueryRow(findNextCertSql)
|
||
|
err := row.Scan(&d.id, &d.cert.current, &d.cert.notAfter, &d.dns.name, &d.dns.token)
|
||
|
switch err {
|
||
|
case nil:
|
||
|
// no nothing
|
||
|
break
|
||
|
case io.EOF:
|
||
|
// no certificate to update
|
||
|
return nil, nil
|
||
|
default:
|
||
|
return nil, fmt.Errorf("failed to scan table row: %w", err)
|
||
|
}
|
||
|
|
||
|
return d, nil
|
||
|
}
|
||
|
|
||
|
func (s *Service) fetchDomains(localData *localCertData) ([]string, error) {
|
||
|
query, err := s.db.Query(`SELECT domain FROM certificate_domains WHERE cert_id = ?`, localData.id)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to fetch domains for certificate: %d: %w", localData.id, err)
|
||
|
}
|
||
|
|
||
|
domains := make([]string, 0)
|
||
|
for query.Next() {
|
||
|
var domain string
|
||
|
err := query.Scan(&domain)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to scan row from domains table: %d: %w", localData.id, err)
|
||
|
}
|
||
|
domains = append(domains, domain)
|
||
|
}
|
||
|
if len(domains) == 0 {
|
||
|
return nil, fmt.Errorf("no domains registered for certificate: %d", localData.id)
|
||
|
}
|
||
|
return domains, nil
|
||
|
}
|
||
|
|
||
|
func (s *Service) setupLegoClient(localData *localCertData) (*lego.Client, error) {
|
||
|
config := lego.NewConfig(s.leAccount)
|
||
|
config.CADirURL = s.caAddr
|
||
|
if s.transport != nil {
|
||
|
config.HTTPClient.Transport = s.transport
|
||
|
}
|
||
|
dnsProv, err := s.getDnsProvider(localData.dns.name, localData.dns.token)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to resolve dns provider: %w", err)
|
||
|
}
|
||
|
|
||
|
client, err := lego.NewClient(config)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to generate client: %w", err)
|
||
|
}
|
||
|
|
||
|
// set providers - always returns nil so ignore the error
|
||
|
_ = client.Challenge.SetHTTP01Provider(s.httpAcme)
|
||
|
_ = client.Challenge.SetDNS01Provider(dnsProv)
|
||
|
|
||
|
register, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to update account registration: %w", err)
|
||
|
}
|
||
|
|
||
|
s.leAccount.reg = register
|
||
|
return client, nil
|
||
|
}
|
||
|
|
||
|
func (s *Service) getDnsProvider(name, token string) (challenge.Provider, error) {
|
||
|
switch name {
|
||
|
case "namesilo":
|
||
|
return namesilo.NewDNSProviderConfig(&namesilo.Config{APIKey: token})
|
||
|
default:
|
||
|
return nil, ErrUnsupportedDNSProvider
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *Service) getPrivateKey(id uint64) (*rsa.PrivateKey, error) {
|
||
|
privKeyBytes, err := os.ReadFile(filepath.Join(s.keyDir, fmt.Sprintf("%d.key.pem", id)))
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return x509.ParsePKCS1PrivateKey(privKeyBytes)
|
||
|
}
|
||
|
|
||
|
func (s *Service) renewCert(localData *localCertData) {
|
||
|
s.setRenewing(localData.id, true, false)
|
||
|
|
||
|
cert, certBytes, err := s.renewCertInternal(localData)
|
||
|
if err != nil {
|
||
|
log.Printf("[Renewal Failed to renew cert %d: %s\n", localData.id, err)
|
||
|
s.setRenewing(localData.id, false, true)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
_, err = s.db.Exec(`UPDATE certificates SET renewing = 0, renew_failed = 0, not_after = ?, updated_at = ? WHERE id = ?`, cert.NotAfter, cert.NotBefore, localData.id)
|
||
|
if err != nil {
|
||
|
log.Printf("[Renewal] Failed to update certificate %d in database: %s\n", localData.id, err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
oldPath := filepath.Join(s.certDir, fmt.Sprintf("%d-old.cert.pem", localData.id))
|
||
|
newPath := filepath.Join(s.certDir, fmt.Sprintf("%d.cert.pem", localData.id))
|
||
|
|
||
|
err = os.Rename(newPath, oldPath)
|
||
|
if err != nil {
|
||
|
log.Printf("[Renewal] Failed to rename certificate file '%s' => '%s': %s\n", newPath, oldPath, err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
openCertFile, err := os.Create(newPath)
|
||
|
if err != nil {
|
||
|
log.Printf("[Renewal] Failed to create certificate file '%s': %s\n", newPath, err)
|
||
|
return
|
||
|
}
|
||
|
defer openCertFile.Close()
|
||
|
|
||
|
_, err = openCertFile.Write(certBytes)
|
||
|
if err != nil {
|
||
|
log.Printf("[Renewal] Failed to write certificate file '%s': %s\n", newPath, err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
log.Printf("[Renewal] Updated certificate %d successfully\n", localData.id)
|
||
|
}
|
||
|
|
||
|
func (s *Service) renewCertInternal(localData *localCertData) (*x509.Certificate, []byte, error) {
|
||
|
// read private key file
|
||
|
privKey, err := s.getPrivateKey(localData.id)
|
||
|
if err != nil {
|
||
|
return nil, nil, fmt.Errorf("failed to open private key: %w", err)
|
||
|
}
|
||
|
|
||
|
// fetch domains for this certificate
|
||
|
domains, err := s.fetchDomains(localData)
|
||
|
if err != nil {
|
||
|
return nil, nil, fmt.Errorf("failed to update cert: %w", err)
|
||
|
}
|
||
|
|
||
|
// setup client for requesting a new certificate
|
||
|
client, err := s.setupLegoClient(localData)
|
||
|
if err != nil {
|
||
|
return nil, nil, fmt.Errorf("failed to generate a client: %w", err)
|
||
|
}
|
||
|
|
||
|
obtain, err := client.Certificate.Obtain(certificate.ObtainRequest{
|
||
|
Domains: domains,
|
||
|
PrivateKey: privKey,
|
||
|
Bundle: true,
|
||
|
})
|
||
|
if err != nil {
|
||
|
return nil, nil, fmt.Errorf("failed to obtain replacement certificate: %w", err)
|
||
|
}
|
||
|
|
||
|
parseCert, err := x509.ParseCertificate(obtain.Certificate)
|
||
|
if err != nil {
|
||
|
return nil, nil, fmt.Errorf("failed to parse new certificate: %w", err)
|
||
|
}
|
||
|
|
||
|
return parseCert, obtain.Certificate, nil
|
||
|
}
|
||
|
|
||
|
func (s *Service) setRenewing(id uint64, renewing, failed bool) {
|
||
|
_, err := s.db.Exec("UPDATE certificates SET renewing = ?, renew_failed = ? WHERE id = ?", renewing, failed, id)
|
||
|
if err != nil {
|
||
|
log.Printf("[Renewal] Failed to set renewing/failed mode in database %d: %s\n", id, err)
|
||
|
}
|
||
|
}
|