From 4188dccd1d4c48f14b4da641c9cc94b981b49dc5 Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Tue, 11 Jul 2023 15:04:59 +0100 Subject: [PATCH] Auto generate domain certificate private key if it doesn't already exist --- http-acme/http-acme-provider.go | 14 +++++++------- http-acme/http-acme-provider_test.go | 2 +- renewal/service.go | 21 +++++++++++++++++++-- renewal/service_test.go | 2 +- 4 files changed, 28 insertions(+), 11 deletions(-) diff --git a/http-acme/http-acme-provider.go b/http-acme/http-acme-provider.go index e936dc9..7969560 100644 --- a/http-acme/http-acme-provider.go +++ b/http-acme/http-acme-provider.go @@ -55,7 +55,7 @@ func (h *HttpAcmeProvider) Present(domain, token, keyAuth string) error { return err } if trip.StatusCode != http.StatusAccepted { - return fmt.Errorf("Trip response status code was not 200") + return fmt.Errorf("trip response status code was not 202") } return nil } @@ -69,7 +69,7 @@ func (h *HttpAcmeProvider) CleanUp(domain, token, keyAuth string) error { return err } if trip.StatusCode != http.StatusAccepted { - return fmt.Errorf("Trip response status code was not 200") + return fmt.Errorf("trip response status code was not 202") } return nil } @@ -83,7 +83,7 @@ func (h *HttpAcmeProvider) authCheckRequest(method, url, domain, token, keyAuth return nil, err } switch resp.StatusCode { - case http.StatusOK: + case http.StatusAccepted: // just return return resp, nil case http.StatusForbidden: @@ -99,8 +99,8 @@ func (h *HttpAcmeProvider) authCheckRequest(method, url, domain, token, keyAuth if err != nil { return nil, fmt.Errorf("refresh token request failed: %w", err) } - if trip.StatusCode != http.StatusOK { - return nil, fmt.Errorf("refresh token request failed: due to invalid status code, expected 200 got %d", trip.StatusCode) + if trip.StatusCode != http.StatusAccepted { + return nil, fmt.Errorf("refresh token request failed: due to invalid status code, expected 202 got %d", trip.StatusCode) } // parse tokens from response body @@ -125,10 +125,10 @@ func (h *HttpAcmeProvider) authCheckRequest(method, url, domain, token, keyAuth // just return return resp, nil } - return nil, fmt.Errorf("invalid status code, expected 200 got %d", resp.StatusCode) + return nil, fmt.Errorf("invalid status code, expected 202 got %d", resp.StatusCode) } // first request had an invalid status code - return nil, fmt.Errorf("invalid status code, expected 200/403 got %d", resp.StatusCode) + return nil, fmt.Errorf("invalid status code, expected 202/403 got %d", resp.StatusCode) } // internalRequest sends a request to the acme challenge hosting api diff --git a/http-acme/http-acme-provider_test.go b/http-acme/http-acme-provider_test.go index e2ff581..82d7312 100644 --- a/http-acme/http-acme-provider_test.go +++ b/http-acme/http-acme-provider_test.go @@ -55,7 +55,7 @@ func (f *fakeTransport) RoundTrip(req *http.Request) (*http.Response, error) { return nil, fmt.Errorf("missing perm 'test:acme:clean'") } rec := httptest.NewRecorder() - rec.WriteHeader(http.StatusOK) + rec.WriteHeader(http.StatusAccepted) f.req = req return rec.Result(), nil } diff --git a/renewal/service.go b/renewal/service.go index 5193ac0..d9091ef 100644 --- a/renewal/service.go +++ b/renewal/service.go @@ -20,6 +20,7 @@ import ( "github.com/go-acme/lego/v4/registration" "io" "log" + "math/rand" "net/http" "os" "path/filepath" @@ -390,13 +391,29 @@ func (s *Service) getDnsProvider(name, token string) (challenge.Provider, error) } } -// getPrivateKey reads the private key for the specified certificate id +// getPrivateKey reads the private key for the specified certificate id, or +// generates one is the file doesn't exist func (s *Service) getPrivateKey(id uint64) (*rsa.PrivateKey, error) { - pemBytes, err := os.ReadFile(filepath.Join(s.keyDir, fmt.Sprintf("%d.key.pem", id))) + fPath := filepath.Join(s.keyDir, fmt.Sprintf("%d.key.pem", id)) + pemBytes, err := os.ReadFile(fPath) if err != nil { + if os.IsNotExist(err) { + key, err := rsa.GenerateKey(rand.New(rand.NewSource(time.Now().UnixNano())), 4096) + if err != nil { + return nil, fmt.Errorf("generate rsa key error: %w", err) + } + err = os.WriteFile(fPath, pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}), os.ModePerm) + if err != nil { + return nil, fmt.Errorf("failed to save rsa key: %w", err) + } + return key, nil + } return nil, err } keyBlock, _ := pem.Decode(pemBytes) + if keyBlock == nil { + return nil, fmt.Errorf("invalid pem block: failed to parse") + } if keyBlock.Type != "RSA PRIVATE KEY" { return nil, fmt.Errorf("invalid pem block type") } diff --git a/renewal/service_test.go b/renewal/service_test.go index 94619b3..189eac8 100644 --- a/renewal/service_test.go +++ b/renewal/service_test.go @@ -136,7 +136,7 @@ func setupPebbleTest(t *testing.T, serverTls *certgen.CertGen) *Service { privKey, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) - assert.NoError(t, os.WriteFile(filepath.Join(keyDir, "1.key.pem"), x509.MarshalPKCS1PrivateKey(privKey), os.ModePerm)) + assert.NoError(t, os.WriteFile(filepath.Join(keyDir, "1.key.pem"), pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privKey)}), os.ModePerm)) return service }