diff --git a/signer.go b/signer.go index 636c3ff..99f6631 100644 --- a/signer.go +++ b/signer.go @@ -2,7 +2,10 @@ package mjwt import ( "crypto/rsa" + "crypto/x509" "github.com/golang-jwt/jwt/v4" + "io" + "os" "time" ) @@ -25,6 +28,31 @@ func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer { } } +// NewMJwtSignerFromFileOrCreate creates a new defaultMJwtSigner using the path +// of a rsa.PrivateKey file. If the file does not exist then the file is created +// and a new private key is generated. +func NewMJwtSignerFromFileOrCreate(issuer, file string, random io.Reader, bits int) (Signer, error) { + privateKey, err := readOrCreatePrivateKey(file, random, bits) + if err != nil { + return nil, err + } + return NewMJwtSigner(issuer, privateKey), nil +} + +// NewMJwtSignerFromFile creates a new defaultMJwtSigner using the path of a +// rsa.PrivateKey file. +func NewMJwtSignerFromFile(issuer, file string) (Signer, error) { + raw, err := os.ReadFile(file) + if err != nil { + return nil, err + } + key, err := x509.ParsePKCS1PrivateKey(raw) + if err != nil { + return nil, err + } + return NewMJwtSigner(issuer, key), nil +} + // Issuer returns the name of the issuer func (d *defaultMJwtSigner) Issuer() string { return d.issuer } @@ -44,3 +72,41 @@ func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) { func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) { return d.verify.VerifyJwt(token, claims) } + +// readOrCreatePrivateKey returns the private key it the file already exists, +// generates a new private key and saves it to the file, or returns an error if +// reading or generating failed. +func readOrCreatePrivateKey(file string, random io.Reader, bits int) (*rsa.PrivateKey, error) { + // read the file or return nil + f, err := readOrEmptyFile(file) + if err != nil { + return nil, err + } + if f == nil { + // generate a new key + key, err := rsa.GenerateKey(random, bits) + if err != nil { + return nil, err + } + + // write the key to the file + err = os.WriteFile(file, x509.MarshalPKCS1PrivateKey(key), 0600) + return key, err + } else { + // try to parse the private key + return x509.ParsePKCS1PrivateKey(f) + } +} + +// readOrEmptyFile returns bytes and errors from os.ReadFile or (nil, nil) if the +// file does not exist. +func readOrEmptyFile(file string) ([]byte, error) { + raw, err := os.ReadFile(file) + if err == nil { + return raw, nil + } + if os.IsNotExist(err) { + return nil, nil + } + return nil, err +} diff --git a/signer_test.go b/signer_test.go new file mode 100644 index 0000000..8691da3 --- /dev/null +++ b/signer_test.go @@ -0,0 +1,62 @@ +package mjwt + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "github.com/stretchr/testify/assert" + "os" + "testing" +) + +func TestNewMJwtSigner(t *testing.T) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + NewMJwtSigner("Test", key) +} + +func TestNewMJwtSignerFromFile(t *testing.T) { + tempKey, err := os.CreateTemp("", "key-test-*.pem") + assert.NoError(t, err) + key, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + _, err = tempKey.Write(x509.MarshalPKCS1PrivateKey(key)) + assert.NoError(t, err) + assert.NoError(t, tempKey.Close()) + signer, err := NewMJwtSignerFromFile("Test", tempKey.Name()) + assert.NoError(t, err) + assert.NoError(t, os.Remove(tempKey.Name())) + _, err = NewMJwtSignerFromFile("Test", tempKey.Name()) + assert.Error(t, err) + assert.True(t, os.IsNotExist(err)) + assert.True(t, signer.(*defaultMJwtSigner).key.Equal(key)) +} + +func TestNewMJwtSignerFromFileOrCreate(t *testing.T) { + tempKey, err := os.CreateTemp("", "key-test-*.pem") + assert.NoError(t, err) + assert.NoError(t, tempKey.Close()) + assert.NoError(t, os.Remove(tempKey.Name())) + signer, err := NewMJwtSignerFromFileOrCreate("Test", tempKey.Name(), rand.Reader, 2048) + assert.NoError(t, err) + signer2, err := NewMJwtSignerFromFileOrCreate("Test", tempKey.Name(), rand.Reader, 2048) + assert.NoError(t, err) + assert.True(t, signer.(*defaultMJwtSigner).key.Equal(signer2.(*defaultMJwtSigner).key)) +} + +func TestReadOrCreatePrivateKey(t *testing.T) { + tempKey, err := os.CreateTemp("", "key-test-*.pem") + assert.NoError(t, err) + key, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + _, err = tempKey.Write(x509.MarshalPKCS1PrivateKey(key)) + assert.NoError(t, err) + assert.NoError(t, tempKey.Close()) + key2, err := readOrCreatePrivateKey(tempKey.Name(), rand.Reader, 2048) + assert.NoError(t, err) + assert.True(t, key.Equal(key2)) + assert.NoError(t, os.Remove(tempKey.Name())) + key3, err := readOrCreatePrivateKey(tempKey.Name(), rand.Reader, 2048) + assert.NoError(t, err) + assert.NoError(t, key3.Validate()) +}