diff --git a/go.mod b/go.mod index 6213037..985c24c 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,12 @@ module github.com/MrMelon54/orchid go 1.20 require ( + github.com/MrMelon54/certgen v0.0.1 github.com/MrMelon54/mjwt v0.1.0 github.com/go-acme/lego/v4 v4.12.3 + github.com/google/uuid v1.3.0 + github.com/mattn/go-sqlite3 v1.14.17 + github.com/miekg/dns v1.1.50 github.com/stretchr/testify v1.8.4 ) @@ -14,7 +18,6 @@ require ( github.com/go-jose/go-jose/v3 v3.0.0 // indirect github.com/golang-jwt/jwt/v4 v4.4.3 // indirect github.com/google/go-querystring v1.1.0 // indirect - github.com/miekg/dns v1.1.50 // indirect github.com/nrdcg/namesilo v0.2.1 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 9b857d0..c800aa7 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/MrMelon54/certgen v0.0.1 h1:ycWdZ2RlxQ5qSuejeBVv4aXjGo5hdqqL4j4EjrXnFMk= +github.com/MrMelon54/certgen v0.0.1/go.mod h1:GHflVlSbtFLJZLpN1oWyUvDBRrR8qCWiwZLXCCnS2Gc= github.com/MrMelon54/mjwt v0.1.0 h1:x1wBrh9l2CowRekHecxcZaH2zy9Hvqwlp4ppmW1P1OA= github.com/MrMelon54/mjwt v0.1.0/go.mod h1:oYrDBWK09Hju98xb+bRQ0wy+RuAzacxYvKYOZchR2Tk= github.com/cenkalti/backoff/v4 v4.2.0 h1:HN5dHm3WBOgndBH6E8V0q2jIYIR3s9yglV8k/+MN3u4= @@ -17,6 +19,10 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA= github.com/miekg/dns v1.1.50/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME= github.com/nrdcg/namesilo v0.2.1 h1:kLjCjsufdW/IlC+iSfAqj0iQGgKjlbUUeDJio5Y6eMg= diff --git a/pebble-dev/normal.go b/pebble-dev/normal.go deleted file mode 100644 index 20196b7..0000000 --- a/pebble-dev/normal.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build !DEBUG - -package pebble_dev - -import "log" - -func GetPebbleCert() []byte { - log.Fatalln("[Renewal] Pebble is selected as the certificate source but this binary was not compiled in debug mode") - return nil -} diff --git a/pebble-dev/debug.go b/pebble/asset/pebble-cert.pem similarity index 92% rename from pebble-dev/debug.go rename to pebble/asset/pebble-cert.pem index 2c21c76..a69a4c4 100644 --- a/pebble-dev/debug.go +++ b/pebble/asset/pebble-cert.pem @@ -1,9 +1,3 @@ -//go:build DEBUG - -package pebble_dev - -func GetPebbleCert() []byte { - return []byte(` -----BEGIN CERTIFICATE----- MIIDCTCCAfGgAwIBAgIIJOLbes8sTr4wDQYJKoZIhvcNAQELBQAwIDEeMBwGA1UE AxMVbWluaWNhIHJvb3QgY2EgMjRlMmRiMCAXDTE3MTIwNjE5NDIxMFoYDzIxMTcx @@ -23,5 +17,3 @@ Mfn3qEb9BXSk0Q3prNV5sOV3vgjEtB4THfDxSz9z3+DepVnW3vbbqwEbkXdk3j82 2muVldgOUgTwK8eT+XdofVdntzU/kzygSAtAQwLJfn51fS1GvEcYGBc1bDryIqmF p9BI7gVKtWSZYegicA== -----END CERTIFICATE----- -`) -} diff --git a/pebble/asset/pebble-config.json b/pebble/asset/pebble-config.json new file mode 100644 index 0000000..e0595be --- /dev/null +++ b/pebble/asset/pebble-config.json @@ -0,0 +1,19 @@ +{ + "pebble": { + "listenAddress": "0.0.0.0:14000", + "managementListenAddress": "0.0.0.0:15000", + "certificate": "certs/localhost/cert.pem", + "privateKey": "certs/localhost/key.pem", + "httpPort": 5002, + "tlsPort": 5001, + "ocspResponderURL": "", + "externalAccountBindingRequired": false, + "domainBlocklist": [ + "blocked-domain.example" + ], + "retryAfter": { + "authz": 3, + "order": 5 + } + } +} diff --git a/pebble/pebble.go b/pebble/pebble.go new file mode 100644 index 0000000..bfaf4f0 --- /dev/null +++ b/pebble/pebble.go @@ -0,0 +1,12 @@ +//go:build !DEBUG + +package pebble + +import _ "embed" + +var ( + //go:embed asset/pebble-cert.pem + RawCert []byte + //go:embed asset/pebble-config.json + RawConfig []byte +) diff --git a/renewal/config.go b/renewal/config.go index ed94376..27359ad 100644 --- a/renewal/config.go +++ b/renewal/config.go @@ -7,4 +7,5 @@ type LetsEncryptConfig struct { } `yaml:"account"` Directory string `yaml:"directory"` Certificate string `yaml:"certificate"` + insecure bool } diff --git a/renewal/find-next-cert.sql b/renewal/find-next-cert.sql index 385eaa2..315227e 100644 --- a/renewal/find-next-cert.sql +++ b/renewal/find-next-cert.sql @@ -1,11 +1,9 @@ -select cert.id, certdata.data_id, certdata.not_after, dns.type, dns.token +select cert.id, cert.not_after, dns.type, dns.token from certificates as cert - left outer join certificate_data as certdata on cert.id = certdata.meta_id left outer join dns on cert.dns = dns.id where cert.active = 1 and cert.auto_renew = 1 and cert.renewing = 0 and cert.renew_failed = 0 - and (certdata.ready IS NULL or certdata.ready = 1) - and (certdata.not_after IS NULL or DATETIME(certdata.not_after, 'utc', '-30 days') < DATETIME()) -order by certdata.not_after DESC NULLS FIRST + and (cert.not_after IS NULL or DATETIME(cert.not_after, 'utc', '-30 days') < DATETIME()) +order by cert.not_after DESC NULLS FIRST diff --git a/renewal/local.go b/renewal/local.go index 4d85809..382f357 100644 --- a/renewal/local.go +++ b/renewal/local.go @@ -1,17 +1,16 @@ package renewal -import "time" +import ( + "database/sql" +) // Contains local types for the renewal service type localCertData struct { id uint64 dns struct { - name string - token string + name sql.NullString + token sql.NullString } - cert struct { - current uint64 - notAfter time.Time - } - domains []string + notAfter sql.NullTime + domains []string } diff --git a/renewal/service.go b/renewal/service.go index 1fd68c4..f60924d 100644 --- a/renewal/service.go +++ b/renewal/service.go @@ -1,17 +1,19 @@ package renewal import ( + "bytes" "crypto/rsa" "crypto/tls" "crypto/x509" "database/sql" _ "embed" + "encoding/pem" "errors" "fmt" - "github.com/MrMelon54/orchid/http-acme" - "github.com/MrMelon54/orchid/pebble-dev" + "github.com/MrMelon54/orchid/pebble" "github.com/go-acme/lego/v4/certificate" "github.com/go-acme/lego/v4/challenge" + "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/providers/dns/namesilo" "github.com/go-acme/lego/v4/registration" @@ -32,23 +34,27 @@ var ( createTableCertificates string ) +var testDnsOptions interface { + challenge.Provider + GetDnsAddrs() []string +} + type Service struct { db *sql.DB - httpAcme *http_acme.HttpAcmeProvider + httpAcme challenge.Provider certTicker *time.Ticker certDone chan struct{} caAddr string caCert []byte - transport *http.Transport + transport http.RoundTripper renewLock *sync.Mutex leAccount *Account certDir string keyDir string - - //notify + insecure bool } -func NewRenewalService(wg *sync.WaitGroup, db *sql.DB, httpAcme *http_acme.HttpAcmeProvider, leConfig LetsEncryptConfig) (*Service, error) { +func NewRenewalService(wg *sync.WaitGroup, db *sql.DB, httpAcme challenge.Provider, leConfig LetsEncryptConfig, certDir, keyDir string) (*Service, error) { r := &Service{ db: db, httpAcme: httpAcme, @@ -57,19 +63,37 @@ func NewRenewalService(wg *sync.WaitGroup, db *sql.DB, httpAcme *http_acme.HttpA renewLock: &sync.Mutex{}, leAccount: &Account{ email: leConfig.Account.Email, - key: leConfig.Account.PrivateKey, }, + certDir: certDir, + keyDir: keyDir, + insecure: leConfig.insecure, + } + + // make certDir and keyDir + err := os.MkdirAll(certDir, os.ModePerm) + if err != nil { + return nil, fmt.Errorf("failed to create certDir '%s': %w", certDir, err) + } + err = os.MkdirAll(keyDir, os.ModePerm) + if err != nil { + return nil, fmt.Errorf("failed to create certDir '%s': %w", certDir, err) + } + + // load lets encrypt private key + err = r.resolveLEPrivKey(leConfig.Account.PrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to resolve LetsEncrypt account private key: %w", err) } // init domains table - _, err := r.db.Exec(createTableCertificates) + _, err = r.db.Exec(createTableCertificates) if err != nil { return nil, fmt.Errorf("failed to create certificates table: %w", err) } // resolve CA information - r.resolveCADirectory(leConfig) - err = r.resolveCACertificate(leConfig) + r.resolveCADirectory(leConfig.Directory) + err = r.resolveCACertificate(leConfig.Certificate) if err != nil { return nil, fmt.Errorf("failed to resolve CA certificate: %w", err) } @@ -84,41 +108,55 @@ func (s *Service) Shutdown() { close(s.certDone) } -func (s *Service) resolveCADirectory(conf LetsEncryptConfig) { - switch conf.Directory { +func (s *Service) resolveLEPrivKey(a string) error { + key, err := x509.ParsePKCS1PrivateKey([]byte(a)) + if err != nil { + bytes, err := os.ReadFile(a) + if err != nil { + return err + } + key, err = x509.ParsePKCS1PrivateKey(bytes) + } + s.leAccount.key = key + return err +} + +func (s *Service) resolveCADirectory(dir string) { + switch dir { case "production", "prod": s.caAddr = lego.LEDirectoryProduction case "staging": s.caAddr = lego.LEDirectoryStaging default: - s.caAddr = conf.Directory + s.caAddr = dir } } -func (s *Service) resolveCACertificate(conf LetsEncryptConfig) error { - switch conf.Certificate { +func (s *Service) resolveCACertificate(cert string) error { + switch cert { case "default": // no nothing case "pebble": - s.caCert = pebble_dev.GetPebbleCert() + s.caCert = pebble.RawCert + case "insecure": + s.caCert = []byte{0x00} default: - caGet, err := http.Get(conf.Certificate) - if err != nil { - return fmt.Errorf("failed to download CA certificate: %w", err) - } - s.caCert, err = io.ReadAll(caGet.Body) - if err != nil { - return fmt.Errorf("failed to read CA certificate: %w", err) - } + s.caCert = []byte(cert) } if s.caCert != nil { - caPool := x509.NewCertPool() - if !caPool.AppendCertsFromPEM(s.caCert) { - return fmt.Errorf("failed to add certificate to CA cert pool") + if bytes.Compare([]byte{0x00}, s.caCert) == 0 { + t := http.DefaultTransport.(*http.Transport).Clone() + t.TLSClientConfig.InsecureSkipVerify = true + s.transport = t + } else { + caPool := x509.NewCertPool() + if !caPool.AppendCertsFromPEM(s.caCert) { + return fmt.Errorf("failed to add certificate to CA cert pool") + } + t := http.DefaultTransport.(*http.Transport).Clone() + t.TLSClientConfig = &tls.Config{RootCAs: caPool} + s.transport = t } - t := http.DefaultTransport.(*http.Transport).Clone() - t.TLSClientConfig = &tls.Config{RootCAs: caPool} - s.transport = t } return nil } @@ -171,14 +209,19 @@ func (s *Service) renewalCheck() error { return nil } - s.renewCert(localData) + err = s.renewCert(localData) + if err != nil { + return err + } + log.Printf("[Renewal] Updated certificate %d successfully\n", localData.id) + return nil } func (s *Service) findNextCertificateToRenew() (*localCertData, error) { d := &localCertData{} row := s.db.QueryRow(findNextCertSql) - err := row.Scan(&d.id, &d.cert.current, &d.cert.notAfter, &d.dns.name, &d.dns.token) + err := row.Scan(&d.id, &d.notAfter, &d.dns.name, &d.dns.token) switch err { case nil: // no nothing @@ -220,19 +263,29 @@ func (s *Service) setupLegoClient(localData *localCertData) (*lego.Client, error if s.transport != nil { config.HTTPClient.Transport = s.transport } - dnsProv, err := s.getDnsProvider(localData.dns.name, localData.dns.token) - if err != nil { - return nil, fmt.Errorf("failed to resolve dns provider: %w", err) - } client, err := lego.NewClient(config) if err != nil { return nil, fmt.Errorf("failed to generate client: %w", err) } - // set providers - always returns nil so ignore the error + // set http challenge provider _ = client.Challenge.SetHTTP01Provider(s.httpAcme) - _ = client.Challenge.SetDNS01Provider(dnsProv) + + // if testDnsOptions is defined then set up the test provider + if testDnsOptions != nil { + dnsAddrs := testDnsOptions.GetDnsAddrs() + log.Printf("Using testDnsOptions with DNS server: %v\n", dnsAddrs) + _ = client.Challenge.SetDNS01Provider(testDnsOptions, dns01.AddRecursiveNameservers(dnsAddrs), dns01.DisableCompletePropagationRequirement()) + } else if localData.dns.name.Valid && localData.dns.token.Valid { + // if the dns name and token are "valid" meaning non-null in this case + // set up the specific dns provider requested + dnsProv, err := s.getDnsProvider(localData.dns.name.String, localData.dns.token.String) + if err != nil { + return nil, fmt.Errorf("failed to resolve dns provider: %w", err) + } + _ = client.Challenge.SetDNS01Provider(dnsProv) + } register, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) if err != nil { @@ -260,45 +313,26 @@ func (s *Service) getPrivateKey(id uint64) (*rsa.PrivateKey, error) { return x509.ParsePKCS1PrivateKey(privKeyBytes) } -func (s *Service) renewCert(localData *localCertData) { +func (s *Service) renewCert(localData *localCertData) error { s.setRenewing(localData.id, true, false) cert, certBytes, err := s.renewCertInternal(localData) if err != nil { - log.Printf("[Renewal Failed to renew cert %d: %s\n", localData.id, err) s.setRenewing(localData.id, false, true) - return + return fmt.Errorf("failed to renew cert %d: %w", localData.id, err) } _, err = s.db.Exec(`UPDATE certificates SET renewing = 0, renew_failed = 0, not_after = ?, updated_at = ? WHERE id = ?`, cert.NotAfter, cert.NotBefore, localData.id) if err != nil { - log.Printf("[Renewal] Failed to update certificate %d in database: %s\n", localData.id, err) - return + return fmt.Errorf("failed to update cert %d in database: %w", localData.id, err) } - oldPath := filepath.Join(s.certDir, fmt.Sprintf("%d-old.cert.pem", localData.id)) - newPath := filepath.Join(s.certDir, fmt.Sprintf("%d.cert.pem", localData.id)) - - err = os.Rename(newPath, oldPath) + err = s.writeCertFile(localData.id, certBytes) if err != nil { - log.Printf("[Renewal] Failed to rename certificate file '%s' => '%s': %s\n", newPath, oldPath, err) - return + return fmt.Errorf("failed to write cert file: %w", err) } - openCertFile, err := os.Create(newPath) - if err != nil { - log.Printf("[Renewal] Failed to create certificate file '%s': %s\n", newPath, err) - return - } - defer openCertFile.Close() - - _, err = openCertFile.Write(certBytes) - if err != nil { - log.Printf("[Renewal] Failed to write certificate file '%s': %s\n", newPath, err) - return - } - - log.Printf("[Renewal] Updated certificate %d successfully\n", localData.id) + return nil } func (s *Service) renewCertInternal(localData *localCertData) (*x509.Certificate, []byte, error) { @@ -320,6 +354,7 @@ func (s *Service) renewCertInternal(localData *localCertData) (*x509.Certificate return nil, nil, fmt.Errorf("failed to generate a client: %w", err) } + // obtain new certificate - this call will hang until a certificate is ready obtain, err := client.Certificate.Obtain(certificate.ObtainRequest{ Domains: domains, PrivateKey: privKey, @@ -329,11 +364,19 @@ func (s *Service) renewCertInternal(localData *localCertData) (*x509.Certificate return nil, nil, fmt.Errorf("failed to obtain replacement certificate: %w", err) } - parseCert, err := x509.ParseCertificate(obtain.Certificate) + // extract the certificate data from pem encoding + p, _ := pem.Decode(obtain.Certificate) + if p.Type != "CERTIFICATE" { + return nil, nil, fmt.Errorf("invalid certificate type '%s'", p.Type) + } + + // parse the obtained certificate + parseCert, err := x509.ParseCertificate(p.Bytes) if err != nil { return nil, nil, fmt.Errorf("failed to parse new certificate: %w", err) } + // return the parsed and raw bytes return parseCert, obtain.Certificate, nil } @@ -343,3 +386,29 @@ func (s *Service) setRenewing(id uint64, renewing, failed bool) { log.Printf("[Renewal] Failed to set renewing/failed mode in database %d: %s\n", id, err) } } + +func (s *Service) writeCertFile(id uint64, certBytes []byte) error { + oldPath := filepath.Join(s.certDir, fmt.Sprintf("%d-old.cert.pem", id)) + newPath := filepath.Join(s.certDir, fmt.Sprintf("%d.cert.pem", id)) + + // move certificate file to old name + err := os.Rename(newPath, oldPath) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to rename cert file '%s' => '%s': %w", newPath, oldPath, err) + } + + // create new certificate file + openCertFile, err := os.Create(newPath) + if err != nil { + return fmt.Errorf("failed to create cert file '%s': %s", newPath, err) + } + defer openCertFile.Close() + + // write certificate bytes + _, err = openCertFile.Write(certBytes) + if err != nil { + return fmt.Errorf("failed to write cert file '%s': %s", newPath, err) + } + + return nil +} diff --git a/renewal/service_test.go b/renewal/service_test.go index 629f895..0c11361 100644 --- a/renewal/service_test.go +++ b/renewal/service_test.go @@ -1 +1,217 @@ package renewal + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "database/sql" + "encoding/pem" + "fmt" + "github.com/MrMelon54/certgen" + "github.com/MrMelon54/orchid/pebble" + "github.com/MrMelon54/orchid/test" + "github.com/go-acme/lego/v4/lego" + "github.com/google/uuid" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "go/build" + "io" + "log" + "math/big" + "net" + "net/http" + "os" + "os/exec" + "path/filepath" + "sync" + "testing" + "time" +) + +const pebbleUrl = "https://localhost:5000" + +func TestService_resolveCADirectory(t *testing.T) { + s := &Service{} + s.resolveCADirectory("production") + assert.Equal(t, lego.LEDirectoryProduction, s.caAddr) + s.resolveCADirectory("prod") + assert.Equal(t, lego.LEDirectoryProduction, s.caAddr) + s.resolveCADirectory("staging") + assert.Equal(t, lego.LEDirectoryStaging, s.caAddr) + s.resolveCADirectory(pebbleUrl) + assert.Equal(t, pebbleUrl, s.caAddr) +} + +func TestService_resolveCACertificate(t *testing.T) { + s := &Service{} + assert.NoError(t, s.resolveCACertificate("default")) + assert.Nil(t, s.caCert) + assert.NoError(t, s.resolveCACertificate("pebble")) + assert.Equal(t, 0, bytes.Compare(pebble.RawCert, s.caCert)) +} + +func setupPebbleSuite(tb testing.TB) (*certgen.CertGen, func()) { + fmt.Println("Running pebble") + pebbleTmp, err := os.MkdirTemp("", "pebble") + assert.NoError(tb, err) + assert.NoError(tb, os.WriteFile(filepath.Join(pebbleTmp, "pebble-config.json"), pebble.RawConfig, os.ModePerm)) + + serverTls, err := certgen.MakeServerTls(nil, 2048, pkix.Name{ + Country: []string{"GB"}, + Organization: []string{"Orchid"}, + OrganizationalUnit: []string{"Test"}, + SerialNumber: "0", + CommonName: "localhost", + }, big.NewInt(1), func(now time.Time) time.Time { + return now.AddDate(10, 0, 0) + }, []string{"localhost", "pebble"}, []net.IP{net.IPv4(127, 0, 0, 1)}) + assert.NoError(tb, err) + assert.NoError(tb, os.MkdirAll(filepath.Join(pebbleTmp, "certs", "localhost"), os.ModePerm)) + assert.NoError(tb, os.WriteFile(filepath.Join(pebbleTmp, "certs", "localhost", "cert.pem"), serverTls.GetCertPem(), os.ModePerm)) + assert.NoError(tb, os.WriteFile(filepath.Join(pebbleTmp, "certs", "localhost", "key.pem"), serverTls.GetKeyPem(), os.ModePerm)) + + // hack default resolver + net.DefaultResolver.PreferGo = true + net.DefaultResolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { + tb.Logf("Custom Resolver %s - %s\n", network, address) + return nil, fmt.Errorf("ha failed") + } + + dnsServer := test.MakeFakeDnsProv("127.0.0.34:5053") // 127.0.0.34:53 + dnsServer.AddRecursiveSOA("example.test.") + go dnsServer.Start() + testDnsOptions = dnsServer + + pebbleFile := filepath.Join(build.Default.GOPATH, "bin", "pebble") + command := exec.Command(pebbleFile, "-config", filepath.Join(pebbleTmp, "pebble-config.json"), "-dnsserver", "127.0.0.34:5053") + command.Env = append(command.Env, "PEBBLE_VA_ALWAYS_VALID=1") + command.Dir = pebbleTmp + + if command.Start() != nil { + instCmd := exec.Command("go", "install", "github.com/letsencrypt/pebble/cmd/pebble@latest") + assert.NoError(tb, instCmd.Run(), "Failed to start pebble make sure it is installed... go install github.com/letsencrypt/pebble/cmd/pebble@latest") + assert.NoError(tb, command.Start(), "failed to start pebble again") + } + + return serverTls, func() { + // unhack default resolver + net.DefaultResolver.PreferGo = false + net.DefaultResolver.Dial = nil + + fmt.Println("Killing pebble") + if command != nil && command.Process != nil { + assert.NoError(tb, command.Process.Kill()) + } + dnsServer.Shutdown() + } +} + +func setupPebbleTest(t *testing.T, serverTls *certgen.CertGen) *Service { + wg := &sync.WaitGroup{} + dbFile := fmt.Sprintf("file:%s?mode=memory&cache=shared", uuid.NewString()) + db, err := sql.Open("sqlite3", dbFile) + assert.NoError(t, err) + log.Println("DB File:", dbFile) + + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.TLSClientConfig.InsecureSkipVerify = true + req, err := http.NewRequest(http.MethodGet, "https://localhost:14000/root", nil) + assert.NoError(t, err) + res, err := tr.RoundTrip(req) + assert.NoError(t, err) + certRaw, err := io.ReadAll(res.Body) + assert.NoError(t, err) + fmt.Println("Cert:", string(certRaw)) + + certDir, err := os.MkdirTemp("", "orchid-certs") + keyDir, err := os.MkdirTemp("", "orchid-keys") + + lePrivKey, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + acmeProv := test.MakeFakeAcmeProv(serverTls.GetCertPem()) + service, err := NewRenewalService(wg, db, acmeProv, LetsEncryptConfig{ + Account: struct { + Email string `yaml:"email"` + PrivateKey string `yaml:"privateKey"` + }{ + Email: "webmaster@example.test", + PrivateKey: string(x509.MarshalPKCS1PrivateKey(lePrivKey)), + }, + Directory: "https://localhost:14000/dir", + Certificate: string(certRaw), + insecure: true, + }, certDir, keyDir) + service.transport = acmeProv + assert.NoError(t, err) + + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + assert.NoError(t, os.WriteFile(filepath.Join(keyDir, "1.key.pem"), x509.MarshalPKCS1PrivateKey(privKey), os.ModePerm)) + + return service +} + +func TestPebbleRenewal(t *testing.T) { + serverTls, cancel := setupPebbleSuite(t) + t.Cleanup(cancel) + + time.Sleep(5 * time.Second) + + tests := []struct { + name string + domains []string + }{ + {"Test", []string{"hello.example.test"}}, + {"Test with multiple certificates", []string{"example.test", "world.example.test"}}, + {"Test with wildcard certificate", []string{"example.test", "*.example.test"}}, + } + + for _, i := range tests { + t.Run(i.name, func(t *testing.T) { + //t.Parallel() + service := setupPebbleTest(t, serverTls) + //goland:noinspection SqlWithoutWhere + _, err := service.db.Exec("DELETE FROM certificate_domains") + assert.NoError(t, err) + + _, err = service.db.Exec(`INSERT INTO certificates (owner, dns, auto_renew, active, renewing, renew_failed, not_after, updated_at) VALUES (1, 1, 1, 1, 0, 0, NULL, NULL)`) + assert.NoError(t, err) + for _, j := range i.domains { + _, err = service.db.Exec(`INSERT INTO certificate_domains (cert_id, domain) VALUES (1, ?)`, j) + assert.NoError(t, err) + } + + fmt.Println("Database rows") + fmt.Println("=============") + query, err := service.db.Query("SELECT cert_id, domain from certificate_domains") + assert.NoError(t, err) + for query.Next() { + var a uint64 + var b string + assert.NoError(t, query.Scan(&a, &b)) + + fmt.Println(a, b) + } + fmt.Println("=============") + + assert.NoError(t, service.renewalCheck()) + certFilePath := filepath.Join(service.certDir, "1.cert.pem") + certFileRaw, err := os.ReadFile(certFilePath) + assert.NoError(t, err) + + p, _ := pem.Decode(certFileRaw) + assert.NotNil(t, p) + if p == nil { + t.FailNow() + } + assert.Equal(t, "CERTIFICATE", p.Type) + outCert, err := x509.ParseCertificate(p.Bytes) + assert.NoError(t, err) + assert.Equal(t, i.domains, outCert.DNSNames) + }) + } +} diff --git a/test/fakeAcmeProv.go b/test/fakeAcmeProv.go new file mode 100644 index 0000000..2aaacc5 --- /dev/null +++ b/test/fakeAcmeProv.go @@ -0,0 +1,35 @@ +package test + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "github.com/go-acme/lego/v4/challenge" + "net/http" +) + +// fakeAcmeProv an acme provider to emulate +type fakeAcmeProv struct { + t http.RoundTripper +} + +func MakeFakeAcmeProv(cert []byte) interface { + challenge.Provider + http.RoundTripper +} { + cp := x509.NewCertPool() + cp.AppendCertsFromPEM(cert) + t := http.DefaultTransport.(*http.Transport).Clone() + t.TLSClientConfig = &tls.Config{RootCAs: cp} + return &fakeAcmeProv{t: t} +} + +func (f *fakeAcmeProv) Present(string, string, string) error { return nil } +func (f *fakeAcmeProv) CleanUp(string, string, string) error { return nil } +func (f *fakeAcmeProv) RoundTrip(req *http.Request) (*http.Response, error) { + // use transport with custom CertPool for pebble requests + if req.URL.Host == "localhost:14000" { + return f.t.RoundTrip(req) + } + return nil, fmt.Errorf("invalid fakeAcmeProv.RoundTrip call to '%s'", req.URL.String()) +} diff --git a/test/fakeDnsProv.go b/test/fakeDnsProv.go new file mode 100644 index 0000000..80c673a --- /dev/null +++ b/test/fakeDnsProv.go @@ -0,0 +1,130 @@ +package test + +import ( + "fmt" + "github.com/go-acme/lego/v4/challenge" + "github.com/go-acme/lego/v4/challenge/dns01" + "github.com/miekg/dns" + "log" + "strings" +) + +type fakeDnsProv struct { + Addr string + mTxt map[string]string + srv *dns.Server + mSoa map[string][2]string +} + +func MakeFakeDnsProv(addr string) interface { + challenge.Provider + GetDnsAddrs() []string + Start() + Shutdown() + AddRecursiveSOA(fqdn string) +} { + return &fakeDnsProv{ + Addr: addr, + mTxt: make(map[string]string), + mSoa: make(map[string][2]string), + } +} + +func (f *fakeDnsProv) AddRecursiveSOA(fqdn string) { + n := fqdn + for { + f.mSoa[n] = [2]string{"ns." + n, "webmaster." + n} + + // find next subdomain separator and trim the fqdn + ni := strings.IndexByte(n, '.') + if ni <= 0 { + break + } + n = n[ni+1:] + } +} + +func (f *fakeDnsProv) Present(domain, _, keyAuth string) error { + info := dns01.GetChallengeInfo(domain, keyAuth) + f.mTxt[info.EffectiveFQDN] = info.Value + log.Printf("fakeDnsProv.Present(%s TXT %s)\n", info.EffectiveFQDN, info.Value) + return nil +} +func (f *fakeDnsProv) CleanUp(domain, _, keyAuth string) error { + info := dns01.GetChallengeInfo(domain, keyAuth) + delete(f.mTxt, info.EffectiveFQDN) + log.Printf("fakeDnsProv.CleanUp(%s TXT %s)\n", info.EffectiveFQDN, info.Value) + return nil +} +func (f *fakeDnsProv) GetDnsAddrs() []string { + fmt.Printf("Get dns addrs: %v\n", f.Addr) + return []string{f.Addr} +} + +func (f *fakeDnsProv) parseQuery(m *dns.Msg) { + for _, q := range m.Question { + switch q.Qtype { + case dns.TypeSOA: + log.Printf("Looking up %s SOA record\n", q.Name) + n := q.Name + for strings.Count(n, ".") > 3 { + // find next subdomain separator and trim the fqdn + ni := strings.IndexByte(n, '.') + if ni <= 0 { + break + } + n = n[ni+1:] + } + + // find an answer if possible + if strings.Count(q.Name, ".") == 3 { + rr, err := dns.NewRR(fmt.Sprintf("%s 32600 IN SOA %s %s 1687993787 86400 7200 4000000 11200", n, "ns.example.com.", "hostmaster.example.com.")) + if err == nil { + m.Answer = append(m.Answer, rr) + } + } + case dns.TypeTXT: + log.Printf("Looking up %s TXT record\n", q.Name) + txt := f.mTxt[q.Name] + if txt != "" { + rr, err := dns.NewRR(fmt.Sprintf("%s 32600 IN TXT \"%s\"", q.Name, txt)) + if err == nil { + fmt.Println("response:", rr.String()) + m.Answer = append(m.Answer, rr) + } + } + default: + log.Printf("Looking up %d for %s\n", q.Qtype, q.Name) + } + } +} + +func (f *fakeDnsProv) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Compress = false + + switch r.Opcode { + case dns.OpcodeQuery: + f.parseQuery(m) + } + + _ = w.WriteMsg(m) +} + +func (f *fakeDnsProv) Start() { + // attach request handler func + dns.HandleFunc(".", f.handleDnsRequest) + + // start server + f.srv = &dns.Server{Addr: f.Addr, Net: "udp"} + log.Printf("Starting fake dns service at %s\n", f.srv.Addr) + err := f.srv.ListenAndServe() + if err != nil { + log.Fatalf("Failed to start server: %s\n ", err.Error()) + } +} + +func (f *fakeDnsProv) Shutdown() { + _ = f.srv.Shutdown() +}