Fix key_store tests

Add key_store support to signer and verifier
This commit is contained in:
Captain ALM 2024-06-09 16:49:57 +01:00
parent 32cfa7a30d
commit 6fbc9e3c1f
Signed by: alfred
GPG Key ID: 4E4ADD02609997B1
5 changed files with 213 additions and 42 deletions

View File

@ -12,20 +12,25 @@ type Signer interface {
Verifier Verifier
GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error)
SignJwt(claims jwt.Claims) (string, error) SignJwt(claims jwt.Claims) (string, error)
GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error)
SignJwtWithKID(claims jwt.Claims, kID string) (string, error)
Issuer() string Issuer() string
PrivateKey() *rsa.PrivateKey PrivateKey() *rsa.PrivateKey
PrivateKeyOf(kID string) *rsa.PrivateKey
} }
// Verifier is used to verify the validity MJWT tokens and extract the claim values. // Verifier is used to verify the validity MJWT tokens and extract the claim values.
type Verifier interface { type Verifier interface {
VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error)
PublicKey() *rsa.PublicKey PublicKey() *rsa.PublicKey
PublicKeyOf(kID string) *rsa.PublicKey
GetKeyStore() KeyStore
} }
// KeyStore is used for the kid header support in Signer and Verifier. // KeyStore is used for the kid header support in Signer and Verifier.
type KeyStore interface { type KeyStore interface {
SetKey(kID string, prvKey *rsa.PrivateKey) bool SetKey(kID string, prvKey *rsa.PrivateKey)
SetKeyPublic(kID string, pubKey *rsa.PublicKey) bool SetKeyPublic(kID string, pubKey *rsa.PublicKey)
RemoveKey(kID string) RemoveKey(kID string)
ListKeys() []string ListKeys() []string
GetKey(kID string) *rsa.PrivateKey GetKey(kID string) *rsa.PrivateKey

View File

@ -32,7 +32,7 @@ func NewMJwtKeyStore() KeyStore {
// NewMJwtKeyStoreFromDirectory loads keys from a directory with the specified extensions to denote public and private // NewMJwtKeyStoreFromDirectory loads keys from a directory with the specified extensions to denote public and private
// rsa keys; the kID is the filename of the key up to the first . // rsa keys; the kID is the filename of the key up to the first .
func NewMJwtKeyStoreFromDirectory(directory string, keyPrvExt string, keyPubExt string) (KeyStore, error) { func NewMJwtKeyStoreFromDirectory(directory, keyPrvExt, keyPubExt string) (KeyStore, error) {
// Create empty KeyStore // Create empty KeyStore
ks := NewMJwtKeyStore().(*defaultMJwtKeyStore) ks := NewMJwtKeyStore().(*defaultMJwtKeyStore)
// List directory contents // List directory contents
@ -78,7 +78,7 @@ func NewMJwtKeyStoreFromDirectory(directory string, keyPrvExt string, keyPubExt
// ExportKeyStore saves all the keys stored in the specified KeyStore into a directory with the specified // ExportKeyStore saves all the keys stored in the specified KeyStore into a directory with the specified
// extensions for public and private keys // extensions for public and private keys
func ExportKeyStore(ks KeyStore, directory string, keyPrvExt string, keyPubExt string) error { func ExportKeyStore(ks KeyStore, directory, keyPrvExt, keyPubExt string) error {
if ks == nil { if ks == nil {
return errors.New("ks is nil") return errors.New("ks is nil")
} }
@ -110,21 +110,21 @@ func ExportKeyStore(ks KeyStore, directory string, keyPrvExt string, keyPubExt s
} }
// SetKey adds a new rsa.PrivateKey with the specified kID to the KeyStore. // SetKey adds a new rsa.PrivateKey with the specified kID to the KeyStore.
func (d *defaultMJwtKeyStore) SetKey(kID string, prvKey *rsa.PrivateKey) bool { func (d *defaultMJwtKeyStore) SetKey(kID string, prvKey *rsa.PrivateKey) {
if d == nil || prvKey == nil { if d == nil || prvKey == nil {
return false return
} }
d.rwLocker.Lock() d.rwLocker.Lock()
defer d.rwLocker.Unlock() defer d.rwLocker.Unlock()
d.store[kID] = prvKey d.store[kID] = prvKey
d.storePub[kID] = &prvKey.PublicKey d.storePub[kID] = &prvKey.PublicKey
return true return
} }
// SetKeyPublic adds a new rsa.PublicKey with the specified kID to the KeyStore. // SetKeyPublic adds a new rsa.PublicKey with the specified kID to the KeyStore.
func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) bool { func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) {
if d == nil || pubKey == nil { if d == nil || pubKey == nil {
return false return
} }
d.rwLocker.Lock() d.rwLocker.Lock()
defer d.rwLocker.Unlock() defer d.rwLocker.Unlock()
@ -133,7 +133,7 @@ func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) bo
d.store[kID] = nil d.store[kID] = nil
} }
d.storePub[kID] = pubKey d.storePub[kID] = pubKey
return true return
} }
// RemoveKey removes a specified kID from the KeyStore. // RemoveKey removes a specified kID from the KeyStore.

View File

@ -14,7 +14,7 @@ import (
const prvExt = "prv" const prvExt = "prv"
const pubExt = "pub" const pubExt = "pub"
func setupTestDir(t *testing.T, genKeys bool) (string, func(t *testing.T)) { func setupTestDirKeyStore(t *testing.T, genKeys bool) (string, func(t *testing.T)) {
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir") tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
assert.NoError(t, err) assert.NoError(t, err)
@ -43,7 +43,7 @@ func setupTestDir(t *testing.T, genKeys bool) (string, func(t *testing.T)) {
} }
} }
func commonSubTests(t *testing.T, kStore KeyStore) { func commonSubTestsKeyStore(t *testing.T, kStore KeyStore) {
key4, err := rsa.GenerateKey(rand.Reader, 2048) key4, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err) assert.NoError(t, err)
@ -54,14 +54,12 @@ func commonSubTests(t *testing.T, kStore KeyStore) {
const extraKID2 = "key5" const extraKID2 = "key5"
t.Run("TestSetKey", func(t *testing.T) { t.Run("TestSetKey", func(t *testing.T) {
b := kStore.SetKey(extraKID1, key4) kStore.SetKey(extraKID1, key4)
assert.True(t, b)
assert.Contains(t, kStore.ListKeys(), extraKID1) assert.Contains(t, kStore.ListKeys(), extraKID1)
}) })
t.Run("TestSetKeyPublic", func(t *testing.T) { t.Run("TestSetKeyPublic", func(t *testing.T) {
b := kStore.SetKeyPublic(extraKID2, &key5.PublicKey) kStore.SetKeyPublic(extraKID2, &key5.PublicKey)
assert.True(t, b)
assert.Contains(t, kStore.ListKeys(), extraKID2) assert.Contains(t, kStore.ListKeys(), extraKID2)
}) })
@ -109,7 +107,7 @@ func commonSubTests(t *testing.T, kStore KeyStore) {
func TestNewMJwtKeyStoreFromDirectory(t *testing.T) { func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
t.Parallel() t.Parallel()
tempDir, cleaner := setupTestDir(t, true) tempDir, cleaner := setupTestDirKeyStore(t, true)
defer cleaner(t) defer cleaner(t)
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, prvExt, pubExt) kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, prvExt, pubExt)
@ -121,15 +119,15 @@ func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
assert.Contains(t, kStore.ListKeys(), k) assert.Contains(t, kStore.ListKeys(), k)
} }
commonSubTests(t, kStore) commonSubTestsKeyStore(t, kStore)
} }
func TestExportKeyStore(t *testing.T) { func TestExportKeyStore(t *testing.T) {
t.Parallel() t.Parallel()
tempDir, cleaner := setupTestDir(t, true) tempDir, cleaner := setupTestDirKeyStore(t, true)
defer cleaner(t) defer cleaner(t)
tempDir2, cleaner2 := setupTestDir(t, false) tempDir2, cleaner2 := setupTestDirKeyStore(t, false)
defer cleaner2(t) defer cleaner2(t)
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, prvExt, pubExt) kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, prvExt, pubExt)
@ -150,5 +148,5 @@ func TestExportKeyStore(t *testing.T) {
assert.Contains(t, kStore2.ListKeys(), k) assert.Contains(t, kStore2.ListKeys(), k)
} }
commonSubTests(t, kStore2) commonSubTestsKeyStore(t, kStore2)
} }

