diff --git a/key_store.go b/key_store.go index fa39ad2..b0a0b1a 100644 --- a/key_store.go +++ b/key_store.go @@ -2,6 +2,7 @@ package mjwt import ( "crypto/rsa" + "errors" "github.com/1f349/rsa-helper/rsaprivate" "github.com/1f349/rsa-helper/rsapublic" "os" @@ -20,8 +21,8 @@ type defaultMJwtKeyStore struct { var _ KeyStore = &defaultMJwtKeyStore{} -// NewMJwtKeyStore creates a new defaultMJwtKeyStore. -func NewMJwtKeyStore() KeyStore { +// newDefaultMJwtKeyStore creates a new defaultMJwtKeyStore. +func newDefaultMJwtKeyStore() *defaultMJwtKeyStore { return &defaultMJwtKeyStore{ rwLocker: new(sync.RWMutex), store: make(map[string]*rsa.PrivateKey), @@ -29,11 +30,16 @@ func NewMJwtKeyStore() KeyStore { } } +// NewMJwtKeyStore creates a new defaultMJwtKeyStore. +func NewMJwtKeyStore() KeyStore { + return newDefaultMJwtKeyStore() +} + // NewMJwtKeyStoreFromDirectory loads keys from a directory with the specified extensions to denote public and private // rsa keys; the kID is the filename of the key up to the first . func NewMJwtKeyStoreFromDirectory(directory string, keyPrvExt string, keyPubExt string) (KeyStore, error) { // Create empty KeyStore - ks := NewMJwtKeyStore() + ks := newDefaultMJwtKeyStore() // List directory contents dirEntries, err := os.ReadDir(directory) if err != nil { @@ -46,22 +52,66 @@ func NewMJwtKeyStoreFromDirectory(directory string, keyPrvExt string, keyPubExt lastDotIdx := strings.LastIndex(entry.Name(), ".") if firstDotIdx > 0 && lastDotIdx+1 < len(entry.Name()) { if entry.Name()[lastDotIdx+1:] == keyPrvExt { + kID := entry.Name()[:firstDotIdx] // Load rsa private key with the file name as the kID (Up to the first .) - key, err := rsaprivate.Read(path.Join(directory, entry.Name())) - if err == nil { - ks.SetKey(entry.Name()[:firstDotIdx], key) + key, err2 := rsaprivate.Read(path.Join(directory, entry.Name())) + if err2 == nil { + ks.store[kID] = key + ks.storePub[kID] = &key.PublicKey + } else { + err = err2 } } else if entry.Name()[lastDotIdx+1:] == keyPubExt { + kID := entry.Name()[:firstDotIdx] // Load rsa public key with the file name as the kID (Up to the first .) - key, err := rsapublic.Read(path.Join(directory, entry.Name())) - if err == nil { - ks.SetKeyPublic(entry.Name()[:firstDotIdx], key) + key, err2 := rsapublic.Read(path.Join(directory, entry.Name())) + if err2 == nil { + _, exs := ks.store[kID] + if !exs { + ks.store[kID] = nil + } + ks.storePub[kID] = key + } else { + err = err2 } } } } } - return ks, nil + return ks, err +} + +// ExportKeyStore saves all the keys stored in the specified KeyStore into a directory with the specified +// extensions for public and private keys +func ExportKeyStore(ks KeyStore, directory string, keyPrvExt string, keyPubExt string) error { + if ks == nil { + return errors.New("ks is nil") + } + + // Create directory + err := os.MkdirAll(directory, 0700) + if err != nil { + return err + } + + // Export all keys + for _, kID := range ks.ListKeys() { + kPrv := ks.GetKey(kID) + if kPrv != nil { + err2 := rsaprivate.Write(path.Join(directory, kID+"."+keyPrvExt), kPrv) + if err2 != nil { + err = err2 + } + } + kPub := ks.GetKeyPublic(kID) + if kPub != nil { + err2 := rsapublic.Write(path.Join(directory, kID+"."+keyPubExt), kPub) + if err2 != nil { + err = err2 + } + } + } + return err } // SetKey adds a new rsa.PrivateKey with the specified kID to the KeyStore. @@ -83,7 +133,10 @@ func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) bo } d.rwLocker.Lock() defer d.rwLocker.Unlock() - delete(d.store, kID) + _, exs := d.store[kID] + if !exs { + d.store[kID] = nil + } d.storePub[kID] = pubKey return true } @@ -111,6 +164,7 @@ func (d *defaultMJwtKeyStore) ListKeys() []string { i := 0 for k := range d.store { lKeys[i] = k + i++ } return lKeys } @@ -150,10 +204,6 @@ func (d *defaultMJwtKeyStore) ClearKeys() { } d.rwLocker.Lock() defer d.rwLocker.Unlock() - for k := range d.store { - delete(d.store, k) - } - for k := range d.storePub { - delete(d.storePub, k) - } + clear(d.store) + clear(d.storePub) } diff --git a/key_store_test.go b/key_store_test.go new file mode 100644 index 0000000..d79333e --- /dev/null +++ b/key_store_test.go @@ -0,0 +1,47 @@ +package mjwt + +import ( + "crypto/rand" + "crypto/rsa" + "github.com/1f349/rsa-helper/rsaprivate" + "github.com/1f349/rsa-helper/rsapublic" + "github.com/stretchr/testify/assert" + "os" + "path" + "testing" +) + +func TestNewMJwtKeyStoreFromDirectory(t *testing.T) { + t.Parallel() + tempDir, err := os.MkdirTemp("", "this-is-a-test-dir") + assert.NoError(t, err) + + const prvExt = "prv" + const pubExt = "pub" + + key1, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+prvExt), key1) + assert.NoError(t, err) + + key2, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+prvExt), key2) + assert.NoError(t, err) + err = rsapublic.Write(path.Join(tempDir, "key2.pem."+pubExt), &key2.PublicKey) + assert.NoError(t, err) + + key3, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsapublic.Write(path.Join(tempDir, "key3.pem."+pubExt), &key3.PublicKey) + assert.NoError(t, err) + + kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, "prv", "pub") + assert.NoError(t, err) + + assert.Len(t, kStore.ListKeys(), 3) + kIDsToFind := []string{"key1", "key2", "key3"} + for _, k := range kIDsToFind { + assert.Contains(t, kStore.ListKeys(), k) + } +}