Fix key_store tests

Add key_store support to signer and verifier
This commit is contained in:
Captain ALM 2024-06-09 16:49:57 +01:00
parent 32cfa7a30d
commit 6fbc9e3c1f
Signed by: alfred
GPG Key ID: 4E4ADD02609997B1
5 changed files with 213 additions and 42 deletions

View File

@ -12,20 +12,25 @@ type Signer interface {
Verifier
GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error)
SignJwt(claims jwt.Claims) (string, error)
GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error)
SignJwtWithKID(claims jwt.Claims, kID string) (string, error)
Issuer() string
PrivateKey() *rsa.PrivateKey
PrivateKeyOf(kID string) *rsa.PrivateKey
}
// Verifier is used to verify the validity MJWT tokens and extract the claim values.
type Verifier interface {
VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error)
PublicKey() *rsa.PublicKey
PublicKeyOf(kID string) *rsa.PublicKey
GetKeyStore() KeyStore
}
// KeyStore is used for the kid header support in Signer and Verifier.
type KeyStore interface {
SetKey(kID string, prvKey *rsa.PrivateKey) bool
SetKeyPublic(kID string, pubKey *rsa.PublicKey) bool
SetKey(kID string, prvKey *rsa.PrivateKey)
SetKeyPublic(kID string, pubKey *rsa.PublicKey)
RemoveKey(kID string)
ListKeys() []string
GetKey(kID string) *rsa.PrivateKey

View File

@ -32,7 +32,7 @@ func NewMJwtKeyStore() KeyStore {
// NewMJwtKeyStoreFromDirectory loads keys from a directory with the specified extensions to denote public and private
// rsa keys; the kID is the filename of the key up to the first .
func NewMJwtKeyStoreFromDirectory(directory string, keyPrvExt string, keyPubExt string) (KeyStore, error) {
func NewMJwtKeyStoreFromDirectory(directory, keyPrvExt, keyPubExt string) (KeyStore, error) {
// Create empty KeyStore
ks := NewMJwtKeyStore().(*defaultMJwtKeyStore)
// List directory contents
@ -78,7 +78,7 @@ func NewMJwtKeyStoreFromDirectory(directory string, keyPrvExt string, keyPubExt
// ExportKeyStore saves all the keys stored in the specified KeyStore into a directory with the specified
// extensions for public and private keys
func ExportKeyStore(ks KeyStore, directory string, keyPrvExt string, keyPubExt string) error {
func ExportKeyStore(ks KeyStore, directory, keyPrvExt, keyPubExt string) error {
if ks == nil {
return errors.New("ks is nil")
}
@ -110,21 +110,21 @@ func ExportKeyStore(ks KeyStore, directory string, keyPrvExt string, keyPubExt s
}
// SetKey adds a new rsa.PrivateKey with the specified kID to the KeyStore.
func (d *defaultMJwtKeyStore) SetKey(kID string, prvKey *rsa.PrivateKey) bool {
func (d *defaultMJwtKeyStore) SetKey(kID string, prvKey *rsa.PrivateKey) {
if d == nil || prvKey == nil {
return false
return
}
d.rwLocker.Lock()
defer d.rwLocker.Unlock()
d.store[kID] = prvKey
d.storePub[kID] = &prvKey.PublicKey
return true
return
}
// SetKeyPublic adds a new rsa.PublicKey with the specified kID to the KeyStore.
func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) bool {
func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) {
if d == nil || pubKey == nil {
return false
return
}
d.rwLocker.Lock()
defer d.rwLocker.Unlock()
@ -133,7 +133,7 @@ func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) bo
d.store[kID] = nil
}
d.storePub[kID] = pubKey
return true
return
}
// RemoveKey removes a specified kID from the KeyStore.

View File

@ -14,7 +14,7 @@ import (
const prvExt = "prv"
const pubExt = "pub"
func setupTestDir(t *testing.T, genKeys bool) (string, func(t *testing.T)) {
func setupTestDirKeyStore(t *testing.T, genKeys bool) (string, func(t *testing.T)) {
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
assert.NoError(t, err)
@ -43,7 +43,7 @@ func setupTestDir(t *testing.T, genKeys bool) (string, func(t *testing.T)) {
}
}
func commonSubTests(t *testing.T, kStore KeyStore) {
func commonSubTestsKeyStore(t *testing.T, kStore KeyStore) {
key4, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
@ -54,14 +54,12 @@ func commonSubTests(t *testing.T, kStore KeyStore) {
const extraKID2 = "key5"
t.Run("TestSetKey", func(t *testing.T) {
b := kStore.SetKey(extraKID1, key4)
assert.True(t, b)
kStore.SetKey(extraKID1, key4)
assert.Contains(t, kStore.ListKeys(), extraKID1)
})
t.Run("TestSetKeyPublic", func(t *testing.T) {
b := kStore.SetKeyPublic(extraKID2, &key5.PublicKey)
assert.True(t, b)
kStore.SetKeyPublic(extraKID2, &key5.PublicKey)
assert.Contains(t, kStore.ListKeys(), extraKID2)
})
@ -109,7 +107,7 @@ func commonSubTests(t *testing.T, kStore KeyStore) {
func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
t.Parallel()
tempDir, cleaner := setupTestDir(t, true)
tempDir, cleaner := setupTestDirKeyStore(t, true)
defer cleaner(t)
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, prvExt, pubExt)
@ -121,15 +119,15 @@ func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
assert.Contains(t, kStore.ListKeys(), k)
}
commonSubTests(t, kStore)
commonSubTestsKeyStore(t, kStore)
}
func TestExportKeyStore(t *testing.T) {
t.Parallel()
tempDir, cleaner := setupTestDir(t, true)
tempDir, cleaner := setupTestDirKeyStore(t, true)
defer cleaner(t)
tempDir2, cleaner2 := setupTestDir(t, false)
tempDir2, cleaner2 := setupTestDirKeyStore(t, false)
defer cleaner2(t)
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, prvExt, pubExt)
@ -150,5 +148,5 @@ func TestExportKeyStore(t *testing.T) {
assert.Contains(t, kStore2.ListKeys(), k)
}
commonSubTests(t, kStore2)
commonSubTestsKeyStore(t, kStore2)
}

122
signer.go
View File

@ -3,6 +3,7 @@ package mjwt
import (
"bytes"
"crypto/rsa"
"errors"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/golang-jwt/jwt/v4"
"io"
@ -23,10 +24,16 @@ var _ Verifier = &defaultMJwtSigner{}
// NewMJwtSigner creates a new defaultMJwtSigner using the issuer name and rsa.PrivateKey
func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer {
return NewMJwtSignerWithKeyStore(issuer, key, NewMJwtKeyStore())
}
// NewMJwtSignerWithKeyStore creates a new defaultMJwtSigner using the issuer name, a rsa.PrivateKey
// for no kID and a KeyStore for kID based keys
func NewMJwtSignerWithKeyStore(issuer string, key *rsa.PrivateKey, kStore KeyStore) Signer {
return &defaultMJwtSigner{
issuer: issuer,
key: key,
verify: newMJwtVerifier(&key.PublicKey),
verify: NewMjwtVerifierWithKeyStore(&key.PublicKey, kStore).(*defaultMJwtVerifier),
}
}
@ -44,38 +51,131 @@ func NewMJwtSignerFromFileOrCreate(issuer, file string, random io.Reader, bits i
// NewMJwtSignerFromFile creates a new defaultMJwtSigner using the path of a
// rsa.PrivateKey file.
func NewMJwtSignerFromFile(issuer, file string) (Signer, error) {
return NewMJwtSignerFromFileAndDirectory(issuer, file, "", "", "")
}
// NewMJwtSignerFromDirectory creates a new defaultMJwtSigner using the path of a directory to
// load the keys into a KeyStore; there is no default rsa.PrivateKey
func NewMJwtSignerFromDirectory(issuer, directory, prvExt, pubExt string) (Signer, error) {
return NewMJwtSignerFromFileAndDirectory(issuer, "", directory, prvExt, pubExt)
}
// NewMJwtSignerFromFileAndDirectory creates a new defaultMJwtSigner using the path of a rsa.PrivateKey
// file as the non kID key and the path of a directory to load the keys into a KeyStore
func NewMJwtSignerFromFileAndDirectory(issuer, file, directory, prvExt, pubExt string) (Signer, error) {
var err error
// read key
key, err := rsaprivate.Read(file)
if err != nil {
return nil, err
var prv *rsa.PrivateKey = nil
if file != "" {
prv, err = rsaprivate.Read(file)
if err != nil {
return nil, err
}
}
// create signer using rsa.PrivateKey
return NewMJwtSigner(issuer, key), nil
// read KeyStore
var kStore KeyStore = nil
if directory != "" {
kStore, err = NewMJwtKeyStoreFromDirectory(directory, prvExt, pubExt)
if err != nil {
return nil, err
}
}
return NewMJwtSignerWithKeyStore(issuer, prv, kStore), nil
}
// Issuer returns the name of the issuer
func (d *defaultMJwtSigner) Issuer() string { return d.issuer }
func (d *defaultMJwtSigner) Issuer() string {
if d == nil {
return ""
}
return d.issuer
}
// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims
// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims; uses the default key
func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
if d == nil {
return "", errors.New("signer nil")
}
return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims))
}
// SignJwt signs a jwt.Claims compatible struct, this is used internally by
// GenerateJwt but is available for signing custom structs
// GenerateJwt but is available for signing custom structs; uses the default key
func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) {
if d == nil {
return "", errors.New("signer nil")
}
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
return token.SignedString(d.key)
}
// GenerateJwtWithKID generates and returns a JWT string using the sub, id, duration and claims; this gets signed with the specified kID
func (d *defaultMJwtSigner) GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error) {
if d == nil {
return "", errors.New("signer nil")
}
return d.SignJwtWithKID(wrapClaims[Claims](d, sub, id, aud, dur, claims), kID)
}
// SignJwtWithKID signs a jwt.Claims compatible struct, this is used internally by
// GenerateJwt but is available for signing custom structs; this gets signed with the specified kID
func (d *defaultMJwtSigner) SignJwtWithKID(wrapped jwt.Claims, kID string) (string, error) {
if d == nil {
return "", errors.New("signer nil")
}
pKey := d.verify.GetKeyStore().GetKey(kID)
if pKey == nil {
return "", errors.New("no private key found")
}
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
token.Header["kid"] = kID
return token.SignedString(pKey)
}
// VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt()
func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
if d == nil {
return nil, errors.New("signer nil")
}
return d.verify.VerifyJwt(token, claims)
}
func (d *defaultMJwtSigner) PrivateKey() *rsa.PrivateKey { return d.key }
func (d *defaultMJwtSigner) PublicKey() *rsa.PublicKey { return d.verify.pub }
func (d *defaultMJwtSigner) PrivateKey() *rsa.PrivateKey {
if d == nil {
return nil
}
return d.key
}
func (d *defaultMJwtSigner) PublicKey() *rsa.PublicKey {
if d == nil {
return nil
}
return d.verify.pub
}
func (d *defaultMJwtSigner) PublicKeyOf(kID string) *rsa.PublicKey {
if d == nil {
return nil
}
return d.verify.kStore.GetKeyPublic(kID)
}
func (d *defaultMJwtSigner) GetKeyStore() KeyStore {
if d == nil {
return nil
}
return d.verify.GetKeyStore()
}
func (d *defaultMJwtSigner) PrivateKeyOf(kID string) *rsa.PrivateKey {
if d == nil {
return nil
}
return d.verify.kStore.GetKey(kID)
}
// readOrCreatePrivateKey returns the private key it the file already exists,
// generates a new private key and saves it to the file, or returns an error if

View File

@ -2,6 +2,7 @@ package mjwt
import (
"crypto/rsa"
"errors"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/golang-jwt/jwt/v4"
)
@ -9,36 +10,84 @@ import (
// defaultMJwtVerifier implements Verifier and uses a rsa.PublicKey to validate
// MJWT tokens
type defaultMJwtVerifier struct {
pub *rsa.PublicKey
pub *rsa.PublicKey
kStore KeyStore
}
var _ Verifier = &defaultMJwtVerifier{}
// NewMJwtVerifier creates a new defaultMJwtVerifier using the rsa.PublicKey
func NewMJwtVerifier(key *rsa.PublicKey) Verifier {
return newMJwtVerifier(key)
return NewMjwtVerifierWithKeyStore(key, NewMJwtKeyStore())
}
func newMJwtVerifier(key *rsa.PublicKey) *defaultMJwtVerifier {
return &defaultMJwtVerifier{pub: key}
// NewMjwtVerifierWithKeyStore creates a new defaultMJwtVerifier using a rsa.PublicKey as the non kID key
// and a KeyStore for kID based keys
func NewMjwtVerifierWithKeyStore(defaultKey *rsa.PublicKey, kStore KeyStore) Verifier {
return &defaultMJwtVerifier{pub: defaultKey, kStore: kStore}
}
// NewMJwtVerifierFromFile creates a new defaultMJwtVerifier using the path of a
// rsa.PublicKey file
func NewMJwtVerifierFromFile(file string) (Verifier, error) {
return NewMJwtVerifierFromFileAndDirectory(file, "", "", "")
}
// NewMJwtVerifierFromDirectory creates a new defaultMJwtVerifier using the path of a directory to
// load the keys into a KeyStore; there is no default rsa.PublicKey
func NewMJwtVerifierFromDirectory(directory, prvExt, pubExt string) (Verifier, error) {
return NewMJwtVerifierFromFileAndDirectory("", directory, prvExt, pubExt)
}
// NewMJwtVerifierFromFileAndDirectory creates a new defaultMJwtVerifier using the path of a rsa.PublicKey
// file as the non kID key and the path of a directory to load the keys into a KeyStore
func NewMJwtVerifierFromFileAndDirectory(file, directory, prvExt, pubExt string) (Verifier, error) {
var err error
// read key
pub, err := rsapublic.Read(file)
if err != nil {
return nil, err
var pub *rsa.PublicKey = nil
if file != "" {
pub, err = rsapublic.Read(file)
if err != nil {
return nil, err
}
}
// create verifier using rsa.PublicKey
return NewMJwtVerifier(pub), nil
// read KeyStore
var kStore KeyStore = nil
if directory != "" {
kStore, err = NewMJwtKeyStoreFromDirectory(directory, prvExt, pubExt)
if err != nil {
return nil, err
}
}
return NewMjwtVerifierWithKeyStore(pub, kStore), nil
}
// VerifyJwt validates and parses MJWT tokens and returns the claims
func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
if d == nil {
return nil, errors.New("verifier nil")
}
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
kIDI, exs := token.Header["kid"]
if exs {
kID, ok := kIDI.(string)
if ok {
key := d.kStore.GetKeyPublic(kID)
if key == nil {
return nil, errors.New("no public key found")
} else {
return key, nil
}
} else {
return nil, errors.New("kid invalid")
}
}
if d.pub == nil {
return nil, errors.New("no public key found")
}
return d.pub, nil
})
if err != nil {
@ -47,4 +96,23 @@ func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jw
return withClaims, claims.Valid()
}
func (d *defaultMJwtVerifier) PublicKey() *rsa.PublicKey { return d.pub }
func (d *defaultMJwtVerifier) PublicKey() *rsa.PublicKey {
if d == nil {
return nil
}
return d.pub
}
func (d *defaultMJwtVerifier) PublicKeyOf(kID string) *rsa.PublicKey {
if d == nil {
return nil
}
return d.kStore.GetKeyPublic(kID)
}
func (d *defaultMJwtVerifier) GetKeyStore() KeyStore {
if d == nil {
return nil
}
return d.kStore
}