Finish up tests (mjwt_test).

Fix MJwt func naming.
Move seperate errors.New to a global var
This commit is contained in:
Captain ALM 2024-06-09 19:31:12 +01:00
parent 407f8510b6
commit ce5eccfb3c
Signed by: alfred
GPG Key ID: 4E4ADD02609997B1
4 changed files with 139 additions and 53 deletions

View File

@ -9,6 +9,8 @@ import (
"time" "time"
) )
var mt_ExtraKID = "tester"
type testClaims struct{ TestValue string } type testClaims struct{ TestValue string }
func (t testClaims) Valid() error { func (t testClaims) Valid() error {
@ -31,11 +33,31 @@ func (t testClaims2) Valid() error {
func (t testClaims2) Type() string { return "testClaims2" } func (t testClaims2) Type() string { return "testClaims2" }
func setupTestKeyStoreMJWT(t *testing.T) (ks KeyStore, a, b, c *rsa.PrivateKey) {
ks = NewMJwtKeyStore()
var err error
a, err = rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
ks.SetKey("key1", a)
b, err = rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
ks.SetKey("key2", b)
c, err = rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
ks.SetKey("key3", c)
return
}
func TestExtractClaims(t *testing.T) { func TestExtractClaims(t *testing.T) {
t.Parallel() t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048) kStore, key, _, _ := setupTestKeyStoreMJWT(t)
assert.NoError(t, err)
t.Run("TestNoKID", func(t *testing.T) {
t.Parallel()
s := NewMJwtSigner("mjwt.test", key) s := NewMJwtSigner("mjwt.test", key)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}) token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err) assert.NoError(t, err)
@ -43,13 +65,30 @@ func TestExtractClaims(t *testing.T) {
m := NewMJwtVerifier(&key.PublicKey) m := NewMJwtVerifier(&key.PublicKey)
_, _, err = ExtractClaims[testClaims](m, token) _, _, err = ExtractClaims[testClaims](m, token)
assert.NoError(t, err) assert.NoError(t, err)
})
t.Run("TestKID", func(t *testing.T) {
t.Parallel()
s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore)
token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err)
token2, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}, "key2")
assert.NoError(t, err)
m := NewMJwtVerifierWithKeyStore(&key.PublicKey, kStore)
_, _, err = ExtractClaims[testClaims](m, token1)
assert.NoError(t, err)
_, _, err = ExtractClaims[testClaims](m, token2)
assert.NoError(t, err)
})
} }
func TestExtractClaimsFail(t *testing.T) { func TestExtractClaimsFail(t *testing.T) {
t.Parallel() t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048) kStore, key, key2, _ := setupTestKeyStoreMJWT(t)
assert.NoError(t, err)
t.Run("TestInvalidClaims", func(t *testing.T) {
t.Parallel()
s := NewMJwtSigner("mjwt.test", key) s := NewMJwtSigner("mjwt.test", key)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}) token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
assert.NoError(t, err) assert.NoError(t, err)
@ -58,4 +97,47 @@ func TestExtractClaimsFail(t *testing.T) {
_, _, err = ExtractClaims[testClaims2](m, token) _, _, err = ExtractClaims[testClaims2](m, token)
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrClaimTypeMismatch) assert.ErrorIs(t, err, ErrClaimTypeMismatch)
})
t.Run("TestDefaultKeyNoKID", func(t *testing.T) {
t.Parallel()
s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore)
token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}, "key1")
assert.NoError(t, err)
m := NewMJwtVerifier(&key.PublicKey)
_, _, err = ExtractClaims[testClaims](m, token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoPublicKeyFound)
})
t.Run("TestNoDefaultKey", func(t *testing.T) {
t.Parallel()
s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
assert.NoError(t, err)
m := NewMJwtVerifierWithKeyStore(nil, kStore)
_, _, err = ExtractClaims[testClaims](m, token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoPublicKeyFound)
})
t.Run("TestKIDNonExist", func(t *testing.T) {
t.Parallel()
kStore.SetKey(mt_ExtraKID, key2)
assert.Contains(t, kStore.ListKeys(), mt_ExtraKID)
s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore)
token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}, mt_ExtraKID)
assert.NoError(t, err)
kStore.RemoveKey(mt_ExtraKID)
assert.NotContains(t, kStore.ListKeys(), mt_ExtraKID)
m := NewMJwtVerifierWithKeyStore(&key.PublicKey, kStore)
_, _, err = ExtractClaims[testClaims](m, token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoPublicKeyFound)
})
} }

View File

