Fix key store issues + tests.

This commit is contained in:
Captain ALM 2024-06-09 00:49:27 +01:00
parent 3a7b3dd250
commit 32cfa7a30d
Signed by: alfred
GPG Key ID: 4E4ADD02609997B1
3 changed files with 131 additions and 29 deletions

View File

@ -26,7 +26,7 @@ type Verifier interface {
type KeyStore interface { type KeyStore interface {
SetKey(kID string, prvKey *rsa.PrivateKey) bool SetKey(kID string, prvKey *rsa.PrivateKey) bool
SetKeyPublic(kID string, pubKey *rsa.PublicKey) bool SetKeyPublic(kID string, pubKey *rsa.PublicKey) bool
RemoveKey(kID string) bool RemoveKey(kID string)
ListKeys() []string ListKeys() []string
GetKey(kID string) *rsa.PrivateKey GetKey(kID string) *rsa.PrivateKey
GetKeyPublic(kID string) *rsa.PublicKey GetKeyPublic(kID string) *rsa.PublicKey

View File

@ -21,8 +21,8 @@ type defaultMJwtKeyStore struct {
var _ KeyStore = &defaultMJwtKeyStore{} var _ KeyStore = &defaultMJwtKeyStore{}
// newDefaultMJwtKeyStore creates a new defaultMJwtKeyStore. // NewMJwtKeyStore creates a new defaultMJwtKeyStore.
func newDefaultMJwtKeyStore() *defaultMJwtKeyStore { func NewMJwtKeyStore() KeyStore {
return &defaultMJwtKeyStore{ return &defaultMJwtKeyStore{
rwLocker: new(sync.RWMutex), rwLocker: new(sync.RWMutex),
store: make(map[string]*rsa.PrivateKey), store: make(map[string]*rsa.PrivateKey),
@ -30,16 +30,11 @@ func newDefaultMJwtKeyStore() *defaultMJwtKeyStore {
} }
} }
// NewMJwtKeyStore creates a new defaultMJwtKeyStore.
func NewMJwtKeyStore() KeyStore {
return newDefaultMJwtKeyStore()
}
// NewMJwtKeyStoreFromDirectory loads keys from a directory with the specified extensions to denote public and private // NewMJwtKeyStoreFromDirectory loads keys from a directory with the specified extensions to denote public and private
// rsa keys; the kID is the filename of the key up to the first . // rsa keys; the kID is the filename of the key up to the first .
func NewMJwtKeyStoreFromDirectory(directory string, keyPrvExt string, keyPubExt string) (KeyStore, error) { func NewMJwtKeyStoreFromDirectory(directory string, keyPrvExt string, keyPubExt string) (KeyStore, error) {
// Create empty KeyStore // Create empty KeyStore
ks := newDefaultMJwtKeyStore() ks := NewMJwtKeyStore().(*defaultMJwtKeyStore)
// List directory contents // List directory contents
dirEntries, err := os.ReadDir(directory) dirEntries, err := os.ReadDir(directory)
if err != nil { if err != nil {
@ -142,15 +137,15 @@ func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) bo
} }
// RemoveKey removes a specified kID from the KeyStore. // RemoveKey removes a specified kID from the KeyStore.
func (d *defaultMJwtKeyStore) RemoveKey(kID string) bool { func (d *defaultMJwtKeyStore) RemoveKey(kID string) {
if d == nil { if d == nil {
return false return
} }
d.rwLocker.Lock() d.rwLocker.Lock()
defer d.rwLocker.Unlock() defer d.rwLocker.Unlock()
delete(d.store, kID) delete(d.store, kID)
delete(d.storePub, kID) delete(d.storePub, kID)
return true return
} }
// ListKeys lists the kIDs of all the keys in the KeyStore. // ListKeys lists the kIDs of all the keys in the KeyStore.

View File

@ -11,14 +11,14 @@ import (
"testing" "testing"
) )
func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
t.Parallel()
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
assert.NoError(t, err)
const prvExt = "prv" const prvExt = "prv"
const pubExt = "pub" const pubExt = "pub"
func setupTestDir(t *testing.T, genKeys bool) (string, func(t *testing.T)) {
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
assert.NoError(t, err)
if genKeys {
key1, err := rsa.GenerateKey(rand.Reader, 2048) key1, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err) assert.NoError(t, err)
err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+prvExt), key1) err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+prvExt), key1)
@ -35,8 +35,84 @@ func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
err = rsapublic.Write(path.Join(tempDir, "key3.pem."+pubExt), &key3.PublicKey) err = rsapublic.Write(path.Join(tempDir, "key3.pem."+pubExt), &key3.PublicKey)
assert.NoError(t, err) assert.NoError(t, err)
}
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, "prv", "pub") return tempDir, func(t *testing.T) {
err := os.RemoveAll(tempDir)
assert.NoError(t, err)
}
}
func commonSubTests(t *testing.T, kStore KeyStore) {
key4, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
key5, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
const extraKID1 = "key4"
const extraKID2 = "key5"
t.Run("TestSetKey", func(t *testing.T) {
b := kStore.SetKey(extraKID1, key4)
assert.True(t, b)
assert.Contains(t, kStore.ListKeys(), extraKID1)
})
t.Run("TestSetKeyPublic", func(t *testing.T) {
b := kStore.SetKeyPublic(extraKID2, &key5.PublicKey)
assert.True(t, b)
assert.Contains(t, kStore.ListKeys(), extraKID2)
})
t.Run("TestGetKey", func(t *testing.T) {
oKey := kStore.GetKey(extraKID1)
assert.Same(t, key4, oKey)
pKey := kStore.GetKey(extraKID2)
assert.Nil(t, pKey)
aKey := kStore.GetKey("key1")
assert.NotNil(t, aKey)
bKey := kStore.GetKey("key2")
assert.NotNil(t, bKey)
cKey := kStore.GetKey("key3")
assert.Nil(t, cKey)
})
t.Run("TestGetKeyPublic", func(t *testing.T) {
oKey := kStore.GetKeyPublic(extraKID1)
assert.Same(t, &key4.PublicKey, oKey)
pKey := kStore.GetKeyPublic(extraKID2)
assert.Same(t, &key5.PublicKey, pKey)
aKey := kStore.GetKeyPublic("key1")
assert.NotNil(t, aKey)
bKey := kStore.GetKeyPublic("key2")
assert.NotNil(t, bKey)
cKey := kStore.GetKeyPublic("key3")
assert.NotNil(t, cKey)
})
t.Run("TestRemoveKey", func(t *testing.T) {
kStore.RemoveKey(extraKID1)
assert.NotContains(t, kStore.ListKeys(), extraKID1)
oKey1 := kStore.GetKey(extraKID1)
assert.Nil(t, oKey1)
oKey2 := kStore.GetKeyPublic(extraKID1)
assert.Nil(t, oKey2)
})
t.Run("TestClearKeys", func(t *testing.T) {
kStore.ClearKeys()
assert.Empty(t, kStore.ListKeys())
})
}
func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
t.Parallel()
tempDir, cleaner := setupTestDir(t, true)
defer cleaner(t)
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, prvExt, pubExt)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, kStore.ListKeys(), 3) assert.Len(t, kStore.ListKeys(), 3)
@ -44,4 +120,35 @@ func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
for _, k := range kIDsToFind { for _, k := range kIDsToFind {
assert.Contains(t, kStore.ListKeys(), k) assert.Contains(t, kStore.ListKeys(), k)
} }
commonSubTests(t, kStore)
}
func TestExportKeyStore(t *testing.T) {
t.Parallel()
tempDir, cleaner := setupTestDir(t, true)
defer cleaner(t)
tempDir2, cleaner2 := setupTestDir(t, false)
defer cleaner2(t)
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, prvExt, pubExt)
assert.NoError(t, err)
const prvExt2 = "v"
const pubExt2 = "b"
err = ExportKeyStore(kStore, tempDir2, prvExt2, pubExt2)
assert.NoError(t, err)
kStore2, err := NewMJwtKeyStoreFromDirectory(tempDir2, prvExt2, pubExt2)
assert.NoError(t, err)
kIDsToFind := kStore.ListKeys()
assert.Len(t, kStore2.ListKeys(), len(kIDsToFind))
for _, k := range kIDsToFind {
assert.Contains(t, kStore2.ListKeys(), k)
}
commonSubTests(t, kStore2)
} }