118
signer.go
View File

@ -3,6 +3,7 @@ package mjwt
import ( import (
"bytes" "bytes"
"crypto/rsa" "crypto/rsa"
"errors"
"github.com/1f349/rsa-helper/rsaprivate" "github.com/1f349/rsa-helper/rsaprivate"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
"io" "io"
@ -23,10 +24,16 @@ var _ Verifier = &defaultMJwtSigner{}
// NewMJwtSigner creates a new defaultMJwtSigner using the issuer name and rsa.PrivateKey // NewMJwtSigner creates a new defaultMJwtSigner using the issuer name and rsa.PrivateKey
func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer { func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer {
return NewMJwtSignerWithKeyStore(issuer, key, NewMJwtKeyStore())
}
// NewMJwtSignerWithKeyStore creates a new defaultMJwtSigner using the issuer name, a rsa.PrivateKey
// for no kID and a KeyStore for kID based keys
func NewMJwtSignerWithKeyStore(issuer string, key *rsa.PrivateKey, kStore KeyStore) Signer {
return &defaultMJwtSigner{ return &defaultMJwtSigner{
issuer: issuer, issuer: issuer,
key: key, key: key,
verify: newMJwtVerifier(&key.PublicKey), verify: NewMjwtVerifierWithKeyStore(&key.PublicKey, kStore).(*defaultMJwtVerifier),
} }
} }
@ -44,38 +51,131 @@ func NewMJwtSignerFromFileOrCreate(issuer, file string, random io.Reader, bits i
// NewMJwtSignerFromFile creates a new defaultMJwtSigner using the path of a // NewMJwtSignerFromFile creates a new defaultMJwtSigner using the path of a
// rsa.PrivateKey file. // rsa.PrivateKey file.
func NewMJwtSignerFromFile(issuer, file string) (Signer, error) { func NewMJwtSignerFromFile(issuer, file string) (Signer, error) {
return NewMJwtSignerFromFileAndDirectory(issuer, file, "", "", "")
}
// NewMJwtSignerFromDirectory creates a new defaultMJwtSigner using the path of a directory to
// load the keys into a KeyStore; there is no default rsa.PrivateKey
func NewMJwtSignerFromDirectory(issuer, directory, prvExt, pubExt string) (Signer, error) {
return NewMJwtSignerFromFileAndDirectory(issuer, "", directory, prvExt, pubExt)
}
// NewMJwtSignerFromFileAndDirectory creates a new defaultMJwtSigner using the path of a rsa.PrivateKey
// file as the non kID key and the path of a directory to load the keys into a KeyStore
func NewMJwtSignerFromFileAndDirectory(issuer, file, directory, prvExt, pubExt string) (Signer, error) {
var err error
// read key // read key
key, err := rsaprivate.Read(file) var prv *rsa.PrivateKey = nil
if file != "" {
prv, err = rsaprivate.Read(file)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
// create signer using rsa.PrivateKey // read KeyStore
return NewMJwtSigner(issuer, key), nil var kStore KeyStore = nil
if directory != "" {
kStore, err = NewMJwtKeyStoreFromDirectory(directory, prvExt, pubExt)
if err != nil {
return nil, err
}
}
return NewMJwtSignerWithKeyStore(issuer, prv, kStore), nil
} }
// Issuer returns the name of the issuer // Issuer returns the name of the issuer
func (d *defaultMJwtSigner) Issuer() string { return d.issuer } func (d *defaultMJwtSigner) Issuer() string {
if d == nil {
return ""
}
return d.issuer
}
// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims // GenerateJwt generates and returns a JWT string using the sub, id, duration and claims; uses the default key
func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) { func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
if d == nil {
return "", errors.New("signer nil")
}
return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims)) return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims))
} }
// SignJwt signs a jwt.Claims compatible struct, this is used internally by // SignJwt signs a jwt.Claims compatible struct, this is used internally by
// GenerateJwt but is available for signing custom structs // GenerateJwt but is available for signing custom structs; uses the default key
func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) { func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) {
if d == nil {
return "", errors.New("signer nil")
}
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped) token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
return token.SignedString(d.key) return token.SignedString(d.key)
} }
// GenerateJwtWithKID generates and returns a JWT string using the sub, id, duration and claims; this gets signed with the specified kID
func (d *defaultMJwtSigner) GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error) {
if d == nil {
return "", errors.New("signer nil")
}
return d.SignJwtWithKID(wrapClaims[Claims](d, sub, id, aud, dur, claims), kID)
}
// SignJwtWithKID signs a jwt.Claims compatible struct, this is used internally by
// GenerateJwt but is available for signing custom structs; this gets signed with the specified kID
func (d *defaultMJwtSigner) SignJwtWithKID(wrapped jwt.Claims, kID string) (string, error) {
if d == nil {
return "", errors.New("signer nil")
}
pKey := d.verify.GetKeyStore().GetKey(kID)
if pKey == nil {
return "", errors.New("no private key found")
}
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
token.Header["kid"] = kID
return token.SignedString(pKey)
}
// VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt() // VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt()
func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) { func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
if d == nil {
return nil, errors.New("signer nil")
}
return d.verify.VerifyJwt(token, claims) return d.verify.VerifyJwt(token, claims)
} }
func (d *defaultMJwtSigner) PrivateKey() *rsa.PrivateKey { return d.key } func (d *defaultMJwtSigner) PrivateKey() *rsa.PrivateKey {
func (d *defaultMJwtSigner) PublicKey() *rsa.PublicKey { return d.verify.pub } if d == nil {
return nil
}
return d.key
}
func (d *defaultMJwtSigner) PublicKey() *rsa.PublicKey {
if d == nil {
return nil
}
return d.verify.pub
}
func (d *defaultMJwtSigner) PublicKeyOf(kID string) *rsa.PublicKey {
if d == nil {
return nil
}
return d.verify.kStore.GetKeyPublic(kID)
}
func (d *defaultMJwtSigner) GetKeyStore() KeyStore {
if d == nil {
return nil
}
return d.verify.GetKeyStore()
}
func (d *defaultMJwtSigner) PrivateKeyOf(kID string) *rsa.PrivateKey {
if d == nil {
return nil
}
return d.verify.kStore.GetKey(kID)
}
// readOrCreatePrivateKey returns the private key it the file already exists, // readOrCreatePrivateKey returns the private key it the file already exists,
// generates a new private key and saves it to the file, or returns an error if // generates a new private key and saves it to the file, or returns an error if