@ -11,6 +11,9 @@ import (
"time" "time"
) )
var ErrNoPrivateKeyFound = errors.New("no private key found")
var ErrSignerNil = errors.New("signer nil")
// defaultMJwtSigner implements Signer and uses an rsa.PrivateKey and issuer name // defaultMJwtSigner implements Signer and uses an rsa.PrivateKey and issuer name
// to generate MJWT tokens // to generate MJWT tokens
type defaultMJwtSigner struct { type defaultMJwtSigner struct {
@ -37,7 +40,7 @@ func NewMJwtSignerWithKeyStore(issuer string, key *rsa.PrivateKey, kStore KeySto
return &defaultMJwtSigner{ return &defaultMJwtSigner{
issuer: issuer, issuer: issuer,
key: key, key: key,
verify: NewMjwtVerifierWithKeyStore(pKey, kStore).(*defaultMJwtVerifier), verify: NewMJwtVerifierWithKeyStore(pKey, kStore).(*defaultMJwtVerifier),
} }
} }
@ -101,7 +104,7 @@ func (d *defaultMJwtSigner) Issuer() string {
// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims; uses the default key // GenerateJwt generates and returns a JWT string using the sub, id, duration and claims; uses the default key
func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) { func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
if d == nil { if d == nil {
return "", errors.New("signer nil") return "", ErrSignerNil
} }
return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims)) return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims))
} }
@ -110,10 +113,10 @@ func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, du
// GenerateJwt but is available for signing custom structs; uses the default key // GenerateJwt but is available for signing custom structs; uses the default key
func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) { func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) {
if d == nil { if d == nil {
return "", errors.New("signer nil") return "", ErrSignerNil
} }
if d.key == nil { if d.key == nil {
return "", errors.New("no private key found") return "", ErrNoPrivateKeyFound
} }
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped) token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
return token.SignedString(d.key) return token.SignedString(d.key)
@ -122,7 +125,7 @@ func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) {
// GenerateJwtWithKID generates and returns a JWT string using the sub, id, duration and claims; this gets signed with the specified kID // GenerateJwtWithKID generates and returns a JWT string using the sub, id, duration and claims; this gets signed with the specified kID
func (d *defaultMJwtSigner) GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error) { func (d *defaultMJwtSigner) GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error) {
if d == nil { if d == nil {
return "", errors.New("signer nil") return "", ErrSignerNil
} }
return d.SignJwtWithKID(wrapClaims[Claims](d, sub, id, aud, dur, claims), kID) return d.SignJwtWithKID(wrapClaims[Claims](d, sub, id, aud, dur, claims), kID)
} }
@ -131,11 +134,11 @@ func (d *defaultMJwtSigner) GenerateJwtWithKID(sub, id string, aud jwt.ClaimStri
// GenerateJwt but is available for signing custom structs; this gets signed with the specified kID // GenerateJwt but is available for signing custom structs; this gets signed with the specified kID
func (d *defaultMJwtSigner) SignJwtWithKID(wrapped jwt.Claims, kID string) (string, error) { func (d *defaultMJwtSigner) SignJwtWithKID(wrapped jwt.Claims, kID string) (string, error) {
if d == nil { if d == nil {
return "", errors.New("signer nil") return "", ErrSignerNil
} }
pKey := d.verify.GetKeyStore().GetKey(kID) pKey := d.verify.GetKeyStore().GetKey(kID)
if pKey == nil { if pKey == nil {
return "", errors.New("no private key found") return "", ErrNoPrivateKeyFound
} }
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped) token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
token.Header["kid"] = kID token.Header["kid"] = kID
@ -145,7 +148,7 @@ func (d *defaultMJwtSigner) SignJwtWithKID(wrapped jwt.Claims, kID string) (stri
// VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt() // VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt()
func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) { func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
if d == nil { if d == nil {
return nil, errors.New("signer nil") return nil, ErrSignerNil
} }
return d.verify.VerifyJwt(token, claims) return d.verify.VerifyJwt(token, claims)
} }

View File

@ -16,13 +16,11 @@ import (
const st_prvExt = "prv" const st_prvExt = "prv"
const st_pubExt = "pub" const st_pubExt = "pub"
func setupTestDirSigner(t *testing.T, genKeys bool) (string, *rsa.PrivateKey, func(t *testing.T)) { func setupTestDirSigner(t *testing.T) (string, *rsa.PrivateKey, func(t *testing.T)) {
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir") tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
assert.NoError(t, err) assert.NoError(t, err)
var key3 *rsa.PrivateKey = nil var key3 *rsa.PrivateKey = nil
if genKeys {
key1, err := rsa.GenerateKey(rand.Reader, 2048) key1, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err) assert.NoError(t, err)
err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+st_prvExt), key1) err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+st_prvExt), key1)
@ -39,7 +37,6 @@ func setupTestDirSigner(t *testing.T, genKeys bool) (string, *rsa.PrivateKey, fu
assert.NoError(t, err) assert.NoError(t, err)
err = rsapublic.Write(path.Join(tempDir, "key3.pem."+st_pubExt), &key3.PublicKey) err = rsapublic.Write(path.Join(tempDir, "key3.pem."+st_pubExt), &key3.PublicKey)
assert.NoError(t, err) assert.NoError(t, err)
}
return tempDir, key3, func(t *testing.T) { return tempDir, key3, func(t *testing.T) {
err := os.RemoveAll(tempDir) err := os.RemoveAll(tempDir)
@ -118,7 +115,7 @@ func TestReadOrCreatePrivateKey(t *testing.T) {
func TestNewMJwtSignerFromDirectory(t *testing.T) { func TestNewMJwtSignerFromDirectory(t *testing.T) {
t.Parallel() t.Parallel()
tempDir, prvKey3, cleaner := setupTestDirSigner(t, true) tempDir, prvKey3, cleaner := setupTestDirSigner(t)
defer cleaner(t) defer cleaner(t)
signer, err := NewMJwtSignerFromDirectory("Test", tempDir, st_prvExt, st_pubExt) signer, err := NewMJwtSignerFromDirectory("Test", tempDir, st_prvExt, st_pubExt)
@ -135,7 +132,7 @@ func TestNewMJwtSignerFromDirectory(t *testing.T) {
func TestNewMJwtSignerFromFileAndDirectory(t *testing.T) { func TestNewMJwtSignerFromFileAndDirectory(t *testing.T) {
t.Parallel() t.Parallel()
tempDir, prvKey3, cleaner := setupTestDirSigner(t, true) tempDir, prvKey3, cleaner := setupTestDirSigner(t)
defer cleaner(t) defer cleaner(t)
signer, err := NewMJwtSignerFromFileAndDirectory("Test", path.Join(tempDir, "key1.pem."+st_prvExt), tempDir, st_prvExt, st_pubExt) signer, err := NewMJwtSignerFromFileAndDirectory("Test", path.Join(tempDir, "key1.pem."+st_prvExt), tempDir, st_prvExt, st_pubExt)

View File

@ -7,6 +7,10 @@ import (
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
) )
var ErrNoPublicKeyFound = errors.New("no public key found")
var ErrKIDInvalid = errors.New("kid invalid")
var ErrVerifierNil = errors.New("verifier nil")
// defaultMJwtVerifier implements Verifier and uses a rsa.PublicKey to validate // defaultMJwtVerifier implements Verifier and uses a rsa.PublicKey to validate
// MJWT tokens // MJWT tokens
type defaultMJwtVerifier struct { type defaultMJwtVerifier struct {
@ -18,12 +22,12 @@ var _ Verifier = &defaultMJwtVerifier{}
// NewMJwtVerifier creates a new defaultMJwtVerifier using the rsa.PublicKey // NewMJwtVerifier creates a new defaultMJwtVerifier using the rsa.PublicKey
func NewMJwtVerifier(key *rsa.PublicKey) Verifier { func NewMJwtVerifier(key *rsa.PublicKey) Verifier {
return NewMjwtVerifierWithKeyStore(key, NewMJwtKeyStore()) return NewMJwtVerifierWithKeyStore(key, NewMJwtKeyStore())
} }
// NewMjwtVerifierWithKeyStore creates a new defaultMJwtVerifier using a rsa.PublicKey as the non kID key // NewMJwtVerifierWithKeyStore creates a new defaultMJwtVerifier using a rsa.PublicKey as the non kID key
// and a KeyStore for kID based keys // and a KeyStore for kID based keys
func NewMjwtVerifierWithKeyStore(defaultKey *rsa.PublicKey, kStore KeyStore) Verifier { func NewMJwtVerifierWithKeyStore(defaultKey *rsa.PublicKey, kStore KeyStore) Verifier {
return &defaultMJwtVerifier{pub: defaultKey, kStore: kStore} return &defaultMJwtVerifier{pub: defaultKey, kStore: kStore}
} }
@ -62,13 +66,13 @@ func NewMJwtVerifierFromFileAndDirectory(file, directory, prvExt, pubExt string)
} }
} }
return NewMjwtVerifierWithKeyStore(pub, kStore), nil return NewMJwtVerifierWithKeyStore(pub, kStore), nil
} }
// VerifyJwt validates and parses MJWT tokens and returns the claims // VerifyJwt validates and parses MJWT tokens and returns the claims
func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) { func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
if d == nil { if d == nil {
return nil, errors.New("verifier nil") return nil, ErrVerifierNil
} }
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
kIDI, exs := token.Header["kid"] kIDI, exs := token.Header["kid"]
@ -77,16 +81,16 @@ func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jw
if ok { if ok {
key := d.kStore.GetKeyPublic(kID) key := d.kStore.GetKeyPublic(kID)
if key == nil { if key == nil {
return nil, errors.New("no public key found") return nil, ErrNoPublicKeyFound
} else { } else {
return key, nil return key, nil
} }
} else { } else {
return nil, errors.New("kid invalid") return nil, ErrKIDInvalid
} }
} }
if d.pub == nil { if d.pub == nil {
return nil, errors.New("no public key found") return nil, ErrNoPublicKeyFound
} }
return d.pub, nil return d.pub, nil
}) })