diff --git a/client.go b/client.go index 13e2313..06b2bbf 100644 --- a/client.go +++ b/client.go @@ -11,7 +11,7 @@ import ( "time" ) -func MakeClientTls(name pkix.Name, serialNumber *big.Int) (*CertGen, error) { +func MakeClientTls(ca *CertGen, name pkix.Name, serialNumber *big.Int) (*CertGen, error) { cert := &x509.Certificate{ SerialNumber: serialNumber, Subject: name, @@ -27,7 +27,10 @@ func MakeClientTls(name pkix.Name, serialNumber *big.Int) (*CertGen, error) { log.Fatalln("Failed to generate client private key:", err) } - clientBytes, err := x509.CreateCertificate(rand.Reader, cert, cert, clientPrivKey.Public(), clientPrivKey) + if ca == nil { + ca = &CertGen{cert: cert, key: clientPrivKey} + } + clientBytes, err := x509.CreateCertificate(rand.Reader, cert, ca.cert, clientPrivKey.Public(), ca.key) if err != nil { log.Fatalln("Failed to generate client certificate bytes:", err) } diff --git a/server.go b/server.go index 9587a95..e5ad2f1 100644 --- a/server.go +++ b/server.go @@ -30,6 +30,9 @@ func MakeServerTls(ca *CertGen, name pkix.Name, serialNumber *big.Int, dnsNames log.Fatalln("Failed to generate server private key:", err) } + if ca == nil { + ca = &CertGen{cert: cert, key: serverPrivKey} + } serverBytes, err := x509.CreateCertificate(rand.Reader, cert, ca.cert, serverPrivKey.Public(), ca.key) if err != nil { log.Fatalln("Failed to generate server certificate bytes:", err)