mirror of
https://github.com/1f349/mjwt.git
synced 2025-01-20 21:46:34 +00:00
Rewrite mjwt library to better support keystores
This commit is contained in:
parent
5d1bd6f8fd
commit
cd2d80cb09
@ -2,14 +2,13 @@ package auth
|
||||
|
||||
import (
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/mjwt/claims"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AccessTokenClaims contains the JWT claims for an access token
|
||||
type AccessTokenClaims struct {
|
||||
Perms *claims.PermStorage `json:"per"`
|
||||
Perms *PermStorage `json:"per"`
|
||||
}
|
||||
|
||||
func (a AccessTokenClaims) Valid() error { return nil }
|
||||
@ -17,21 +16,11 @@ func (a AccessTokenClaims) Valid() error { return nil }
|
||||
func (a AccessTokenClaims) Type() string { return "access-token" }
|
||||
|
||||
// CreateAccessToken creates an access token with the default 15 minute duration
|
||||
func CreateAccessToken(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
|
||||
func CreateAccessToken(p *mjwt.Issuer, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
|
||||
return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, aud, perms)
|
||||
}
|
||||
|
||||
// CreateAccessTokenWithDuration creates an access token with a custom duration
|
||||
func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
|
||||
func CreateAccessTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
|
||||
return p.GenerateJwt(sub, id, aud, dur, &AccessTokenClaims{Perms: perms})
|
||||
}
|
||||
|
||||
// CreateAccessTokenWithKID creates an access token with the default 15 minute duration and the specified kID
|
||||
func CreateAccessTokenWithKID(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) {
|
||||
return CreateAccessTokenWithDurationAndKID(p, time.Minute*15, sub, id, aud, perms, kID)
|
||||
}
|
||||
|
||||
// CreateAccessTokenWithDurationAndKID creates an access token with a custom duration and the specified kID
|
||||
func CreateAccessTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) {
|
||||
return p.GenerateJwtWithKID(sub, id, aud, dur, &AccessTokenClaims{Perms: perms}, kID)
|
||||
}
|
||||
|
@ -1,55 +1,26 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/mjwt/claims"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCreateAccessToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ps := claims.NewPermStorage()
|
||||
ps := NewPermStorage()
|
||||
ps.Set("mjwt:test")
|
||||
ps.Set("mjwt:test2")
|
||||
|
||||
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
||||
kStore := mjwt.NewKeyStore()
|
||||
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
accessToken, err := CreateAccessToken(s, "1", "test", nil, ps)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b.Subject)
|
||||
assert.Equal(t, "test", b.ID)
|
||||
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
|
||||
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
||||
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
||||
}
|
||||
|
||||
func TestCreateAccessTokenInvalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
kStore := mjwt.NewMJwtKeyStore()
|
||||
kStore.SetKey("test", key)
|
||||
|
||||
ps := claims.NewPermStorage()
|
||||
ps.Set("mjwt:test")
|
||||
ps.Set("mjwt:test2")
|
||||
|
||||
s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)
|
||||
|
||||
accessToken, err := CreateAccessTokenWithKID(s, "1", "test", nil, ps, "test")
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b.Subject)
|
||||
assert.Equal(t, "test", b.ID)
|
||||
|
25
auth/pair.go
25
auth/pair.go
@ -2,20 +2,19 @@ package auth
|
||||
|
||||
import (
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/mjwt/claims"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CreateTokenPair creates an access and refresh token pair using the default
|
||||
// 15 minute and 7 day durations respectively
|
||||
func CreateTokenPair(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
|
||||
func CreateTokenPair(p *mjwt.Issuer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
|
||||
return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms)
|
||||
}
|
||||
|
||||
// CreateTokenPairWithDuration creates an access and refresh token pair using
|
||||
// custom durations for the access and refresh tokens
|
||||
func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
|
||||
func CreateTokenPairWithDuration(p *mjwt.Issuer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
|
||||
accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, aud, perms)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
@ -26,23 +25,3 @@ func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Durat
|
||||
}
|
||||
return accessToken, refreshToken, nil
|
||||
}
|
||||
|
||||
// CreateTokenPairWithKID creates an access and refresh token pair using the default
|
||||
// 15 minute and 7 day durations respectively using the specified kID
|
||||
func CreateTokenPairWithKID(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) {
|
||||
return CreateTokenPairWithDurationAndKID(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms, kID)
|
||||
}
|
||||
|
||||
// CreateTokenPairWithDurationAndKID creates an access and refresh token pair using
|
||||
// custom durations for the access and refresh tokens
|
||||
func CreateTokenPairWithDurationAndKID(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) {
|
||||
accessToken, err := CreateAccessTokenWithDurationAndKID(p, accessDur, sub, id, aud, perms, kID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
refreshToken, err := CreateRefreshTokenWithDurationAndKID(p, refreshDur, sub, rId, id, rAud, kID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return accessToken, refreshToken, nil
|
||||
}
|
||||
|
@ -1,29 +1,26 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/mjwt/claims"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCreateTokenPair(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ps := claims.NewPermStorage()
|
||||
ps := NewPermStorage()
|
||||
ps.Set("mjwt:test")
|
||||
ps.Set("mjwt:test2")
|
||||
|
||||
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
||||
kStore := mjwt.NewKeyStore()
|
||||
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key2", kStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", "test2", nil, nil, ps)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b.Subject)
|
||||
assert.Equal(t, "test", b.ID)
|
||||
@ -31,38 +28,7 @@ func TestCreateTokenPair(t *testing.T) {
|
||||
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
||||
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
||||
|
||||
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b2.Subject)
|
||||
assert.Equal(t, "test2", b2.ID)
|
||||
}
|
||||
|
||||
func TestCreateTokenPairWithKID(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
kStore := mjwt.NewMJwtKeyStore()
|
||||
kStore.SetKey("test", key)
|
||||
|
||||
ps := claims.NewPermStorage()
|
||||
ps.Set("mjwt:test")
|
||||
ps.Set("mjwt:test2")
|
||||
|
||||
s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)
|
||||
|
||||
accessToken, refreshToken, err := CreateTokenPairWithKID(s, "1", "test", "test2", nil, nil, ps, "test")
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b.Subject)
|
||||
assert.Equal(t, "test", b.ID)
|
||||
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
|
||||
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
||||
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
||||
|
||||
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
||||
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b2.Subject)
|
||||
assert.Equal(t, "test2", b2.ID)
|
||||
|
@ -1,4 +1,4 @@
|
||||
package claims
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bufio"
|
@ -1,4 +1,4 @@
|
||||
package claims
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
@ -16,21 +16,11 @@ func (r RefreshTokenClaims) Valid() error { return nil }
|
||||
func (r RefreshTokenClaims) Type() string { return "refresh-token" }
|
||||
|
||||
// CreateRefreshToken creates a refresh token with the default 7 day duration
|
||||
func CreateRefreshToken(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
||||
func CreateRefreshToken(p *mjwt.Issuer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
||||
return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, ati, aud)
|
||||
}
|
||||
|
||||
// CreateRefreshTokenWithDuration creates a refresh token with a custom duration
|
||||
func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
||||
func CreateRefreshTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
||||
return p.GenerateJwt(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati})
|
||||
}
|
||||
|
||||
// CreateRefreshTokenWithKID creates a refresh token with the default 7 day duration and the specified kID
|
||||
func CreateRefreshTokenWithKID(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) {
|
||||
return CreateRefreshTokenWithDurationAndKID(p, time.Hour*24*7, sub, id, ati, aud, kID)
|
||||
}
|
||||
|
||||
// CreateRefreshTokenWithDurationAndKID creates a refresh token with a custom duration and the specified kID
|
||||
func CreateRefreshTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) {
|
||||
return p.GenerateJwtWithKID(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati}, kID)
|
||||
}
|
||||
|
@ -1,8 +1,6 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
@ -10,35 +8,15 @@ import (
|
||||
|
||||
func TestCreateRefreshToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
||||
kStore := mjwt.NewKeyStore()
|
||||
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
refreshToken, err := CreateRefreshToken(s, "1", "test", "test2", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b.Subject)
|
||||
assert.Equal(t, "test", b.ID)
|
||||
assert.Equal(t, "test2", b.Claims.AccessTokenId)
|
||||
}
|
||||
|
||||
func TestCreateRefreshTokenWithKID(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
kStore := mjwt.NewMJwtKeyStore()
|
||||
kStore.SetKey("test", key)
|
||||
|
||||
s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)
|
||||
|
||||
refreshToken, err := CreateRefreshTokenWithKID(s, "1", "test", "test2", nil, "test")
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
||||
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b.Subject)
|
||||
assert.Equal(t, "test", b.ID)
|
||||
|
@ -10,11 +10,11 @@ import (
|
||||
var ErrClaimTypeMismatch = errors.New("claim type mismatch")
|
||||
|
||||
// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct
|
||||
func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
|
||||
func wrapClaims[T Claims](sub, id, issuer string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
|
||||
now := time.Now()
|
||||
return (&BaseTypeClaims[T]{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: p.Issuer(),
|
||||
Issuer: issuer,
|
||||
Subject: sub,
|
||||
Audience: aud,
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(dur)),
|
||||
@ -28,12 +28,12 @@ func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur ti
|
||||
|
||||
// ExtractClaims uses a Verifier to validate the MJWT token and returns the parsed
|
||||
// token and BaseTypeClaims
|
||||
func ExtractClaims[T Claims](p Verifier, token string) (*jwt.Token, BaseTypeClaims[T], error) {
|
||||
func ExtractClaims[T Claims](ks *KeyStore, token string) (*jwt.Token, BaseTypeClaims[T], error) {
|
||||
b := BaseTypeClaims[T]{
|
||||
RegisteredClaims: jwt.RegisteredClaims{},
|
||||
Claims: *new(T),
|
||||
}
|
||||
tok, err := p.VerifyJwt(token, &b)
|
||||
tok, err := ks.VerifyJwt(token, &b)
|
||||
return tok, b, err
|
||||
}
|
||||
|
100
claims_test.go
Normal file
100
claims_test.go
Normal file
@ -0,0 +1,100 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type testClaims struct{ TestValue string }
|
||||
|
||||
func (t testClaims) Valid() error {
|
||||
if t.TestValue != "hello" && t.TestValue != "world" {
|
||||
return fmt.Errorf("TestValue should be hello")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t testClaims) Type() string { return "testClaims" }
|
||||
|
||||
type testClaims2 struct{ TestValue2 string }
|
||||
|
||||
func (t testClaims2) Valid() error {
|
||||
if t.TestValue2 != "world" {
|
||||
return fmt.Errorf("TestValue2 should be world")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t testClaims2) Type() string { return "testClaims2" }
|
||||
|
||||
func TestExtractClaims(t *testing.T) {
|
||||
t.Parallel()
|
||||
kStore := NewKeyStore()
|
||||
|
||||
t.Run("TestNoKID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
s, err := NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
|
||||
assert.NoError(t, err)
|
||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
a, _, err := ExtractClaims[testClaims](kStore, token)
|
||||
assert.NoError(t, err)
|
||||
kid, _ := a.Header["kid"].(string)
|
||||
assert.Equal(t, "key1", kid)
|
||||
})
|
||||
|
||||
t.Run("TestKID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
s, err := NewIssuerWithKeyStore("mjwt.test", "key2", kStore)
|
||||
assert.NoError(t, err)
|
||||
s2, err := NewIssuerWithKeyStore("mjwt.test", "key3", kStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
||||
assert.NoError(t, err)
|
||||
token2, err := s2.GenerateJwt("2", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
k1, _, err := ExtractClaims[testClaims](kStore, token1)
|
||||
assert.NoError(t, err)
|
||||
k2, _, err := ExtractClaims[testClaims](kStore, token2)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, k1.Header["kid"], k2.Header["kid"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractClaimsFail(t *testing.T) {
|
||||
t.Parallel()
|
||||
kStore := NewKeyStore()
|
||||
|
||||
t.Run("TestInvalidClaims", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
s, err := NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
|
||||
assert.NoError(t, err)
|
||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, _, err = ExtractClaims[testClaims2](kStore, token)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrClaimTypeMismatch)
|
||||
})
|
||||
|
||||
t.Run("TestKIDNonExist", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s, err := NewIssuerWithKeyStore("mjwt.test", "key2", kStore)
|
||||
assert.NoError(t, err)
|
||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
kStore.RemoveKey("key2")
|
||||
assert.NotContains(t, kStore.ListKeys(), "key2")
|
||||
|
||||
_, _, err = ExtractClaims[testClaims](kStore, token)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrMissingPublicKey)
|
||||
})
|
||||
}
|
@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/mjwt/auth"
|
||||
"github.com/1f349/mjwt/claims"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/subcommands"
|
||||
@ -35,7 +34,7 @@ func (s *accessCmd) SetFlags(f *flag.FlagSet) {
|
||||
f.StringVar(&s.id, "id", "", "MJWT ID")
|
||||
f.StringVar(&s.audience, "aud", "", "Comma separated audience items for the MJWT")
|
||||
f.StringVar(&s.duration, "dur", "15m", "Duration for the MJWT (default: 15m)")
|
||||
f.StringVar(&s.kID, "kid", "\x00", "The Key ID of the signing key")
|
||||
f.StringVar(&s.kID, "kid", "", "The Key ID of the signing key")
|
||||
}
|
||||
|
||||
func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
|
||||
@ -51,7 +50,7 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
|
||||
return subcommands.ExitFailure
|
||||
}
|
||||
|
||||
ps := claims.NewPermStorage()
|
||||
ps := auth.NewPermStorage()
|
||||
for i := 1; i < len(args); i++ {
|
||||
ps.Set(args[i])
|
||||
}
|
||||
@ -67,16 +66,16 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
|
||||
}
|
||||
|
||||
var token string
|
||||
if s.kID == "\x00" {
|
||||
signer := mjwt.NewMJwtSigner(s.issuer, key)
|
||||
token, err = signer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
|
||||
} else {
|
||||
kStore := mjwt.NewMJwtKeyStore()
|
||||
kStore.SetKey(s.kID, key)
|
||||
signer := mjwt.NewMJwtSignerWithKeyStore(s.issuer, nil, kStore)
|
||||
token, err = signer.GenerateJwtWithKID(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps}, s.kID)
|
||||
|
||||
kStore := mjwt.NewKeyStore()
|
||||
kStore.LoadPrivateKey(s.kID, key)
|
||||
|
||||
issuer, err := mjwt.NewIssuerWithKeyStore(s.issuer, s.kID, kStore)
|
||||
if err != nil {
|
||||
panic("this should not fail")
|
||||
}
|
||||
|
||||
token, err = issuer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Error: Failed to generate MJWT token: ", err)
|
||||
return subcommands.ExitFailure
|
||||
|
@ -1,7 +1,7 @@
|
||||
package claims
|
||||
package mjwt
|
||||
|
||||
// EmptyClaims contains no claims
|
||||
type EmptyClaims struct {}
|
||||
type EmptyClaims struct{}
|
||||
|
||||
func (e EmptyClaims) Valid() error { return nil }
|
||||
|
8
go.mod
8
go.mod
@ -10,11 +10,17 @@ require (
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0
|
||||
github.com/google/subcommands v1.2.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/spf13/afero v1.11.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
golang.org/x/sync v0.7.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.12.0 // indirect
|
||||
golang.org/x/text v0.16.0 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
)
|
||||
|
25
go.sum
25
go.sum
@ -2,19 +2,38 @@ github.com/1f349/rsa-helper v0.0.2 h1:N/fLQqg5wrjIzG6G4zdwa5Xcv9/jIPutCls9YekZr9
|
||||
github.com/1f349/rsa-helper v0.0.2/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg=
|
||||
github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA=
|
||||
github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
|
||||
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
@ -1,39 +0,0 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Signer is used to generate MJWT tokens.
|
||||
// Signer can also be used as a Verifier.
|
||||
type Signer interface {
|
||||
Verifier
|
||||
GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error)
|
||||
SignJwt(claims jwt.Claims) (string, error)
|
||||
GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error)
|
||||
SignJwtWithKID(claims jwt.Claims, kID string) (string, error)
|
||||
Issuer() string
|
||||
PrivateKey() *rsa.PrivateKey
|
||||
PrivateKeyOf(kID string) *rsa.PrivateKey
|
||||
}
|
||||
|
||||
// Verifier is used to verify the validity MJWT tokens and extract the claim values.
|
||||
type Verifier interface {
|
||||
VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error)
|
||||
PublicKey() *rsa.PublicKey
|
||||
PublicKeyOf(kID string) *rsa.PublicKey
|
||||
GetKeyStore() KeyStore
|
||||
}
|
||||
|
||||
// KeyStore is used for the kid header support in Signer and Verifier.
|
||||
type KeyStore interface {
|
||||
SetKey(kID string, prvKey *rsa.PrivateKey)
|
||||
SetKeyPublic(kID string, pubKey *rsa.PublicKey)
|
||||
RemoveKey(kID string)
|
||||
ListKeys() []string
|
||||
GetKey(kID string) *rsa.PrivateKey
|
||||
GetKeyPublic(kID string) *rsa.PublicKey
|
||||
ClearKeys()
|
||||
}
|
49
issuer.go
Normal file
49
issuer.go
Normal file
@ -0,0 +1,49 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Issuer struct {
|
||||
issuer string
|
||||
kid string
|
||||
keystore *KeyStore
|
||||
}
|
||||
|
||||
func NewIssuer(name, kid string) (*Issuer, error) {
|
||||
return NewIssuerWithKeyStore(name, kid, NewKeyStore())
|
||||
}
|
||||
|
||||
func NewIssuerWithKeyStore(name, kid string, keystore *KeyStore) (*Issuer, error) {
|
||||
i := &Issuer{name, kid, keystore}
|
||||
if i.keystore.HasPrivateKey(kid) {
|
||||
return i, nil
|
||||
}
|
||||
key, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
i.keystore.LoadPrivateKey(kid, key)
|
||||
return i, i.keystore.SaveSingleKey(kid)
|
||||
}
|
||||
|
||||
func (i *Issuer) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
|
||||
return i.SignJwt(wrapClaims[Claims](sub, id, i.issuer, aud, dur, claims))
|
||||
}
|
||||
|
||||
func (i *Issuer) SignJwt(wrapped jwt.Claims) (string, error) {
|
||||
key, err := i.PrivateKey()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
|
||||
token.Header["kid"] = i.kid
|
||||
return token.SignedString(key)
|
||||
}
|
||||
|
||||
func (i *Issuer) PrivateKey() (*rsa.PrivateKey, error) {
|
||||
return i.keystore.GetPrivateKey(i.kid)
|
||||
}
|
58
issuer_test.go
Normal file
58
issuer_test.go
Normal file
@ -0,0 +1,58 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewIssuer(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("generate missing key for issuer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
kStore := NewKeyStore()
|
||||
issuer, err := NewIssuerWithKeyStore("Test", "test", kStore)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, kStore.HasPrivateKey("test"))
|
||||
assert.True(t, kStore.HasPublicKey("test"))
|
||||
assert.Equal(t, "Test", issuer.issuer)
|
||||
assert.Equal(t, "test", issuer.kid)
|
||||
})
|
||||
t.Run("use existing issuer key", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
kStore := NewKeyStore()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
kStore.LoadPrivateKey("test", key)
|
||||
issuer, err := NewIssuerWithKeyStore("Test", "test", kStore)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, kStore.HasPrivateKey("test"))
|
||||
assert.True(t, kStore.HasPublicKey("test"))
|
||||
assert.Equal(t, "Test", issuer.issuer)
|
||||
assert.Equal(t, "test", issuer.kid)
|
||||
privateKey, err := issuer.PrivateKey()
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, key.Equal(privateKey))
|
||||
})
|
||||
t.Run("generate missing key in filesystem", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := afero.NewMemMapFs()
|
||||
kStore := NewKeyStoreWithDir(dir)
|
||||
issuer, err := NewIssuerWithKeyStore("Test", "test", kStore)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, kStore.HasPrivateKey("test"))
|
||||
assert.True(t, kStore.HasPublicKey("test"))
|
||||
assert.Equal(t, "Test", issuer.issuer)
|
||||
assert.Equal(t, "test", issuer.kid)
|
||||
privKeyFile, err := dir.Open("test.private.pem")
|
||||
assert.NoError(t, err)
|
||||
privKey, err := rsaprivate.Decode(privKeyFile)
|
||||
assert.NoError(t, err)
|
||||
key, err := issuer.PrivateKey()
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, key.Equal(privKey))
|
||||
})
|
||||
}
|
185
key_store.go
185
key_store.go
@ -1,185 +0,0 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/1f349/rsa-helper/rsapublic"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// defaultMJwtKeyStore implements KeyStore and stores kIDs against just rsa.PublicKey
|
||||
// or with rsa.PrivateKey instances as well.
|
||||
type defaultMJwtKeyStore struct {
|
||||
rwLocker *sync.RWMutex
|
||||
store map[string]*rsa.PrivateKey
|
||||
storePub map[string]*rsa.PublicKey
|
||||
}
|
||||
|
||||
var _ KeyStore = &defaultMJwtKeyStore{}
|
||||
|
||||
// NewMJwtKeyStore creates a new defaultMJwtKeyStore.
|
||||
func NewMJwtKeyStore() KeyStore {
|
||||
return &defaultMJwtKeyStore{
|
||||
rwLocker: new(sync.RWMutex),
|
||||
store: make(map[string]*rsa.PrivateKey),
|
||||
storePub: make(map[string]*rsa.PublicKey),
|
||||
}
|
||||
}
|
||||
|
||||
// NewMJwtKeyStoreFromDirectory loads keys from a directory with the specified extensions to denote public and private
|
||||
// rsa keys; the kID is the filename of the key up to the first .
|
||||
func NewMJwtKeyStoreFromDirectory(directory, keyPrvExt, keyPubExt string) (KeyStore, error) {
|
||||
// Create empty KeyStore
|
||||
ks := NewMJwtKeyStore().(*defaultMJwtKeyStore)
|
||||
// List directory contents
|
||||
dirEntries, err := os.ReadDir(directory)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
errs := make([]error, 0, len(dirEntries)/2)
|
||||
// Import keys from files, based on extension
|
||||
for _, entry := range dirEntries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
kID, _, _ := strings.Cut(entry.Name(), ".")
|
||||
if kID == "" {
|
||||
continue
|
||||
}
|
||||
pExt := path.Ext(entry.Name())
|
||||
if pExt == "."+keyPrvExt {
|
||||
// Load rsa private key with the file name as the kID (Up to the first .)
|
||||
key, err2 := rsaprivate.Read(path.Join(directory, entry.Name()))
|
||||
if err2 == nil {
|
||||
ks.store[kID] = key
|
||||
ks.storePub[kID] = &key.PublicKey
|
||||
}
|
||||
errs = append(errs, err2)
|
||||
} else if pExt == "."+keyPubExt {
|
||||
// Load rsa public key with the file name as the kID (Up to the first .)
|
||||
key, err2 := rsapublic.Read(path.Join(directory, entry.Name()))
|
||||
if err2 == nil {
|
||||
_, exs := ks.store[kID]
|
||||
if !exs {
|
||||
ks.store[kID] = nil
|
||||
}
|
||||
ks.storePub[kID] = key
|
||||
}
|
||||
errs = append(errs, err2)
|
||||
}
|
||||
}
|
||||
return ks, errors.Join(errs...)
|
||||
}
|
||||
|
||||
// ExportKeyStore saves all the keys stored in the specified KeyStore into a directory with the specified
|
||||
// extensions for public and private keys
|
||||
func ExportKeyStore(ks KeyStore, directory, keyPrvExt, keyPubExt string) error {
|
||||
if ks == nil {
|
||||
return errors.New("ks is nil")
|
||||
}
|
||||
|
||||
// Create directory
|
||||
err := os.MkdirAll(directory, 0700)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
errs := make([]error, 0, len(ks.ListKeys())/2)
|
||||
// Export all keys
|
||||
for _, kID := range ks.ListKeys() {
|
||||
kPrv := ks.GetKey(kID)
|
||||
if kPrv != nil {
|
||||
err2 := rsaprivate.Write(path.Join(directory, kID+"."+keyPrvExt), kPrv)
|
||||
errs = append(errs, err2)
|
||||
}
|
||||
kPub := ks.GetKeyPublic(kID)
|
||||
if kPub != nil {
|
||||
err2 := rsapublic.Write(path.Join(directory, kID+"."+keyPubExt), kPub)
|
||||
errs = append(errs, err2)
|
||||
}
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// SetKey adds a new rsa.PrivateKey with the specified kID to the KeyStore.
|
||||
func (d *defaultMJwtKeyStore) SetKey(kID string, prvKey *rsa.PrivateKey) {
|
||||
if prvKey == nil {
|
||||
return
|
||||
}
|
||||
d.rwLocker.Lock()
|
||||
defer d.rwLocker.Unlock()
|
||||
d.store[kID] = prvKey
|
||||
d.storePub[kID] = &prvKey.PublicKey
|
||||
return
|
||||
}
|
||||
|
||||
// SetKeyPublic adds a new rsa.PublicKey with the specified kID to the KeyStore.
|
||||
func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) {
|
||||
if pubKey == nil {
|
||||
return
|
||||
}
|
||||
d.rwLocker.Lock()
|
||||
defer d.rwLocker.Unlock()
|
||||
_, exs := d.store[kID]
|
||||
if !exs {
|
||||
d.store[kID] = nil
|
||||
}
|
||||
d.storePub[kID] = pubKey
|
||||
return
|
||||
}
|
||||
|
||||
// RemoveKey removes a specified kID from the KeyStore.
|
||||
func (d *defaultMJwtKeyStore) RemoveKey(kID string) {
|
||||
d.rwLocker.Lock()
|
||||
defer d.rwLocker.Unlock()
|
||||
delete(d.store, kID)
|
||||
delete(d.storePub, kID)
|
||||
return
|
||||
}
|
||||
|
||||
// ListKeys lists the kIDs of all the keys in the KeyStore.
|
||||
func (d *defaultMJwtKeyStore) ListKeys() []string {
|
||||
d.rwLocker.RLock()
|
||||
defer d.rwLocker.RUnlock()
|
||||
lKeys := make([]string, len(d.store))
|
||||
i := 0
|
||||
for k := range d.store {
|
||||
lKeys[i] = k
|
||||
i++
|
||||
}
|
||||
return lKeys
|
||||
}
|
||||
|
||||
// GetKey gets the rsa.PrivateKey given the kID in the KeyStore or null if not found.
|
||||
func (d *defaultMJwtKeyStore) GetKey(kID string) *rsa.PrivateKey {
|
||||
d.rwLocker.RLock()
|
||||
defer d.rwLocker.RUnlock()
|
||||
kPrv, ok := d.store[kID]
|
||||
if ok {
|
||||
return kPrv
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetKeyPublic gets the rsa.PublicKey given the kID in the KeyStore or null if not found.
|
||||
func (d *defaultMJwtKeyStore) GetKeyPublic(kID string) *rsa.PublicKey {
|
||||
d.rwLocker.RLock()
|
||||
defer d.rwLocker.RUnlock()
|
||||
kPub, ok := d.storePub[kID]
|
||||
if ok {
|
||||
return kPub
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearKeys removes all the stored keys in the KeyStore.
|
||||
func (d *defaultMJwtKeyStore) ClearKeys() {
|
||||
d.rwLocker.Lock()
|
||||
defer d.rwLocker.Unlock()
|
||||
clear(d.store)
|
||||
clear(d.storePub)
|
||||
}
|
@ -1,152 +0,0 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/1f349/rsa-helper/rsapublic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const kst_prvExt = "prv"
|
||||
const kst_pubExt = "pub"
|
||||
|
||||
func setupTestDirKeyStore(t *testing.T, genKeys bool) (string, func(t *testing.T)) {
|
||||
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
|
||||
assert.NoError(t, err)
|
||||
|
||||
if genKeys {
|
||||
key1, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+kst_prvExt), key1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key2, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+kst_prvExt), key2)
|
||||
assert.NoError(t, err)
|
||||
err = rsapublic.Write(path.Join(tempDir, "key2.pem."+kst_pubExt), &key2.PublicKey)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key3, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = rsapublic.Write(path.Join(tempDir, "key3.pem."+kst_pubExt), &key3.PublicKey)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
return tempDir, func(t *testing.T) {
|
||||
err := os.RemoveAll(tempDir)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func commonSubTestsKeyStore(t *testing.T, kStore KeyStore) {
|
||||
key4, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key5, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
const extraKID1 = "key4"
|
||||
const extraKID2 = "key5"
|
||||
|
||||
t.Run("TestSetKey", func(t *testing.T) {
|
||||
kStore.SetKey(extraKID1, key4)
|
||||
assert.Contains(t, kStore.ListKeys(), extraKID1)
|
||||
})
|
||||
|
||||
t.Run("TestSetKeyPublic", func(t *testing.T) {
|
||||
kStore.SetKeyPublic(extraKID2, &key5.PublicKey)
|
||||
assert.Contains(t, kStore.ListKeys(), extraKID2)
|
||||
})
|
||||
|
||||
t.Run("TestGetKey", func(t *testing.T) {
|
||||
oKey := kStore.GetKey(extraKID1)
|
||||
assert.Same(t, key4, oKey)
|
||||
pKey := kStore.GetKey(extraKID2)
|
||||
assert.Nil(t, pKey)
|
||||
aKey := kStore.GetKey("key1")
|
||||
assert.NotNil(t, aKey)
|
||||
bKey := kStore.GetKey("key2")
|
||||
assert.NotNil(t, bKey)
|
||||
cKey := kStore.GetKey("key3")
|
||||
assert.Nil(t, cKey)
|
||||
})
|
||||
|
||||
t.Run("TestGetKeyPublic", func(t *testing.T) {
|
||||
oKey := kStore.GetKeyPublic(extraKID1)
|
||||
assert.Same(t, &key4.PublicKey, oKey)
|
||||
pKey := kStore.GetKeyPublic(extraKID2)
|
||||
assert.Same(t, &key5.PublicKey, pKey)
|
||||
aKey := kStore.GetKeyPublic("key1")
|
||||
assert.NotNil(t, aKey)
|
||||
bKey := kStore.GetKeyPublic("key2")
|
||||
assert.NotNil(t, bKey)
|
||||
cKey := kStore.GetKeyPublic("key3")
|
||||
assert.NotNil(t, cKey)
|
||||
})
|
||||
|
||||
t.Run("TestRemoveKey", func(t *testing.T) {
|
||||
kStore.RemoveKey(extraKID1)
|
||||
assert.NotContains(t, kStore.ListKeys(), extraKID1)
|
||||
oKey1 := kStore.GetKey(extraKID1)
|
||||
assert.Nil(t, oKey1)
|
||||
oKey2 := kStore.GetKeyPublic(extraKID1)
|
||||
assert.Nil(t, oKey2)
|
||||
})
|
||||
|
||||
t.Run("TestClearKeys", func(t *testing.T) {
|
||||
kStore.ClearKeys()
|
||||
assert.Empty(t, kStore.ListKeys())
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir, cleaner := setupTestDirKeyStore(t, true)
|
||||
defer cleaner(t)
|
||||
|
||||
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, kst_prvExt, kst_pubExt)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, kStore.ListKeys(), 3)
|
||||
kIDsToFind := []string{"key1", "key2", "key3"}
|
||||
for _, k := range kIDsToFind {
|
||||
assert.Contains(t, kStore.ListKeys(), k)
|
||||
}
|
||||
|
||||
commonSubTestsKeyStore(t, kStore)
|
||||
}
|
||||
|
||||
func TestExportKeyStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir, cleaner := setupTestDirKeyStore(t, true)
|
||||
defer cleaner(t)
|
||||
tempDir2, cleaner2 := setupTestDirKeyStore(t, false)
|
||||
defer cleaner2(t)
|
||||
|
||||
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, kst_prvExt, kst_pubExt)
|
||||
assert.NoError(t, err)
|
||||
|
||||
const prvExt2 = "v"
|
||||
const pubExt2 = "b"
|
||||
|
||||
err = ExportKeyStore(kStore, tempDir2, prvExt2, pubExt2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
kStore2, err := NewMJwtKeyStoreFromDirectory(tempDir2, prvExt2, pubExt2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
kIDsToFind := kStore.ListKeys()
|
||||
assert.Len(t, kStore2.ListKeys(), len(kIDsToFind))
|
||||
for _, k := range kIDsToFind {
|
||||
assert.Contains(t, kStore2.ListKeys(), k)
|
||||
}
|
||||
|
||||
commonSubTestsKeyStore(t, kStore2)
|
||||
}
|
233
keystore.go
Normal file
233
keystore.go
Normal file
@ -0,0 +1,233 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/1f349/rsa-helper/rsapublic"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/spf13/afero"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var ErrMissingPrivateKey = errors.New("missing private key")
|
||||
var ErrMissingPublicKey = errors.New("missing public key")
|
||||
var ErrMissingKeyPair = errors.New("missing key pair")
|
||||
|
||||
const PrivateStr = ".private"
|
||||
const PublicStr = ".public"
|
||||
|
||||
const PemExt = ".pem"
|
||||
const PrivatePemExt = PrivateStr + PemExt
|
||||
const PublicPemExt = PublicStr + PemExt
|
||||
|
||||
type KeyStore struct {
|
||||
mu *sync.RWMutex
|
||||
store map[string]*keyPair
|
||||
dir afero.Fs
|
||||
}
|
||||
|
||||
func NewKeyStore() *KeyStore {
|
||||
return &KeyStore{
|
||||
mu: new(sync.RWMutex),
|
||||
store: make(map[string]*keyPair),
|
||||
}
|
||||
}
|
||||
|
||||
func NewKeyStoreWithDir(dir afero.Fs) *KeyStore {
|
||||
keyStore := NewKeyStore()
|
||||
keyStore.dir = dir
|
||||
return keyStore
|
||||
}
|
||||
|
||||
func NewKeyStoreFromDir(dir afero.Fs) (*KeyStore, error) {
|
||||
keyStore := NewKeyStoreWithDir(dir)
|
||||
err := afero.Walk(dir, ".", func(path string, d fs.FileInfo, err error) error {
|
||||
// maybe this is "name.private.pem"
|
||||
name := filepath.Base(path)
|
||||
ext := filepath.Ext(name)
|
||||
if ext != PemExt {
|
||||
return nil
|
||||
}
|
||||
|
||||
name = strings.TrimSuffix(name, ext)
|
||||
ext = filepath.Ext(name)
|
||||
name = strings.TrimSuffix(name, ext)
|
||||
switch ext {
|
||||
case PrivateStr:
|
||||
open, err := dir.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decode, err := rsaprivate.Decode(open)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyStore.LoadPrivateKey(name, decode)
|
||||
return nil
|
||||
case PublicStr:
|
||||
open, err := dir.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decode, err := rsapublic.Decode(open)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyStore.LoadPublicKey(name, decode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// still invalid
|
||||
return nil
|
||||
})
|
||||
return keyStore, err
|
||||
}
|
||||
|
||||
type keyPair struct {
|
||||
private *rsa.PrivateKey
|
||||
public *rsa.PublicKey
|
||||
}
|
||||
|
||||
func (k *KeyStore) LoadPrivateKey(kid string, key *rsa.PrivateKey) {
|
||||
k.mu.Lock()
|
||||
if k.store[kid] == nil {
|
||||
k.store[kid] = &keyPair{}
|
||||
}
|
||||
k.store[kid].private = key
|
||||
k.store[kid].public = &key.PublicKey
|
||||
k.mu.Unlock()
|
||||
}
|
||||
|
||||
func (k *KeyStore) LoadPublicKey(kid string, key *rsa.PublicKey) {
|
||||
k.mu.Lock()
|
||||
if k.store[kid] == nil {
|
||||
k.store[kid] = &keyPair{}
|
||||
}
|
||||
k.store[kid].public = key
|
||||
k.mu.Unlock()
|
||||
}
|
||||
|
||||
func (k *KeyStore) RemoveKey(kid string) {
|
||||
k.mu.Lock()
|
||||
delete(k.store, kid)
|
||||
k.mu.Unlock()
|
||||
}
|
||||
|
||||
func (k *KeyStore) ListKeys() []string {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
keys := make([]string, 0, len(k.store))
|
||||
for k, _ := range k.store {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func (k *KeyStore) GetPrivateKey(kid string) (*rsa.PrivateKey, error) {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
if !k.internalHasPrivateKey(kid) {
|
||||
return nil, ErrMissingPrivateKey
|
||||
}
|
||||
return k.store[kid].private, nil
|
||||
}
|
||||
|
||||
func (k *KeyStore) GetPublicKey(kid string) (*rsa.PublicKey, error) {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
if !k.internalHasPublicKey(kid) {
|
||||
return nil, ErrMissingPublicKey
|
||||
}
|
||||
return k.store[kid].public, nil
|
||||
}
|
||||
|
||||
func (k *KeyStore) ClearKeys() {
|
||||
k.mu.Lock()
|
||||
clear(k.store)
|
||||
k.mu.Unlock()
|
||||
}
|
||||
|
||||
func (k *KeyStore) HasPrivateKey(kid string) bool {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
return k.internalHasPrivateKey(kid)
|
||||
}
|
||||
|
||||
func (k *KeyStore) internalHasPrivateKey(kid string) bool {
|
||||
v := k.store[kid]
|
||||
return v != nil && v.private != nil
|
||||
}
|
||||
|
||||
func (k *KeyStore) HasPublicKey(kid string) bool {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
return k.internalHasPublicKey(kid)
|
||||
}
|
||||
|
||||
func (k *KeyStore) internalHasPublicKey(kid string) bool {
|
||||
v := k.store[kid]
|
||||
return v != nil && v.public != nil
|
||||
}
|
||||
|
||||
func (k *KeyStore) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
|
||||
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
kid, ok := token.Header["kid"].(string)
|
||||
if !ok {
|
||||
return nil, ErrMissingPublicKey
|
||||
}
|
||||
return k.GetPublicKey(kid)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return withClaims, claims.Valid()
|
||||
}
|
||||
|
||||
func (k *KeyStore) SaveSingleKey(kid string) error {
|
||||
if k.dir == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
k.mu.RLock()
|
||||
pair := k.store[kid]
|
||||
k.mu.RUnlock()
|
||||
if pair == nil {
|
||||
return ErrMissingKeyPair
|
||||
}
|
||||
|
||||
var errs []error
|
||||
if pair.private != nil {
|
||||
errs = append(errs, afero.WriteFile(k.dir, kid+PrivatePemExt, rsaprivate.Encode(pair.private), 0600))
|
||||
}
|
||||
if pair.public != nil {
|
||||
errs = append(errs, afero.WriteFile(k.dir, kid+PublicPemExt, rsapublic.Encode(pair.public), 0600))
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (k *KeyStore) SaveKeys() error {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
|
||||
workers := new(errgroup.Group)
|
||||
workers.SetLimit(runtime.NumCPU())
|
||||
for kid, pair := range k.store {
|
||||
workers.Go(func() error {
|
||||
var errs []error
|
||||
if pair.private != nil {
|
||||
errs = append(errs, afero.WriteFile(k.dir, kid+PrivatePemExt, rsaprivate.Encode(pair.private), 0600))
|
||||
}
|
||||
if pair.public != nil {
|
||||
errs = append(errs, afero.WriteFile(k.dir, kid+PublicPemExt, rsapublic.Encode(pair.public), 0600))
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
})
|
||||
}
|
||||
return workers.Wait()
|
||||
}
|
175
keystore_test.go
Normal file
175
keystore_test.go
Normal file
@ -0,0 +1,175 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/1f349/rsa-helper/rsapublic"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const kst_prvExt = "prv"
|
||||
const kst_pubExt = "pub"
|
||||
|
||||
func setupTestDirKeyStore(t *testing.T, genKeys bool) afero.Fs {
|
||||
tempDir := afero.NewMemMapFs()
|
||||
|
||||
if genKeys {
|
||||
key1, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = afero.WriteFile(tempDir, "key1.private.pem", rsaprivate.Encode(key1), 0600)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key2, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = afero.WriteFile(tempDir, "key2.private.pem", rsaprivate.Encode(key2), 0600)
|
||||
assert.NoError(t, err)
|
||||
err = afero.WriteFile(tempDir, "key2.public.pem", rsapublic.Encode(&key2.PublicKey), 0600)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key3, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = afero.WriteFile(tempDir, "key3.public.pem", rsapublic.Encode(&key3.PublicKey), 0600)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
return tempDir
|
||||
}
|
||||
|
||||
func commonSubTestsKeyStore(t *testing.T, kStore *KeyStore) {
|
||||
key4, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key5, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
const extraKID1 = "key4"
|
||||
const extraKID2 = "key5"
|
||||
|
||||
t.Run("TestSetKey", func(t *testing.T) {
|
||||
kStore.LoadPrivateKey(extraKID1, key4)
|
||||
assert.Contains(t, kStore.ListKeys(), extraKID1)
|
||||
})
|
||||
|
||||
t.Run("TestSetKeyPublic", func(t *testing.T) {
|
||||
kStore.LoadPublicKey(extraKID2, &key5.PublicKey)
|
||||
assert.Contains(t, kStore.ListKeys(), extraKID2)
|
||||
})
|
||||
|
||||
t.Run("TestGetPrivateKey", func(t *testing.T) {
|
||||
oKey, err := kStore.GetPrivateKey(extraKID1)
|
||||
assert.NoError(t, err)
|
||||
assert.Same(t, key4, oKey)
|
||||
pKey, err := kStore.GetPrivateKey(extraKID2)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrMissingPrivateKey)
|
||||
assert.Nil(t, pKey)
|
||||
aKey, err := kStore.GetPrivateKey("key1")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, aKey)
|
||||
bKey, err := kStore.GetPrivateKey("key2")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, bKey)
|
||||
cKey, err := kStore.GetPrivateKey("key3")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrMissingPrivateKey)
|
||||
assert.Nil(t, cKey)
|
||||
wKey, err := kStore.GetPrivateKey("key1337")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrMissingPrivateKey)
|
||||
assert.Nil(t, wKey)
|
||||
})
|
||||
|
||||
t.Run("TestGetPublicKey", func(t *testing.T) {
|
||||
oKey, err := kStore.GetPublicKey(extraKID1)
|
||||
assert.NoError(t, err)
|
||||
assert.Same(t, &key4.PublicKey, oKey)
|
||||
pKey, err := kStore.GetPublicKey(extraKID2)
|
||||
assert.NoError(t, err)
|
||||
assert.Same(t, &key5.PublicKey, pKey)
|
||||
aKey, err := kStore.GetPublicKey("key1")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, aKey)
|
||||
bKey, err := kStore.GetPublicKey("key2")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, bKey)
|
||||
cKey, err := kStore.GetPublicKey("key3")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cKey)
|
||||
wKey, err := kStore.GetPublicKey("key1337")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrMissingPublicKey)
|
||||
assert.Nil(t, wKey)
|
||||
})
|
||||
|
||||
t.Run("TestRemoveKey", func(t *testing.T) {
|
||||
kStore.RemoveKey(extraKID1)
|
||||
assert.NotContains(t, kStore.ListKeys(), extraKID1)
|
||||
oKey1, err := kStore.GetPrivateKey(extraKID1)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrMissingPrivateKey)
|
||||
assert.Nil(t, oKey1)
|
||||
oKey2, err := kStore.GetPublicKey(extraKID1)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrMissingPublicKey)
|
||||
assert.Nil(t, oKey2)
|
||||
})
|
||||
|
||||
t.Run("TestClearKeys", func(t *testing.T) {
|
||||
kStore.ClearKeys()
|
||||
assert.Empty(t, kStore.ListKeys())
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := setupTestDirKeyStore(t, true)
|
||||
|
||||
kStore, err := NewKeyStoreFromDir(tempDir)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, kStore.ListKeys(), 3)
|
||||
kIDsToFind := []string{"key1", "key2", "key3"}
|
||||
for _, k := range kIDsToFind {
|
||||
assert.Contains(t, kStore.ListKeys(), k)
|
||||
}
|
||||
assert.True(t, kStore.HasPrivateKey("key1"))
|
||||
assert.True(t, kStore.HasPublicKey("key1")) // loading a private key also loads the public key
|
||||
assert.True(t, kStore.HasPrivateKey("key2"))
|
||||
assert.True(t, kStore.HasPublicKey("key2"))
|
||||
assert.False(t, kStore.HasPrivateKey("key3"))
|
||||
assert.True(t, kStore.HasPublicKey("key3"))
|
||||
|
||||
commonSubTestsKeyStore(t, kStore)
|
||||
}
|
||||
|
||||
func TestExportKeyStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := setupTestDirKeyStore(t, true)
|
||||
tempDir2 := setupTestDirKeyStore(t, false)
|
||||
|
||||
kStore, err := NewKeyStoreFromDir(tempDir)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// internally swap directory
|
||||
kStore.dir = tempDir2
|
||||
|
||||
err = kStore.SaveKeys()
|
||||
assert.NoError(t, err)
|
||||
|
||||
kStore2, err := NewKeyStoreFromDir(tempDir2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
kidList1 := kStore.ListKeys()
|
||||
kidList2 := kStore2.ListKeys()
|
||||
sort.Strings(kidList1)
|
||||
sort.Strings(kidList2)
|
||||
assert.Equal(t, kidList1, kidList2)
|
||||
|
||||
commonSubTestsKeyStore(t, kStore2)
|
||||
}
|
143
mjwt_test.go
143
mjwt_test.go
@ -1,143 +0,0 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var mt_ExtraKID = "tester"
|
||||
|
||||
type testClaims struct{ TestValue string }
|
||||
|
||||
func (t testClaims) Valid() error {
|
||||
if t.TestValue != "hello" && t.TestValue != "world" {
|
||||
return fmt.Errorf("TestValue should be hello")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t testClaims) Type() string { return "testClaims" }
|
||||
|
||||
type testClaims2 struct{ TestValue2 string }
|
||||
|
||||
func (t testClaims2) Valid() error {
|
||||
if t.TestValue2 != "world" {
|
||||
return fmt.Errorf("TestValue2 should be world")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t testClaims2) Type() string { return "testClaims2" }
|
||||
|
||||
func setupTestKeyStoreMJWT(t *testing.T) (ks KeyStore, a, b, c *rsa.PrivateKey) {
|
||||
ks = NewMJwtKeyStore()
|
||||
var err error
|
||||
|
||||
a, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
ks.SetKey("key1", a)
|
||||
|
||||
b, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
ks.SetKey("key2", b)
|
||||
|
||||
c, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
ks.SetKey("key3", c)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TestExtractClaims(t *testing.T) {
|
||||
t.Parallel()
|
||||
kStore, key, _, _ := setupTestKeyStoreMJWT(t)
|
||||
|
||||
t.Run("TestNoKID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewMJwtSigner("mjwt.test", key)
|
||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := NewMJwtVerifier(&key.PublicKey)
|
||||
_, _, err = ExtractClaims[testClaims](m, token)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("TestKID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore)
|
||||
token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
||||
assert.NoError(t, err)
|
||||
token2, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}, "key2")
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := NewMJwtVerifierWithKeyStore(&key.PublicKey, kStore)
|
||||
_, _, err = ExtractClaims[testClaims](m, token1)
|
||||
assert.NoError(t, err)
|
||||
_, _, err = ExtractClaims[testClaims](m, token2)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractClaimsFail(t *testing.T) {
|
||||
t.Parallel()
|
||||
kStore, key, key2, _ := setupTestKeyStoreMJWT(t)
|
||||
|
||||
t.Run("TestInvalidClaims", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewMJwtSigner("mjwt.test", key)
|
||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := NewMJwtVerifier(&key.PublicKey)
|
||||
_, _, err = ExtractClaims[testClaims2](m, token)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrClaimTypeMismatch)
|
||||
})
|
||||
|
||||
t.Run("TestDefaultKeyNoKID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore)
|
||||
token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}, "key1")
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := NewMJwtVerifier(&key.PublicKey)
|
||||
_, _, err = ExtractClaims[testClaims](m, token)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoPublicKeyFound)
|
||||
})
|
||||
|
||||
t.Run("TestNoDefaultKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore)
|
||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := NewMJwtVerifierWithKeyStore(nil, kStore)
|
||||
_, _, err = ExtractClaims[testClaims](m, token)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoPublicKeyFound)
|
||||
})
|
||||
|
||||
t.Run("TestKIDNonExist", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
kStore.SetKey(mt_ExtraKID, key2)
|
||||
assert.Contains(t, kStore.ListKeys(), mt_ExtraKID)
|
||||
|
||||
s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore)
|
||||
token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}, mt_ExtraKID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
kStore.RemoveKey(mt_ExtraKID)
|
||||
assert.NotContains(t, kStore.ListKeys(), mt_ExtraKID)
|
||||
|
||||
m := NewMJwtVerifierWithKeyStore(&key.PublicKey, kStore)
|
||||
_, _, err = ExtractClaims[testClaims](m, token)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoPublicKeyFound)
|
||||
})
|
||||
}
|
204
signer.go
204
signer.go
@ -1,204 +0,0 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
const readLimit = 10240 // 10 KiB
|
||||
|
||||
var ErrNoPrivateKeyFound = errors.New("no private key found")
|
||||
|
||||
// defaultMJwtSigner implements Signer and uses an rsa.PrivateKey and issuer name
|
||||
// to generate MJWT tokens
|
||||
type defaultMJwtSigner struct {
|
||||
issuer string
|
||||
key *rsa.PrivateKey
|
||||
verify *defaultMJwtVerifier
|
||||
}
|
||||
|
||||
var _ Signer = &defaultMJwtSigner{}
|
||||
var _ Verifier = &defaultMJwtSigner{}
|
||||
|
||||
// NewMJwtSigner creates a new defaultMJwtSigner using the issuer name and rsa.PrivateKey
|
||||
func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer {
|
||||
return NewMJwtSignerWithKeyStore(issuer, key, NewMJwtKeyStore())
|
||||
}
|
||||
|
||||
// NewMJwtSignerWithKeyStore creates a new defaultMJwtSigner using the issuer name, a rsa.PrivateKey
|
||||
// for no kID and a KeyStore for kID based keys
|
||||
func NewMJwtSignerWithKeyStore(issuer string, key *rsa.PrivateKey, kStore KeyStore) Signer {
|
||||
var pKey *rsa.PublicKey = nil
|
||||
if key != nil {
|
||||
pKey = &key.PublicKey
|
||||
}
|
||||
return &defaultMJwtSigner{
|
||||
issuer: issuer,
|
||||
key: key,
|
||||
verify: NewMJwtVerifierWithKeyStore(pKey, kStore).(*defaultMJwtVerifier),
|
||||
}
|
||||
}
|
||||
|
||||
// NewMJwtSignerFromFileOrCreate creates a new defaultMJwtSigner using the path
|
||||
// of a rsa.PrivateKey file. If the file does not exist then the file is created
|
||||
// and a new private key is generated.
|
||||
func NewMJwtSignerFromFileOrCreate(issuer, file string, random io.Reader, bits int) (Signer, error) {
|
||||
privateKey, err := readOrCreatePrivateKey(file, random, bits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewMJwtSigner(issuer, privateKey), nil
|
||||
}
|
||||
|
||||
// NewMJwtSignerFromFile creates a new defaultMJwtSigner using the path of a
|
||||
// rsa.PrivateKey file.
|
||||
func NewMJwtSignerFromFile(issuer, file string) (Signer, error) {
|
||||
return NewMJwtSignerFromFileAndDirectory(issuer, file, "", "", "")
|
||||
}
|
||||
|
||||
// NewMJwtSignerFromDirectory creates a new defaultMJwtSigner using the path of a directory to
|
||||
// load the keys into a KeyStore; there is no default rsa.PrivateKey
|
||||
func NewMJwtSignerFromDirectory(issuer, directory, prvExt, pubExt string) (Signer, error) {
|
||||
return NewMJwtSignerFromFileAndDirectory(issuer, "", directory, prvExt, pubExt)
|
||||
}
|
||||
|
||||
// NewMJwtSignerFromFileAndDirectory creates a new defaultMJwtSigner using the path of a rsa.PrivateKey
|
||||
// file as the non kID key and the path of a directory to load the keys into a KeyStore
|
||||
func NewMJwtSignerFromFileAndDirectory(issuer, file, directory, prvExt, pubExt string) (Signer, error) {
|
||||
var err error
|
||||
|
||||
// read key
|
||||
var prv *rsa.PrivateKey = nil
|
||||
if file != "" {
|
||||
prv, err = rsaprivate.Read(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// read KeyStore
|
||||
var kStore KeyStore = nil
|
||||
if directory != "" {
|
||||
kStore, err = NewMJwtKeyStoreFromDirectory(directory, prvExt, pubExt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return NewMJwtSignerWithKeyStore(issuer, prv, kStore), nil
|
||||
}
|
||||
|
||||
// Issuer returns the name of the issuer
|
||||
func (d *defaultMJwtSigner) Issuer() string {
|
||||
return d.issuer
|
||||
}
|
||||
|
||||
// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims; uses the default key
|
||||
func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
|
||||
return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims))
|
||||
}
|
||||
|
||||
// SignJwt signs a jwt.Claims compatible struct, this is used internally by
|
||||
// GenerateJwt but is available for signing custom structs; uses the default key
|
||||
func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) {
|
||||
if d.key == nil {
|
||||
return "", ErrNoPrivateKeyFound
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
|
||||
return token.SignedString(d.key)
|
||||
}
|
||||
|
||||
// GenerateJwtWithKID generates and returns a JWT string using the sub, id, duration and claims; this gets signed with the specified kID
|
||||
func (d *defaultMJwtSigner) GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error) {
|
||||
return d.SignJwtWithKID(wrapClaims[Claims](d, sub, id, aud, dur, claims), kID)
|
||||
}
|
||||
|
||||
// SignJwtWithKID signs a jwt.Claims compatible struct, this is used internally by
|
||||
// GenerateJwt but is available for signing custom structs; this gets signed with the specified kID
|
||||
func (d *defaultMJwtSigner) SignJwtWithKID(wrapped jwt.Claims, kID string) (string, error) {
|
||||
pKey := d.verify.GetKeyStore().GetKey(kID)
|
||||
if pKey == nil {
|
||||
return "", ErrNoPrivateKeyFound
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
|
||||
token.Header["kid"] = kID
|
||||
return token.SignedString(pKey)
|
||||
}
|
||||
|
||||
// VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt()
|
||||
func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
|
||||
return d.verify.VerifyJwt(token, claims)
|
||||
}
|
||||
|
||||
func (d *defaultMJwtSigner) PrivateKey() *rsa.PrivateKey {
|
||||
return d.key
|
||||
}
|
||||
func (d *defaultMJwtSigner) PublicKey() *rsa.PublicKey {
|
||||
return d.verify.pub
|
||||
}
|
||||
|
||||
func (d *defaultMJwtSigner) PublicKeyOf(kID string) *rsa.PublicKey {
|
||||
return d.verify.kStore.GetKeyPublic(kID)
|
||||
}
|
||||
|
||||
func (d *defaultMJwtSigner) GetKeyStore() KeyStore {
|
||||
return d.verify.GetKeyStore()
|
||||
}
|
||||
|
||||
func (d *defaultMJwtSigner) PrivateKeyOf(kID string) *rsa.PrivateKey {
|
||||
return d.verify.kStore.GetKey(kID)
|
||||
}
|
||||
|
||||
// readOrCreatePrivateKey returns the private key it the file already exists,
|
||||
// generates a new private key and saves it to the file, or returns an error if
|
||||
// reading or generating failed.
|
||||
func readOrCreatePrivateKey(file string, random io.Reader, bits int) (*rsa.PrivateKey, error) {
|
||||
// read the file or return nil
|
||||
f, err := readOrEmptyFile(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if f == nil {
|
||||
// generate a new key
|
||||
key, err := rsa.GenerateKey(random, bits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// save key to file
|
||||
err = rsaprivate.Write(file, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return key, err
|
||||
} else {
|
||||
// return key
|
||||
return rsaprivate.Decode(bytes.NewReader(f))
|
||||
}
|
||||
}
|
||||
|
||||
// readOrEmptyFile returns bytes and errors from os.OpenFile or (nil, nil) if the
|
||||
// file does not exist.
|
||||
func readOrEmptyFile(file string) ([]byte, error) {
|
||||
fp, err := os.Open(file)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = fp.Close() }()
|
||||
// add hard limit
|
||||
limitReader := io.LimitReader(fp, readLimit)
|
||||
raw, err := io.ReadAll(limitReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return raw, nil
|
||||
}
|
148
signer_test.go
148
signer_test.go
@ -1,148 +0,0 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/1f349/rsa-helper/rsapublic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const st_prvExt = "prv"
|
||||
const st_pubExt = "pub"
|
||||
|
||||
func setupTestDirSigner(t *testing.T) (string, *rsa.PrivateKey, func(t *testing.T)) {
|
||||
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
|
||||
assert.NoError(t, err)
|
||||
|
||||
var key3 *rsa.PrivateKey = nil
|
||||
key1, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+st_prvExt), key1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key2, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+st_prvExt), key2)
|
||||
assert.NoError(t, err)
|
||||
err = rsapublic.Write(path.Join(tempDir, "key2.pem."+st_pubExt), &key2.PublicKey)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key3, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = rsapublic.Write(path.Join(tempDir, "key3.pem."+st_pubExt), &key3.PublicKey)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return tempDir, key3, func(t *testing.T) {
|
||||
err := os.RemoveAll(tempDir)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewMJwtSigner(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
NewMJwtSigner("Test", key)
|
||||
}
|
||||
|
||||
func TestNewMJwtSignerWithKeyStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
kStore := NewMJwtKeyStore()
|
||||
kStore.SetKey("test", key)
|
||||
assert.Contains(t, kStore.ListKeys(), "test")
|
||||
NewMJwtSignerWithKeyStore("Test", nil, kStore)
|
||||
}
|
||||
|
||||
func TestNewMJwtSignerFromFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempKey, err := os.CreateTemp("", "key-test-*.pem")
|
||||
assert.NoError(t, err)
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
||||
_, err = tempKey.Write(b)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, tempKey.Close())
|
||||
signer, err := NewMJwtSignerFromFile("Test", tempKey.Name())
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, os.Remove(tempKey.Name()))
|
||||
_, err = NewMJwtSignerFromFile("Test", tempKey.Name())
|
||||
assert.Error(t, err)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
assert.True(t, signer.(*defaultMJwtSigner).key.Equal(key))
|
||||
}
|
||||
|
||||
func TestNewMJwtSignerFromFileOrCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempKey, err := os.CreateTemp("", "key-test-*.pem")
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, tempKey.Close())
|
||||
assert.NoError(t, os.Remove(tempKey.Name()))
|
||||
signer, err := NewMJwtSignerFromFileOrCreate("Test", tempKey.Name(), rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
signer2, err := NewMJwtSignerFromFileOrCreate("Test", tempKey.Name(), rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, signer.PrivateKey().Equal(signer2.PrivateKey()))
|
||||
}
|
||||
|
||||
func TestReadOrCreatePrivateKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempKey, err := os.CreateTemp("", "key-test-*.pem")
|
||||
assert.NoError(t, err)
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
||||
_, err = tempKey.Write(b)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, tempKey.Close())
|
||||
key2, err := readOrCreatePrivateKey(tempKey.Name(), rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, key.Equal(key2))
|
||||
assert.NoError(t, os.Remove(tempKey.Name()))
|
||||
key3, err := readOrCreatePrivateKey(tempKey.Name(), rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, key3.Validate())
|
||||
}
|
||||
|
||||
func TestNewMJwtSignerFromDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir, prvKey3, cleaner := setupTestDirSigner(t)
|
||||
defer cleaner(t)
|
||||
|
||||
signer, err := NewMJwtSignerFromDirectory("Test", tempDir, st_prvExt, st_pubExt)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, signer.GetKeyStore().ListKeys(), 3)
|
||||
kIDsToFind := []string{"key1", "key2", "key3"}
|
||||
for _, k := range kIDsToFind {
|
||||
assert.Contains(t, signer.GetKeyStore().ListKeys(), k)
|
||||
}
|
||||
assert.True(t, prvKey3.PublicKey.Equal(signer.GetKeyStore().GetKeyPublic("key3")))
|
||||
}
|
||||
|
||||
func TestNewMJwtSignerFromFileAndDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir, prvKey3, cleaner := setupTestDirSigner(t)
|
||||
defer cleaner(t)
|
||||
|
||||
signer, err := NewMJwtSignerFromFileAndDirectory("Test", path.Join(tempDir, "key1.pem."+st_prvExt), tempDir, st_prvExt, st_pubExt)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, signer.GetKeyStore().ListKeys(), 3)
|
||||
kIDsToFind := []string{"key1", "key2", "key3"}
|
||||
for _, k := range kIDsToFind {
|
||||
assert.Contains(t, signer.GetKeyStore().ListKeys(), k)
|
||||
}
|
||||
assert.True(t, prvKey3.PublicKey.Equal(signer.GetKeyStore().GetKeyPublic("key3")))
|
||||
assert.True(t, signer.PrivateKey().Equal(signer.GetKeyStore().GetKey("key1")))
|
||||
}
|
108
verifier.go
108
verifier.go
@ -1,108 +0,0 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"github.com/1f349/rsa-helper/rsapublic"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
var ErrNoPublicKeyFound = errors.New("no public key found")
|
||||
var ErrKIDInvalid = errors.New("kid invalid")
|
||||
|
||||
// defaultMJwtVerifier implements Verifier and uses a rsa.PublicKey to validate
|
||||
// MJWT tokens
|
||||
type defaultMJwtVerifier struct {
|
||||
pub *rsa.PublicKey
|
||||
kStore KeyStore
|
||||
}
|
||||
|
||||
var _ Verifier = &defaultMJwtVerifier{}
|
||||
|
||||
// NewMJwtVerifier creates a new defaultMJwtVerifier using the rsa.PublicKey
|
||||
func NewMJwtVerifier(key *rsa.PublicKey) Verifier {
|
||||
return NewMJwtVerifierWithKeyStore(key, NewMJwtKeyStore())
|
||||
}
|
||||
|
||||
// NewMJwtVerifierWithKeyStore creates a new defaultMJwtVerifier using a rsa.PublicKey as the non kID key
|
||||
// and a KeyStore for kID based keys
|
||||
func NewMJwtVerifierWithKeyStore(defaultKey *rsa.PublicKey, kStore KeyStore) Verifier {
|
||||
return &defaultMJwtVerifier{pub: defaultKey, kStore: kStore}
|
||||
}
|
||||
|
||||
// NewMJwtVerifierFromFile creates a new defaultMJwtVerifier using the path of a
|
||||
// rsa.PublicKey file
|
||||
func NewMJwtVerifierFromFile(file string) (Verifier, error) {
|
||||
return NewMJwtVerifierFromFileAndDirectory(file, "", "", "")
|
||||
}
|
||||
|
||||
// NewMJwtVerifierFromDirectory creates a new defaultMJwtVerifier using the path of a directory to
|
||||
// load the keys into a KeyStore; there is no default rsa.PublicKey
|
||||
func NewMJwtVerifierFromDirectory(directory, prvExt, pubExt string) (Verifier, error) {
|
||||
return NewMJwtVerifierFromFileAndDirectory("", directory, prvExt, pubExt)
|
||||
}
|
||||
|
||||
// NewMJwtVerifierFromFileAndDirectory creates a new defaultMJwtVerifier using the path of a rsa.PublicKey
|
||||
// file as the non kID key and the path of a directory to load the keys into a KeyStore
|
||||
func NewMJwtVerifierFromFileAndDirectory(file, directory, prvExt, pubExt string) (Verifier, error) {
|
||||
var err error
|
||||
|
||||
// read key
|
||||
var pub *rsa.PublicKey = nil
|
||||
if file != "" {
|
||||
pub, err = rsapublic.Read(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// read KeyStore
|
||||
var kStore KeyStore = nil
|
||||
if directory != "" {
|
||||
kStore, err = NewMJwtKeyStoreFromDirectory(directory, prvExt, pubExt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return NewMJwtVerifierWithKeyStore(pub, kStore), nil
|
||||
}
|
||||
|
||||
// VerifyJwt validates and parses MJWT tokens and returns the claims
|
||||
func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
|
||||
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
kIDI, exs := token.Header["kid"]
|
||||
if exs {
|
||||
kID, ok := kIDI.(string)
|
||||
if !ok {
|
||||
return nil, ErrKIDInvalid
|
||||
}
|
||||
key := d.kStore.GetKeyPublic(kID)
|
||||
if key == nil {
|
||||
return nil, ErrNoPublicKeyFound
|
||||
} else {
|
||||
return key, nil
|
||||
}
|
||||
}
|
||||
if d.pub == nil {
|
||||
return nil, ErrNoPublicKeyFound
|
||||
}
|
||||
return d.pub, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return withClaims, claims.Valid()
|
||||
}
|
||||
|
||||
func (d *defaultMJwtVerifier) PublicKey() *rsa.PublicKey {
|
||||
return d.pub
|
||||
}
|
||||
|
||||
func (d *defaultMJwtVerifier) PublicKeyOf(kID string) *rsa.PublicKey {
|
||||
return d.kStore.GetKeyPublic(kID)
|
||||
}
|
||||
|
||||
func (d *defaultMJwtVerifier) GetKeyStore() KeyStore {
|
||||
return d.kStore
|
||||
}
|
111
verifier_test.go
111
verifier_test.go
@ -1,111 +0,0 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/1f349/rsa-helper/rsapublic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const vt_prvExt = "prv"
|
||||
const vt_pubExt = "pub"
|
||||
|
||||
func setupTestDirVerifier(t *testing.T, genKeys bool) (string, *rsa.PrivateKey, func(t *testing.T)) {
|
||||
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
|
||||
assert.NoError(t, err)
|
||||
|
||||
var key3 *rsa.PrivateKey = nil
|
||||
|
||||
if genKeys {
|
||||
key1, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+vt_prvExt), key1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key2, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+vt_prvExt), key2)
|
||||
assert.NoError(t, err)
|
||||
err = rsapublic.Write(path.Join(tempDir, "key2.pem."+vt_pubExt), &key2.PublicKey)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key3, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = rsapublic.Write(path.Join(tempDir, "key3.pem."+vt_pubExt), &key3.PublicKey)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
return tempDir, key3, func(t *testing.T) {
|
||||
err := os.RemoveAll(tempDir)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewMJwtVerifierFromFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
s := NewMJwtSigner("mjwt.test", key)
|
||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PUBLIC KEY", Bytes: x509.MarshalPKCS1PublicKey(&key.PublicKey)})
|
||||
temp, err := os.CreateTemp("", "this-is-a-test-file.pem")
|
||||
assert.NoError(t, err)
|
||||
_, err = temp.Write(b)
|
||||
assert.NoError(t, err)
|
||||
file, err := NewMJwtVerifierFromFile(temp.Name())
|
||||
assert.NoError(t, err)
|
||||
_, _, err = ExtractClaims[testClaims](file, token)
|
||||
assert.NoError(t, err)
|
||||
err = os.Remove(temp.Name())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestNewMJwtVerifierFromDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir, prvKey3, cleaner := setupTestDirVerifier(t, true)
|
||||
defer cleaner(t)
|
||||
|
||||
s, err := NewMJwtSignerFromDirectory("mjwt.test", tempDir, vt_prvExt, vt_pubExt)
|
||||
assert.NoError(t, err)
|
||||
s.GetKeyStore().SetKey("key3", prvKey3)
|
||||
token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}, "key3")
|
||||
assert.NoError(t, err)
|
||||
|
||||
v, err := NewMJwtVerifierFromDirectory(tempDir, vt_prvExt, vt_pubExt)
|
||||
assert.NoError(t, err)
|
||||
_, _, err = ExtractClaims[testClaims](v, token)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestNewMJwtVerifierFromFileAndDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir, prvKey3, cleaner := setupTestDirVerifier(t, true)
|
||||
defer cleaner(t)
|
||||
|
||||
s, err := NewMJwtSignerFromFileAndDirectory("mjwt.test", path.Join(tempDir, "key2.pem."+vt_prvExt), tempDir, vt_prvExt, vt_pubExt)
|
||||
assert.NoError(t, err)
|
||||
s.GetKeyStore().SetKey("key3", prvKey3)
|
||||
token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"})
|
||||
assert.NoError(t, err)
|
||||
token2, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}, "key3")
|
||||
assert.NoError(t, err)
|
||||
|
||||
v, err := NewMJwtVerifierFromFileAndDirectory(path.Join(tempDir, "key2.pem."+vt_pubExt), tempDir, vt_prvExt, vt_pubExt)
|
||||
assert.NoError(t, err)
|
||||
_, _, err = ExtractClaims[testClaims](v, token1)
|
||||
assert.NoError(t, err)
|
||||
_, _, err = ExtractClaims[testClaims](v, token2)
|
||||
assert.NoError(t, err)
|
||||
}
|
Loading…
Reference in New Issue
Block a user