diff --git a/interfaces.go b/interfaces.go index ce63f1d..dd7add2 100644 --- a/interfaces.go +++ b/interfaces.go @@ -26,7 +26,7 @@ type Verifier interface { type KeyStore interface { SetKey(kID string, prvKey *rsa.PrivateKey) bool SetKeyPublic(kID string, pubKey *rsa.PublicKey) bool - RemoveKey(kID string) bool + RemoveKey(kID string) ListKeys() []string GetKey(kID string) *rsa.PrivateKey GetKeyPublic(kID string) *rsa.PublicKey diff --git a/key_store.go b/key_store.go index b0a0b1a..a0a5e03 100644 --- a/key_store.go +++ b/key_store.go @@ -21,8 +21,8 @@ type defaultMJwtKeyStore struct { var _ KeyStore = &defaultMJwtKeyStore{} -// newDefaultMJwtKeyStore creates a new defaultMJwtKeyStore. -func newDefaultMJwtKeyStore() *defaultMJwtKeyStore { +// NewMJwtKeyStore creates a new defaultMJwtKeyStore. +func NewMJwtKeyStore() KeyStore { return &defaultMJwtKeyStore{ rwLocker: new(sync.RWMutex), store: make(map[string]*rsa.PrivateKey), @@ -30,16 +30,11 @@ func newDefaultMJwtKeyStore() *defaultMJwtKeyStore { } } -// NewMJwtKeyStore creates a new defaultMJwtKeyStore. -func NewMJwtKeyStore() KeyStore { - return newDefaultMJwtKeyStore() -} - // NewMJwtKeyStoreFromDirectory loads keys from a directory with the specified extensions to denote public and private // rsa keys; the kID is the filename of the key up to the first . func NewMJwtKeyStoreFromDirectory(directory string, keyPrvExt string, keyPubExt string) (KeyStore, error) { // Create empty KeyStore - ks := newDefaultMJwtKeyStore() + ks := NewMJwtKeyStore().(*defaultMJwtKeyStore) // List directory contents dirEntries, err := os.ReadDir(directory) if err != nil { @@ -142,15 +137,15 @@ func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) bo } // RemoveKey removes a specified kID from the KeyStore. -func (d *defaultMJwtKeyStore) RemoveKey(kID string) bool { +func (d *defaultMJwtKeyStore) RemoveKey(kID string) { if d == nil { - return false + return } d.rwLocker.Lock() defer d.rwLocker.Unlock() delete(d.store, kID) delete(d.storePub, kID) - return true + return } // ListKeys lists the kIDs of all the keys in the KeyStore. diff --git a/key_store_test.go b/key_store_test.go index d79333e..69963b4 100644 --- a/key_store_test.go +++ b/key_store_test.go @@ -11,32 +11,108 @@ import ( "testing" ) -func TestNewMJwtKeyStoreFromDirectory(t *testing.T) { - t.Parallel() +const prvExt = "prv" +const pubExt = "pub" + +func setupTestDir(t *testing.T, genKeys bool) (string, func(t *testing.T)) { tempDir, err := os.MkdirTemp("", "this-is-a-test-dir") assert.NoError(t, err) - const prvExt = "prv" - const pubExt = "pub" + if genKeys { + key1, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+prvExt), key1) + assert.NoError(t, err) - key1, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+prvExt), key1) + key2, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+prvExt), key2) + assert.NoError(t, err) + err = rsapublic.Write(path.Join(tempDir, "key2.pem."+pubExt), &key2.PublicKey) + assert.NoError(t, err) + + key3, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsapublic.Write(path.Join(tempDir, "key3.pem."+pubExt), &key3.PublicKey) + assert.NoError(t, err) + } + + return tempDir, func(t *testing.T) { + err := os.RemoveAll(tempDir) + assert.NoError(t, err) + } +} + +func commonSubTests(t *testing.T, kStore KeyStore) { + key4, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) - key2, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+prvExt), key2) - assert.NoError(t, err) - err = rsapublic.Write(path.Join(tempDir, "key2.pem."+pubExt), &key2.PublicKey) + key5, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) - key3, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - err = rsapublic.Write(path.Join(tempDir, "key3.pem."+pubExt), &key3.PublicKey) - assert.NoError(t, err) + const extraKID1 = "key4" + const extraKID2 = "key5" - kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, "prv", "pub") + t.Run("TestSetKey", func(t *testing.T) { + b := kStore.SetKey(extraKID1, key4) + assert.True(t, b) + assert.Contains(t, kStore.ListKeys(), extraKID1) + }) + + t.Run("TestSetKeyPublic", func(t *testing.T) { + b := kStore.SetKeyPublic(extraKID2, &key5.PublicKey) + assert.True(t, b) + assert.Contains(t, kStore.ListKeys(), extraKID2) + }) + + t.Run("TestGetKey", func(t *testing.T) { + oKey := kStore.GetKey(extraKID1) + assert.Same(t, key4, oKey) + pKey := kStore.GetKey(extraKID2) + assert.Nil(t, pKey) + aKey := kStore.GetKey("key1") + assert.NotNil(t, aKey) + bKey := kStore.GetKey("key2") + assert.NotNil(t, bKey) + cKey := kStore.GetKey("key3") + assert.Nil(t, cKey) + }) + + t.Run("TestGetKeyPublic", func(t *testing.T) { + oKey := kStore.GetKeyPublic(extraKID1) + assert.Same(t, &key4.PublicKey, oKey) + pKey := kStore.GetKeyPublic(extraKID2) + assert.Same(t, &key5.PublicKey, pKey) + aKey := kStore.GetKeyPublic("key1") + assert.NotNil(t, aKey) + bKey := kStore.GetKeyPublic("key2") + assert.NotNil(t, bKey) + cKey := kStore.GetKeyPublic("key3") + assert.NotNil(t, cKey) + }) + + t.Run("TestRemoveKey", func(t *testing.T) { + kStore.RemoveKey(extraKID1) + assert.NotContains(t, kStore.ListKeys(), extraKID1) + oKey1 := kStore.GetKey(extraKID1) + assert.Nil(t, oKey1) + oKey2 := kStore.GetKeyPublic(extraKID1) + assert.Nil(t, oKey2) + }) + + t.Run("TestClearKeys", func(t *testing.T) { + kStore.ClearKeys() + assert.Empty(t, kStore.ListKeys()) + }) +} + +func TestNewMJwtKeyStoreFromDirectory(t *testing.T) { + t.Parallel() + + tempDir, cleaner := setupTestDir(t, true) + defer cleaner(t) + + kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, prvExt, pubExt) assert.NoError(t, err) assert.Len(t, kStore.ListKeys(), 3) @@ -44,4 +120,35 @@ func TestNewMJwtKeyStoreFromDirectory(t *testing.T) { for _, k := range kIDsToFind { assert.Contains(t, kStore.ListKeys(), k) } + + commonSubTests(t, kStore) +} + +func TestExportKeyStore(t *testing.T) { + t.Parallel() + + tempDir, cleaner := setupTestDir(t, true) + defer cleaner(t) + tempDir2, cleaner2 := setupTestDir(t, false) + defer cleaner2(t) + + kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, prvExt, pubExt) + assert.NoError(t, err) + + const prvExt2 = "v" + const pubExt2 = "b" + + err = ExportKeyStore(kStore, tempDir2, prvExt2, pubExt2) + assert.NoError(t, err) + + kStore2, err := NewMJwtKeyStoreFromDirectory(tempDir2, prvExt2, pubExt2) + assert.NoError(t, err) + + kIDsToFind := kStore.ListKeys() + assert.Len(t, kStore2.ListKeys(), len(kIDsToFind)) + for _, k := range kIDsToFind { + assert.Contains(t, kStore2.ListKeys(), k) + } + + commonSubTests(t, kStore2) }