2024-06-08 23:57:52 +01:00
|
|
|
package mjwt
|
|
|
|
|
|
|
|
import (
|
|
|
|
"crypto/rand"
|
|
|
|
"crypto/rsa"
|
|
|
|
"github.com/1f349/rsa-helper/rsaprivate"
|
|
|
|
"github.com/1f349/rsa-helper/rsapublic"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"os"
|
|
|
|
"path"
|
|
|
|
"testing"
|
|
|
|
)
|
|
|
|
|
2024-06-09 00:49:27 +01:00
|
|
|
const prvExt = "prv"
|
|
|
|
const pubExt = "pub"
|
|
|
|
|
|
|
|
func setupTestDir(t *testing.T, genKeys bool) (string, func(t *testing.T)) {
|
2024-06-08 23:57:52 +01:00
|
|
|
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
2024-06-09 00:49:27 +01:00
|
|
|
if genKeys {
|
|
|
|
key1, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+prvExt), key1)
|
|
|
|
assert.NoError(t, err)
|
2024-06-08 23:57:52 +01:00
|
|
|
|
2024-06-09 00:49:27 +01:00
|
|
|
key2, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+prvExt), key2)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
err = rsapublic.Write(path.Join(tempDir, "key2.pem."+pubExt), &key2.PublicKey)
|
|
|
|
assert.NoError(t, err)
|
2024-06-08 23:57:52 +01:00
|
|
|
|
2024-06-09 00:49:27 +01:00
|
|
|
key3, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
err = rsapublic.Write(path.Join(tempDir, "key3.pem."+pubExt), &key3.PublicKey)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
}
|
2024-06-08 23:57:52 +01:00
|
|
|
|
2024-06-09 00:49:27 +01:00
|
|
|
return tempDir, func(t *testing.T) {
|
|
|
|
err := os.RemoveAll(tempDir)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func commonSubTests(t *testing.T, kStore KeyStore) {
|
|
|
|
key4, err := rsa.GenerateKey(rand.Reader, 2048)
|
2024-06-08 23:57:52 +01:00
|
|
|
assert.NoError(t, err)
|
2024-06-09 00:49:27 +01:00
|
|
|
|
|
|
|
key5, err := rsa.GenerateKey(rand.Reader, 2048)
|
2024-06-08 23:57:52 +01:00
|
|
|
assert.NoError(t, err)
|
|
|
|
|
2024-06-09 00:49:27 +01:00
|
|
|
const extraKID1 = "key4"
|
|
|
|
const extraKID2 = "key5"
|
|
|
|
|
|
|
|
t.Run("TestSetKey", func(t *testing.T) {
|
|
|
|
b := kStore.SetKey(extraKID1, key4)
|
|
|
|
assert.True(t, b)
|
|
|
|
assert.Contains(t, kStore.ListKeys(), extraKID1)
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("TestSetKeyPublic", func(t *testing.T) {
|
|
|
|
b := kStore.SetKeyPublic(extraKID2, &key5.PublicKey)
|
|
|
|
assert.True(t, b)
|
|
|
|
assert.Contains(t, kStore.ListKeys(), extraKID2)
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("TestGetKey", func(t *testing.T) {
|
|
|
|
oKey := kStore.GetKey(extraKID1)
|
|
|
|
assert.Same(t, key4, oKey)
|
|
|
|
pKey := kStore.GetKey(extraKID2)
|
|
|
|
assert.Nil(t, pKey)
|
|
|
|
aKey := kStore.GetKey("key1")
|
|
|
|
assert.NotNil(t, aKey)
|
|
|
|
bKey := kStore.GetKey("key2")
|
|
|
|
assert.NotNil(t, bKey)
|
|
|
|
cKey := kStore.GetKey("key3")
|
|
|
|
assert.Nil(t, cKey)
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("TestGetKeyPublic", func(t *testing.T) {
|
|
|
|
oKey := kStore.GetKeyPublic(extraKID1)
|
|
|
|
assert.Same(t, &key4.PublicKey, oKey)
|
|
|
|
pKey := kStore.GetKeyPublic(extraKID2)
|
|
|
|
assert.Same(t, &key5.PublicKey, pKey)
|
|
|
|
aKey := kStore.GetKeyPublic("key1")
|
|
|
|
assert.NotNil(t, aKey)
|
|
|
|
bKey := kStore.GetKeyPublic("key2")
|
|
|
|
assert.NotNil(t, bKey)
|
|
|
|
cKey := kStore.GetKeyPublic("key3")
|
|
|
|
assert.NotNil(t, cKey)
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("TestRemoveKey", func(t *testing.T) {
|
|
|
|
kStore.RemoveKey(extraKID1)
|
|
|
|
assert.NotContains(t, kStore.ListKeys(), extraKID1)
|
|
|
|
oKey1 := kStore.GetKey(extraKID1)
|
|
|
|
assert.Nil(t, oKey1)
|
|
|
|
oKey2 := kStore.GetKeyPublic(extraKID1)
|
|
|
|
assert.Nil(t, oKey2)
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("TestClearKeys", func(t *testing.T) {
|
|
|
|
kStore.ClearKeys()
|
|
|
|
assert.Empty(t, kStore.ListKeys())
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
|
|
|
|
tempDir, cleaner := setupTestDir(t, true)
|
|
|
|
defer cleaner(t)
|
|
|
|
|
|
|
|
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, prvExt, pubExt)
|
2024-06-08 23:57:52 +01:00
|
|
|
assert.NoError(t, err)
|
|
|
|
|
|
|
|
assert.Len(t, kStore.ListKeys(), 3)
|
|
|
|
kIDsToFind := []string{"key1", "key2", "key3"}
|
|
|
|
for _, k := range kIDsToFind {
|
|
|
|
assert.Contains(t, kStore.ListKeys(), k)
|
|
|
|
}
|
2024-06-09 00:49:27 +01:00
|
|
|
|
|
|
|
commonSubTests(t, kStore)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestExportKeyStore(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
|
|
|
|
tempDir, cleaner := setupTestDir(t, true)
|
|
|
|
defer cleaner(t)
|
|
|
|
tempDir2, cleaner2 := setupTestDir(t, false)
|
|
|
|
defer cleaner2(t)
|
|
|
|
|
|
|
|
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, prvExt, pubExt)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
|
|
|
const prvExt2 = "v"
|
|
|
|
const pubExt2 = "b"
|
|
|
|
|
|
|
|
err = ExportKeyStore(kStore, tempDir2, prvExt2, pubExt2)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
|
|
|
kStore2, err := NewMJwtKeyStoreFromDirectory(tempDir2, prvExt2, pubExt2)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
|
|
|
kIDsToFind := kStore.ListKeys()
|
|
|
|
assert.Len(t, kStore2.ListKeys(), len(kIDsToFind))
|
|
|
|
for _, k := range kIDsToFind {
|
|
|
|
assert.Contains(t, kStore2.ListKeys(), k)
|
|
|
|
}
|
|
|
|
|
|
|
|
commonSubTests(t, kStore2)
|
2024-06-08 23:57:52 +01:00
|
|
|
}
|