Auto generate domain certificate private key if it doesn't already exist

This commit is contained in:
Melon 2023-07-11 15:04:59 +01:00
parent 78734a1e01
commit 4188dccd1d
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
4 changed files with 28 additions and 11 deletions

View File

@ -55,7 +55,7 @@ func (h *HttpAcmeProvider) Present(domain, token, keyAuth string) error {
return err return err
} }
if trip.StatusCode != http.StatusAccepted { if trip.StatusCode != http.StatusAccepted {
return fmt.Errorf("Trip response status code was not 200") return fmt.Errorf("trip response status code was not 202")
} }
return nil return nil
} }
@ -69,7 +69,7 @@ func (h *HttpAcmeProvider) CleanUp(domain, token, keyAuth string) error {
return err return err
} }
if trip.StatusCode != http.StatusAccepted { if trip.StatusCode != http.StatusAccepted {
return fmt.Errorf("Trip response status code was not 200") return fmt.Errorf("trip response status code was not 202")
} }
return nil return nil
} }
@ -83,7 +83,7 @@ func (h *HttpAcmeProvider) authCheckRequest(method, url, domain, token, keyAuth
return nil, err return nil, err
} }
switch resp.StatusCode { switch resp.StatusCode {
case http.StatusOK: case http.StatusAccepted:
// just return // just return
return resp, nil return resp, nil
case http.StatusForbidden: case http.StatusForbidden:
@ -99,8 +99,8 @@ func (h *HttpAcmeProvider) authCheckRequest(method, url, domain, token, keyAuth
if err != nil { if err != nil {
return nil, fmt.Errorf("refresh token request failed: %w", err) return nil, fmt.Errorf("refresh token request failed: %w", err)
} }
if trip.StatusCode != http.StatusOK { if trip.StatusCode != http.StatusAccepted {
return nil, fmt.Errorf("refresh token request failed: due to invalid status code, expected 200 got %d", trip.StatusCode) return nil, fmt.Errorf("refresh token request failed: due to invalid status code, expected 202 got %d", trip.StatusCode)
} }
// parse tokens from response body // parse tokens from response body
@ -125,10 +125,10 @@ func (h *HttpAcmeProvider) authCheckRequest(method, url, domain, token, keyAuth
// just return // just return
return resp, nil return resp, nil
} }
return nil, fmt.Errorf("invalid status code, expected 200 got %d", resp.StatusCode) return nil, fmt.Errorf("invalid status code, expected 202 got %d", resp.StatusCode)
} }
// first request had an invalid status code // first request had an invalid status code
return nil, fmt.Errorf("invalid status code, expected 200/403 got %d", resp.StatusCode) return nil, fmt.Errorf("invalid status code, expected 202/403 got %d", resp.StatusCode)
} }
// internalRequest sends a request to the acme challenge hosting api // internalRequest sends a request to the acme challenge hosting api

View File

@ -55,7 +55,7 @@ func (f *fakeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("missing perm 'test:acme:clean'") return nil, fmt.Errorf("missing perm 'test:acme:clean'")
} }
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
rec.WriteHeader(http.StatusOK) rec.WriteHeader(http.StatusAccepted)
f.req = req f.req = req
return rec.Result(), nil return rec.Result(), nil
} }

View File

@ -20,6 +20,7 @@ import (
"github.com/go-acme/lego/v4/registration" "github.com/go-acme/lego/v4/registration"
"io" "io"
"log" "log"
"math/rand"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
@ -390,13 +391,29 @@ func (s *Service) getDnsProvider(name, token string) (challenge.Provider, error)
} }
} }
// getPrivateKey reads the private key for the specified certificate id // getPrivateKey reads the private key for the specified certificate id, or
// generates one is the file doesn't exist
func (s *Service) getPrivateKey(id uint64) (*rsa.PrivateKey, error) { func (s *Service) getPrivateKey(id uint64) (*rsa.PrivateKey, error) {
pemBytes, err := os.ReadFile(filepath.Join(s.keyDir, fmt.Sprintf("%d.key.pem", id))) fPath := filepath.Join(s.keyDir, fmt.Sprintf("%d.key.pem", id))
pemBytes, err := os.ReadFile(fPath)
if err != nil { if err != nil {
if os.IsNotExist(err) {
key, err := rsa.GenerateKey(rand.New(rand.NewSource(time.Now().UnixNano())), 4096)
if err != nil {
return nil, fmt.Errorf("generate rsa key error: %w", err)
}
err = os.WriteFile(fPath, pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}), os.ModePerm)
if err != nil {
return nil, fmt.Errorf("failed to save rsa key: %w", err)
}
return key, nil
}
return nil, err return nil, err
} }
keyBlock, _ := pem.Decode(pemBytes) keyBlock, _ := pem.Decode(pemBytes)
if keyBlock == nil {
return nil, fmt.Errorf("invalid pem block: failed to parse")
}
if keyBlock.Type != "RSA PRIVATE KEY" { if keyBlock.Type != "RSA PRIVATE KEY" {
return nil, fmt.Errorf("invalid pem block type") return nil, fmt.Errorf("invalid pem block type")
} }

View File

@ -136,7 +136,7 @@ func setupPebbleTest(t *testing.T, serverTls *certgen.CertGen) *Service {
privKey, err := rsa.GenerateKey(rand.Reader, 2048) privKey, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err) assert.NoError(t, err)
assert.NoError(t, os.WriteFile(filepath.Join(keyDir, "1.key.pem"), x509.MarshalPKCS1PrivateKey(privKey), os.ModePerm)) assert.NoError(t, os.WriteFile(filepath.Join(keyDir, "1.key.pem"), pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privKey)}), os.ModePerm))
return service return service
} }