diff --git a/key_store_test.go b/key_store_test.go index c7ef163..264f689 100644 --- a/key_store_test.go +++ b/key_store_test.go @@ -11,8 +11,8 @@ import ( "testing" ) -const prvExt = "prv" -const pubExt = "pub" +const kst_prvExt = "prv" +const kst_pubExt = "pub" func setupTestDirKeyStore(t *testing.T, genKeys bool) (string, func(t *testing.T)) { tempDir, err := os.MkdirTemp("", "this-is-a-test-dir") @@ -21,19 +21,19 @@ func setupTestDirKeyStore(t *testing.T, genKeys bool) (string, func(t *testing.T if genKeys { key1, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) - err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+prvExt), key1) + err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+kst_prvExt), key1) assert.NoError(t, err) key2, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) - err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+prvExt), key2) + err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+kst_prvExt), key2) assert.NoError(t, err) - err = rsapublic.Write(path.Join(tempDir, "key2.pem."+pubExt), &key2.PublicKey) + err = rsapublic.Write(path.Join(tempDir, "key2.pem."+kst_pubExt), &key2.PublicKey) assert.NoError(t, err) key3, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) - err = rsapublic.Write(path.Join(tempDir, "key3.pem."+pubExt), &key3.PublicKey) + err = rsapublic.Write(path.Join(tempDir, "key3.pem."+kst_pubExt), &key3.PublicKey) assert.NoError(t, err) } @@ -110,7 +110,7 @@ func TestNewMJwtKeyStoreFromDirectory(t *testing.T) { tempDir, cleaner := setupTestDirKeyStore(t, true) defer cleaner(t) - kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, prvExt, pubExt) + kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, kst_prvExt, kst_pubExt) assert.NoError(t, err) assert.Len(t, kStore.ListKeys(), 3) @@ -130,7 +130,7 @@ func TestExportKeyStore(t *testing.T) { tempDir2, cleaner2 := setupTestDirKeyStore(t, false) defer cleaner2(t) - kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, prvExt, pubExt) + kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, kst_prvExt, kst_pubExt) assert.NoError(t, err) const prvExt2 = "v" diff --git a/signer.go b/signer.go index f0cdc40..3959619 100644 --- a/signer.go +++ b/signer.go @@ -30,10 +30,14 @@ func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer { // NewMJwtSignerWithKeyStore creates a new defaultMJwtSigner using the issuer name, a rsa.PrivateKey // for no kID and a KeyStore for kID based keys func NewMJwtSignerWithKeyStore(issuer string, key *rsa.PrivateKey, kStore KeyStore) Signer { + var pKey *rsa.PublicKey = nil + if key != nil { + pKey = &key.PublicKey + } return &defaultMJwtSigner{ issuer: issuer, key: key, - verify: NewMjwtVerifierWithKeyStore(&key.PublicKey, kStore).(*defaultMJwtVerifier), + verify: NewMjwtVerifierWithKeyStore(pKey, kStore).(*defaultMJwtVerifier), } } @@ -108,6 +112,9 @@ func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) { if d == nil { return "", errors.New("signer nil") } + if d.key == nil { + return "", errors.New("no private key found") + } token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped) return token.SignedString(d.key) } diff --git a/signer_test.go b/signer_test.go index 64d8fda..468b8e0 100644 --- a/signer_test.go +++ b/signer_test.go @@ -5,11 +5,48 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "github.com/1f349/rsa-helper/rsaprivate" + "github.com/1f349/rsa-helper/rsapublic" "github.com/stretchr/testify/assert" "os" + "path" "testing" ) +const st_prvExt = "prv" +const st_pubExt = "pub" + +func setupTestDirSigner(t *testing.T, genKeys bool) (string, *rsa.PrivateKey, func(t *testing.T)) { + tempDir, err := os.MkdirTemp("", "this-is-a-test-dir") + assert.NoError(t, err) + + var key3 *rsa.PrivateKey = nil + + if genKeys { + key1, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+st_prvExt), key1) + assert.NoError(t, err) + + key2, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+st_prvExt), key2) + assert.NoError(t, err) + err = rsapublic.Write(path.Join(tempDir, "key2.pem."+st_pubExt), &key2.PublicKey) + assert.NoError(t, err) + + key3, err = rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsapublic.Write(path.Join(tempDir, "key3.pem."+st_pubExt), &key3.PublicKey) + assert.NoError(t, err) + } + + return tempDir, key3, func(t *testing.T) { + err := os.RemoveAll(tempDir) + assert.NoError(t, err) + } +} + func TestNewMJwtSigner(t *testing.T) { t.Parallel() key, err := rsa.GenerateKey(rand.Reader, 2048) @@ -17,6 +54,16 @@ func TestNewMJwtSigner(t *testing.T) { NewMJwtSigner("Test", key) } +func TestNewMJwtSignerWithKeyStore(t *testing.T) { + t.Parallel() + key, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + kStore := NewMJwtKeyStore() + kStore.SetKey("test", key) + assert.Contains(t, kStore.ListKeys(), "test") + NewMJwtSignerWithKeyStore("Test", nil, kStore) +} + func TestNewMJwtSignerFromFile(t *testing.T) { t.Parallel() tempKey, err := os.CreateTemp("", "key-test-*.pem") @@ -67,3 +114,38 @@ func TestReadOrCreatePrivateKey(t *testing.T) { assert.NoError(t, err) assert.NoError(t, key3.Validate()) } + +func TestNewMJwtSignerFromDirectory(t *testing.T) { + t.Parallel() + + tempDir, prvKey3, cleaner := setupTestDirSigner(t, true) + defer cleaner(t) + + signer, err := NewMJwtSignerFromDirectory("Test", tempDir, st_prvExt, st_pubExt) + assert.NoError(t, err) + + assert.Len(t, signer.GetKeyStore().ListKeys(), 3) + kIDsToFind := []string{"key1", "key2", "key3"} + for _, k := range kIDsToFind { + assert.Contains(t, signer.GetKeyStore().ListKeys(), k) + } + assert.True(t, prvKey3.PublicKey.Equal(signer.GetKeyStore().GetKeyPublic("key3"))) +} + +func TestNewMJwtSignerFromFileAndDirectory(t *testing.T) { + t.Parallel() + + tempDir, prvKey3, cleaner := setupTestDirSigner(t, true) + defer cleaner(t) + + signer, err := NewMJwtSignerFromFileAndDirectory("Test", path.Join(tempDir, "key1.pem."+st_prvExt), tempDir, st_prvExt, st_pubExt) + assert.NoError(t, err) + + assert.Len(t, signer.GetKeyStore().ListKeys(), 3) + kIDsToFind := []string{"key1", "key2", "key3"} + for _, k := range kIDsToFind { + assert.Contains(t, signer.GetKeyStore().ListKeys(), k) + } + assert.True(t, prvKey3.PublicKey.Equal(signer.GetKeyStore().GetKeyPublic("key3"))) + assert.True(t, signer.PrivateKey().Equal(signer.GetKeyStore().GetKey("key1"))) +} diff --git a/verifier_test.go b/verifier_test.go index 8e448dd..378ce1a 100644 --- a/verifier_test.go +++ b/verifier_test.go @@ -5,12 +5,49 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "github.com/1f349/rsa-helper/rsaprivate" + "github.com/1f349/rsa-helper/rsapublic" "github.com/stretchr/testify/assert" "os" + "path" "testing" "time" ) +const vt_prvExt = "prv" +const vt_pubExt = "pub" + +func setupTestDirVerifier(t *testing.T, genKeys bool) (string, *rsa.PrivateKey, func(t *testing.T)) { + tempDir, err := os.MkdirTemp("", "this-is-a-test-dir") + assert.NoError(t, err) + + var key3 *rsa.PrivateKey = nil + + if genKeys { + key1, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+vt_prvExt), key1) + assert.NoError(t, err) + + key2, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+vt_prvExt), key2) + assert.NoError(t, err) + err = rsapublic.Write(path.Join(tempDir, "key2.pem."+vt_pubExt), &key2.PublicKey) + assert.NoError(t, err) + + key3, err = rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsapublic.Write(path.Join(tempDir, "key3.pem."+vt_pubExt), &key3.PublicKey) + assert.NoError(t, err) + } + + return tempDir, key3, func(t *testing.T) { + err := os.RemoveAll(tempDir) + assert.NoError(t, err) + } +} + func TestNewMJwtVerifierFromFile(t *testing.T) { t.Parallel() key, err := rsa.GenerateKey(rand.Reader, 2048) @@ -32,3 +69,43 @@ func TestNewMJwtVerifierFromFile(t *testing.T) { err = os.Remove(temp.Name()) assert.NoError(t, err) } + +func TestNewMJwtVerifierFromDirectory(t *testing.T) { + t.Parallel() + + tempDir, prvKey3, cleaner := setupTestDirVerifier(t, true) + defer cleaner(t) + + s, err := NewMJwtSignerFromDirectory("mjwt.test", tempDir, vt_prvExt, vt_pubExt) + assert.NoError(t, err) + s.GetKeyStore().SetKey("key3", prvKey3) + token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}, "key3") + assert.NoError(t, err) + + v, err := NewMJwtVerifierFromDirectory(tempDir, vt_prvExt, vt_pubExt) + assert.NoError(t, err) + _, _, err = ExtractClaims[testClaims](v, token) + assert.NoError(t, err) +} + +func TestNewMJwtVerifierFromFileAndDirectory(t *testing.T) { + t.Parallel() + + tempDir, prvKey3, cleaner := setupTestDirVerifier(t, true) + defer cleaner(t) + + s, err := NewMJwtSignerFromFileAndDirectory("mjwt.test", path.Join(tempDir, "key2.pem."+vt_prvExt), tempDir, vt_prvExt, vt_pubExt) + assert.NoError(t, err) + s.GetKeyStore().SetKey("key3", prvKey3) + token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}) + assert.NoError(t, err) + token2, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}, "key3") + assert.NoError(t, err) + + v, err := NewMJwtVerifierFromFileAndDirectory(path.Join(tempDir, "key2.pem."+vt_pubExt), tempDir, vt_prvExt, vt_pubExt) + assert.NoError(t, err) + _, _, err = ExtractClaims[testClaims](v, token1) + assert.NoError(t, err) + _, _, err = ExtractClaims[testClaims](v, token2) + assert.NoError(t, err) +}