diff --git a/mjwt_test.go b/mjwt_test.go index 3a83fe8..1c059eb 100644 --- a/mjwt_test.go +++ b/mjwt_test.go @@ -9,6 +9,8 @@ import ( "time" ) +var mt_ExtraKID = "tester" + type testClaims struct{ TestValue string } func (t testClaims) Valid() error { @@ -31,31 +33,111 @@ func (t testClaims2) Valid() error { func (t testClaims2) Type() string { return "testClaims2" } +func setupTestKeyStoreMJWT(t *testing.T) (ks KeyStore, a, b, c *rsa.PrivateKey) { + ks = NewMJwtKeyStore() + var err error + + a, err = rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + ks.SetKey("key1", a) + + b, err = rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + ks.SetKey("key2", b) + + c, err = rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + ks.SetKey("key3", c) + + return +} + func TestExtractClaims(t *testing.T) { t.Parallel() - key, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) + kStore, key, _, _ := setupTestKeyStoreMJWT(t) - s := NewMJwtSigner("mjwt.test", key) - token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}) - assert.NoError(t, err) + t.Run("TestNoKID", func(t *testing.T) { + t.Parallel() + s := NewMJwtSigner("mjwt.test", key) + token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}) + assert.NoError(t, err) - m := NewMJwtVerifier(&key.PublicKey) - _, _, err = ExtractClaims[testClaims](m, token) - assert.NoError(t, err) + m := NewMJwtVerifier(&key.PublicKey) + _, _, err = ExtractClaims[testClaims](m, token) + assert.NoError(t, err) + }) + + t.Run("TestKID", func(t *testing.T) { + t.Parallel() + s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore) + token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}) + assert.NoError(t, err) + token2, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}, "key2") + assert.NoError(t, err) + + m := NewMJwtVerifierWithKeyStore(&key.PublicKey, kStore) + _, _, err = ExtractClaims[testClaims](m, token1) + assert.NoError(t, err) + _, _, err = ExtractClaims[testClaims](m, token2) + assert.NoError(t, err) + }) } func TestExtractClaimsFail(t *testing.T) { t.Parallel() - key, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) + kStore, key, key2, _ := setupTestKeyStoreMJWT(t) - s := NewMJwtSigner("mjwt.test", key) - token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}) - assert.NoError(t, err) + t.Run("TestInvalidClaims", func(t *testing.T) { + t.Parallel() + s := NewMJwtSigner("mjwt.test", key) + token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}) + assert.NoError(t, err) - m := NewMJwtVerifier(&key.PublicKey) - _, _, err = ExtractClaims[testClaims2](m, token) - assert.Error(t, err) - assert.ErrorIs(t, err, ErrClaimTypeMismatch) + m := NewMJwtVerifier(&key.PublicKey) + _, _, err = ExtractClaims[testClaims2](m, token) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrClaimTypeMismatch) + }) + + t.Run("TestDefaultKeyNoKID", func(t *testing.T) { + t.Parallel() + s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore) + token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}, "key1") + assert.NoError(t, err) + + m := NewMJwtVerifier(&key.PublicKey) + _, _, err = ExtractClaims[testClaims](m, token) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoPublicKeyFound) + }) + + t.Run("TestNoDefaultKey", func(t *testing.T) { + t.Parallel() + s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore) + token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}) + assert.NoError(t, err) + + m := NewMJwtVerifierWithKeyStore(nil, kStore) + _, _, err = ExtractClaims[testClaims](m, token) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoPublicKeyFound) + }) + + t.Run("TestKIDNonExist", func(t *testing.T) { + t.Parallel() + kStore.SetKey(mt_ExtraKID, key2) + assert.Contains(t, kStore.ListKeys(), mt_ExtraKID) + + s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore) + token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}, mt_ExtraKID) + assert.NoError(t, err) + + kStore.RemoveKey(mt_ExtraKID) + assert.NotContains(t, kStore.ListKeys(), mt_ExtraKID) + + m := NewMJwtVerifierWithKeyStore(&key.PublicKey, kStore) + _, _, err = ExtractClaims[testClaims](m, token) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoPublicKeyFound) + }) } diff --git a/signer.go b/signer.go index 3959619..2679dbd 100644 --- a/signer.go +++ b/signer.go @@ -11,6 +11,9 @@ import ( "time" ) +var ErrNoPrivateKeyFound = errors.New("no private key found") +var ErrSignerNil = errors.New("signer nil") + // defaultMJwtSigner implements Signer and uses an rsa.PrivateKey and issuer name // to generate MJWT tokens type defaultMJwtSigner struct { @@ -37,7 +40,7 @@ func NewMJwtSignerWithKeyStore(issuer string, key *rsa.PrivateKey, kStore KeySto return &defaultMJwtSigner{ issuer: issuer, key: key, - verify: NewMjwtVerifierWithKeyStore(pKey, kStore).(*defaultMJwtVerifier), + verify: NewMJwtVerifierWithKeyStore(pKey, kStore).(*defaultMJwtVerifier), } } @@ -101,7 +104,7 @@ func (d *defaultMJwtSigner) Issuer() string { // GenerateJwt generates and returns a JWT string using the sub, id, duration and claims; uses the default key func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) { if d == nil { - return "", errors.New("signer nil") + return "", ErrSignerNil } return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims)) } @@ -110,10 +113,10 @@ func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, du // GenerateJwt but is available for signing custom structs; uses the default key func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) { if d == nil { - return "", errors.New("signer nil") + return "", ErrSignerNil } if d.key == nil { - return "", errors.New("no private key found") + return "", ErrNoPrivateKeyFound } token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped) return token.SignedString(d.key) @@ -122,7 +125,7 @@ func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) { // GenerateJwtWithKID generates and returns a JWT string using the sub, id, duration and claims; this gets signed with the specified kID func (d *defaultMJwtSigner) GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error) { if d == nil { - return "", errors.New("signer nil") + return "", ErrSignerNil } return d.SignJwtWithKID(wrapClaims[Claims](d, sub, id, aud, dur, claims), kID) } @@ -131,11 +134,11 @@ func (d *defaultMJwtSigner) GenerateJwtWithKID(sub, id string, aud jwt.ClaimStri // GenerateJwt but is available for signing custom structs; this gets signed with the specified kID func (d *defaultMJwtSigner) SignJwtWithKID(wrapped jwt.Claims, kID string) (string, error) { if d == nil { - return "", errors.New("signer nil") + return "", ErrSignerNil } pKey := d.verify.GetKeyStore().GetKey(kID) if pKey == nil { - return "", errors.New("no private key found") + return "", ErrNoPrivateKeyFound } token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped) token.Header["kid"] = kID @@ -145,7 +148,7 @@ func (d *defaultMJwtSigner) SignJwtWithKID(wrapped jwt.Claims, kID string) (stri // VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt() func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) { if d == nil { - return nil, errors.New("signer nil") + return nil, ErrSignerNil } return d.verify.VerifyJwt(token, claims) } diff --git a/signer_test.go b/signer_test.go index 468b8e0..f04fbae 100644 --- a/signer_test.go +++ b/signer_test.go @@ -16,30 +16,27 @@ import ( const st_prvExt = "prv" const st_pubExt = "pub" -func setupTestDirSigner(t *testing.T, genKeys bool) (string, *rsa.PrivateKey, func(t *testing.T)) { +func setupTestDirSigner(t *testing.T) (string, *rsa.PrivateKey, func(t *testing.T)) { tempDir, err := os.MkdirTemp("", "this-is-a-test-dir") assert.NoError(t, err) var key3 *rsa.PrivateKey = nil + key1, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+st_prvExt), key1) + assert.NoError(t, err) - if genKeys { - key1, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+st_prvExt), key1) - assert.NoError(t, err) + key2, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+st_prvExt), key2) + assert.NoError(t, err) + err = rsapublic.Write(path.Join(tempDir, "key2.pem."+st_pubExt), &key2.PublicKey) + assert.NoError(t, err) - key2, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+st_prvExt), key2) - assert.NoError(t, err) - err = rsapublic.Write(path.Join(tempDir, "key2.pem."+st_pubExt), &key2.PublicKey) - assert.NoError(t, err) - - key3, err = rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - err = rsapublic.Write(path.Join(tempDir, "key3.pem."+st_pubExt), &key3.PublicKey) - assert.NoError(t, err) - } + key3, err = rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = rsapublic.Write(path.Join(tempDir, "key3.pem."+st_pubExt), &key3.PublicKey) + assert.NoError(t, err) return tempDir, key3, func(t *testing.T) { err := os.RemoveAll(tempDir) @@ -118,7 +115,7 @@ func TestReadOrCreatePrivateKey(t *testing.T) { func TestNewMJwtSignerFromDirectory(t *testing.T) { t.Parallel() - tempDir, prvKey3, cleaner := setupTestDirSigner(t, true) + tempDir, prvKey3, cleaner := setupTestDirSigner(t) defer cleaner(t) signer, err := NewMJwtSignerFromDirectory("Test", tempDir, st_prvExt, st_pubExt) @@ -135,7 +132,7 @@ func TestNewMJwtSignerFromDirectory(t *testing.T) { func TestNewMJwtSignerFromFileAndDirectory(t *testing.T) { t.Parallel() - tempDir, prvKey3, cleaner := setupTestDirSigner(t, true) + tempDir, prvKey3, cleaner := setupTestDirSigner(t) defer cleaner(t) signer, err := NewMJwtSignerFromFileAndDirectory("Test", path.Join(tempDir, "key1.pem."+st_prvExt), tempDir, st_prvExt, st_pubExt) diff --git a/verifier.go b/verifier.go index b8abe80..05b8a2b 100644 --- a/verifier.go +++ b/verifier.go @@ -7,6 +7,10 @@ import ( "github.com/golang-jwt/jwt/v4" ) +var ErrNoPublicKeyFound = errors.New("no public key found") +var ErrKIDInvalid = errors.New("kid invalid") +var ErrVerifierNil = errors.New("verifier nil") + // defaultMJwtVerifier implements Verifier and uses a rsa.PublicKey to validate // MJWT tokens type defaultMJwtVerifier struct { @@ -18,12 +22,12 @@ var _ Verifier = &defaultMJwtVerifier{} // NewMJwtVerifier creates a new defaultMJwtVerifier using the rsa.PublicKey func NewMJwtVerifier(key *rsa.PublicKey) Verifier { - return NewMjwtVerifierWithKeyStore(key, NewMJwtKeyStore()) + return NewMJwtVerifierWithKeyStore(key, NewMJwtKeyStore()) } -// NewMjwtVerifierWithKeyStore creates a new defaultMJwtVerifier using a rsa.PublicKey as the non kID key +// NewMJwtVerifierWithKeyStore creates a new defaultMJwtVerifier using a rsa.PublicKey as the non kID key // and a KeyStore for kID based keys -func NewMjwtVerifierWithKeyStore(defaultKey *rsa.PublicKey, kStore KeyStore) Verifier { +func NewMJwtVerifierWithKeyStore(defaultKey *rsa.PublicKey, kStore KeyStore) Verifier { return &defaultMJwtVerifier{pub: defaultKey, kStore: kStore} } @@ -62,13 +66,13 @@ func NewMJwtVerifierFromFileAndDirectory(file, directory, prvExt, pubExt string) } } - return NewMjwtVerifierWithKeyStore(pub, kStore), nil + return NewMJwtVerifierWithKeyStore(pub, kStore), nil } // VerifyJwt validates and parses MJWT tokens and returns the claims func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) { if d == nil { - return nil, errors.New("verifier nil") + return nil, ErrVerifierNil } withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { kIDI, exs := token.Header["kid"] @@ -77,16 +81,16 @@ func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jw if ok { key := d.kStore.GetKeyPublic(kID) if key == nil { - return nil, errors.New("no public key found") + return nil, ErrNoPublicKeyFound } else { return key, nil } } else { - return nil, errors.New("kid invalid") + return nil, ErrKIDInvalid } } if d.pub == nil { - return nil, errors.New("no public key found") + return nil, ErrNoPublicKeyFound } return d.pub, nil })