diff --git a/auth/access-token_test.go b/auth/access-token_test.go index acf53aa..11b2523 100644 --- a/auth/access-token_test.go +++ b/auth/access-token_test.go @@ -31,3 +31,29 @@ func TestCreateAccessToken(t *testing.T) { assert.True(t, b.Claims.Perms.Has("mjwt:test2")) assert.False(t, b.Claims.Perms.Has("mjwt:test3")) } + +func TestCreateAccessTokenInvalid(t *testing.T) { + t.Parallel() + key, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + kStore := mjwt.NewMJwtKeyStore() + kStore.SetKey("test", key) + + ps := claims.NewPermStorage() + ps.Set("mjwt:test") + ps.Set("mjwt:test2") + + s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore) + + accessToken, err := CreateAccessTokenWithKID(s, "1", "test", nil, ps, "test") + assert.NoError(t, err) + + _, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken) + assert.NoError(t, err) + assert.Equal(t, "1", b.Subject) + assert.Equal(t, "test", b.ID) + assert.True(t, b.Claims.Perms.Has("mjwt:test")) + assert.True(t, b.Claims.Perms.Has("mjwt:test2")) + assert.False(t, b.Claims.Perms.Has("mjwt:test3")) +} diff --git a/auth/pair_test.go b/auth/pair_test.go index 5a7dd77..0b9d135 100644 --- a/auth/pair_test.go +++ b/auth/pair_test.go @@ -36,3 +36,34 @@ func TestCreateTokenPair(t *testing.T) { assert.Equal(t, "1", b2.Subject) assert.Equal(t, "test2", b2.ID) } + +func TestCreateTokenPairWithKID(t *testing.T) { + t.Parallel() + key, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + kStore := mjwt.NewMJwtKeyStore() + kStore.SetKey("test", key) + + ps := claims.NewPermStorage() + ps.Set("mjwt:test") + ps.Set("mjwt:test2") + + s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore) + + accessToken, refreshToken, err := CreateTokenPairWithKID(s, "1", "test", "test2", nil, nil, ps, "test") + assert.NoError(t, err) + + _, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken) + assert.NoError(t, err) + assert.Equal(t, "1", b.Subject) + assert.Equal(t, "test", b.ID) + assert.True(t, b.Claims.Perms.Has("mjwt:test")) + assert.True(t, b.Claims.Perms.Has("mjwt:test2")) + assert.False(t, b.Claims.Perms.Has("mjwt:test3")) + + _, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken) + assert.NoError(t, err) + assert.Equal(t, "1", b2.Subject) + assert.Equal(t, "test2", b2.ID) +} diff --git a/auth/refresh-token_test.go b/auth/refresh-token_test.go index 4765c35..bcb7521 100644 --- a/auth/refresh-token_test.go +++ b/auth/refresh-token_test.go @@ -24,3 +24,23 @@ func TestCreateRefreshToken(t *testing.T) { assert.Equal(t, "test", b.ID) assert.Equal(t, "test2", b.Claims.AccessTokenId) } + +func TestCreateRefreshTokenWithKID(t *testing.T) { + t.Parallel() + key, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + kStore := mjwt.NewMJwtKeyStore() + kStore.SetKey("test", key) + + s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore) + + refreshToken, err := CreateRefreshTokenWithKID(s, "1", "test", "test2", nil, "test") + assert.NoError(t, err) + + _, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken) + assert.NoError(t, err) + assert.Equal(t, "1", b.Subject) + assert.Equal(t, "test", b.ID) + assert.Equal(t, "test2", b.Claims.AccessTokenId) +}