View File

@ -2,6 +2,7 @@ package mjwt
import ( import (
"crypto/rsa" "crypto/rsa"
"errors"
"github.com/1f349/rsa-helper/rsapublic" "github.com/1f349/rsa-helper/rsapublic"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
) )
@ -10,35 +11,83 @@ import (
// MJWT tokens // MJWT tokens
type defaultMJwtVerifier struct { type defaultMJwtVerifier struct {
pub *rsa.PublicKey pub *rsa.PublicKey
kStore KeyStore
} }
var _ Verifier = &defaultMJwtVerifier{} var _ Verifier = &defaultMJwtVerifier{}
// NewMJwtVerifier creates a new defaultMJwtVerifier using the rsa.PublicKey // NewMJwtVerifier creates a new defaultMJwtVerifier using the rsa.PublicKey
func NewMJwtVerifier(key *rsa.PublicKey) Verifier { func NewMJwtVerifier(key *rsa.PublicKey) Verifier {
return newMJwtVerifier(key) return NewMjwtVerifierWithKeyStore(key, NewMJwtKeyStore())
} }
func newMJwtVerifier(key *rsa.PublicKey) *defaultMJwtVerifier { // NewMjwtVerifierWithKeyStore creates a new defaultMJwtVerifier using a rsa.PublicKey as the non kID key
return &defaultMJwtVerifier{pub: key} // and a KeyStore for kID based keys
func NewMjwtVerifierWithKeyStore(defaultKey *rsa.PublicKey, kStore KeyStore) Verifier {
return &defaultMJwtVerifier{pub: defaultKey, kStore: kStore}
} }
// NewMJwtVerifierFromFile creates a new defaultMJwtVerifier using the path of a // NewMJwtVerifierFromFile creates a new defaultMJwtVerifier using the path of a
// rsa.PublicKey file // rsa.PublicKey file
func NewMJwtVerifierFromFile(file string) (Verifier, error) { func NewMJwtVerifierFromFile(file string) (Verifier, error) {
return NewMJwtVerifierFromFileAndDirectory(file, "", "", "")
}
// NewMJwtVerifierFromDirectory creates a new defaultMJwtVerifier using the path of a directory to
// load the keys into a KeyStore; there is no default rsa.PublicKey
func NewMJwtVerifierFromDirectory(directory, prvExt, pubExt string) (Verifier, error) {
return NewMJwtVerifierFromFileAndDirectory("", directory, prvExt, pubExt)
}
// NewMJwtVerifierFromFileAndDirectory creates a new defaultMJwtVerifier using the path of a rsa.PublicKey
// file as the non kID key and the path of a directory to load the keys into a KeyStore
func NewMJwtVerifierFromFileAndDirectory(file, directory, prvExt, pubExt string) (Verifier, error) {
var err error
// read key // read key
pub, err := rsapublic.Read(file) var pub *rsa.PublicKey = nil
if file != "" {
pub, err = rsapublic.Read(file)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
// create verifier using rsa.PublicKey // read KeyStore
return NewMJwtVerifier(pub), nil var kStore KeyStore = nil
if directory != "" {
kStore, err = NewMJwtKeyStoreFromDirectory(directory, prvExt, pubExt)
if err != nil {
return nil, err
}
}
return NewMjwtVerifierWithKeyStore(pub, kStore), nil
} }
// VerifyJwt validates and parses MJWT tokens and returns the claims // VerifyJwt validates and parses MJWT tokens and returns the claims
func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) { func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
if d == nil {
return nil, errors.New("verifier nil")
}
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
kIDI, exs := token.Header["kid"]
if exs {
kID, ok := kIDI.(string)
if ok {
key := d.kStore.GetKeyPublic(kID)
if key == nil {
return nil, errors.New("no public key found")
} else {
return key, nil
}
} else {
return nil, errors.New("kid invalid")
}
}
if d.pub == nil {
return nil, errors.New("no public key found")
}
return d.pub, nil return d.pub, nil
}) })
if err != nil { if err != nil {
@ -47,4 +96,23 @@ func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jw
return withClaims, claims.Valid() return withClaims, claims.Valid()
} }
func (d *defaultMJwtVerifier) PublicKey() *rsa.PublicKey { return d.pub } func (d *defaultMJwtVerifier) PublicKey() *rsa.PublicKey {
if d == nil {
return nil
}
return d.pub
}
func (d *defaultMJwtVerifier) PublicKeyOf(kID string) *rsa.PublicKey {
if d == nil {
return nil
}
return d.kStore.GetKeyPublic(kID)
}
func (d *defaultMJwtVerifier) GetKeyStore() KeyStore {
if d == nil {
return nil
}
return d.kStore
}