Rewrite mjwt library to better support keystores

This commit is contained in:
Melon 2024-07-27 17:05:27 +01:00
parent 5d1bd6f8fd
commit cd2d80cb09
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
26 changed files with 684 additions and 1262 deletions

View File

@ -2,14 +2,13 @@ package auth
import (
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"time"
)
// AccessTokenClaims contains the JWT claims for an access token
type AccessTokenClaims struct {
Perms *claims.PermStorage `json:"per"`
Perms *PermStorage `json:"per"`
}
func (a AccessTokenClaims) Valid() error { return nil }
@ -17,21 +16,11 @@ func (a AccessTokenClaims) Valid() error { return nil }
func (a AccessTokenClaims) Type() string { return "access-token" }
// CreateAccessToken creates an access token with the default 15 minute duration
func CreateAccessToken(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
func CreateAccessToken(p *mjwt.Issuer, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, aud, perms)
}
// CreateAccessTokenWithDuration creates an access token with a custom duration
func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
func CreateAccessTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
return p.GenerateJwt(sub, id, aud, dur, &AccessTokenClaims{Perms: perms})
}
// CreateAccessTokenWithKID creates an access token with the default 15 minute duration and the specified kID
func CreateAccessTokenWithKID(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) {
return CreateAccessTokenWithDurationAndKID(p, time.Minute*15, sub, id, aud, perms, kID)
}
// CreateAccessTokenWithDurationAndKID creates an access token with a custom duration and the specified kID
func CreateAccessTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) {
return p.GenerateJwtWithKID(sub, id, aud, dur, &AccessTokenClaims{Perms: perms}, kID)
}

View File

@ -1,55 +1,26 @@
package auth
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/stretchr/testify/assert"
"testing"
)
func TestCreateAccessToken(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
ps := claims.NewPermStorage()
ps := NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")
s := mjwt.NewMJwtSigner("mjwt.test", key)
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
assert.NoError(t, err)
accessToken, err := CreateAccessToken(s, "1", "test", nil, ps)
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
}
func TestCreateAccessTokenInvalid(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey("test", key)
ps := claims.NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")
s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)
accessToken, err := CreateAccessTokenWithKID(s, "1", "test", nil, ps, "test")
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)

View File

@ -2,20 +2,19 @@ package auth
import (
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"time"
)
// CreateTokenPair creates an access and refresh token pair using the default
// 15 minute and 7 day durations respectively
func CreateTokenPair(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
func CreateTokenPair(p *mjwt.Issuer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms)
}
// CreateTokenPairWithDuration creates an access and refresh token pair using
// custom durations for the access and refresh tokens
func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
func CreateTokenPairWithDuration(p *mjwt.Issuer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, aud, perms)
if err != nil {
return "", "", err
@ -26,23 +25,3 @@ func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Durat
}
return accessToken, refreshToken, nil
}
// CreateTokenPairWithKID creates an access and refresh token pair using the default
// 15 minute and 7 day durations respectively using the specified kID
func CreateTokenPairWithKID(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) {
return CreateTokenPairWithDurationAndKID(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms, kID)
}
// CreateTokenPairWithDurationAndKID creates an access and refresh token pair using
// custom durations for the access and refresh tokens
func CreateTokenPairWithDurationAndKID(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) {
accessToken, err := CreateAccessTokenWithDurationAndKID(p, accessDur, sub, id, aud, perms, kID)
if err != nil {
return "", "", err
}
refreshToken, err := CreateRefreshTokenWithDurationAndKID(p, refreshDur, sub, rId, id, rAud, kID)
if err != nil {
return "", "", err
}
return accessToken, refreshToken, nil
}

View File

@ -1,29 +1,26 @@
package auth
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/stretchr/testify/assert"
"testing"
)
func TestCreateTokenPair(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
ps := claims.NewPermStorage()
ps := NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")
s := mjwt.NewMJwtSigner("mjwt.test", key)
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key2", kStore)
assert.NoError(t, err)
accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", "test2", nil, nil, ps)
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
@ -31,38 +28,7 @@ func TestCreateTokenPair(t *testing.T) {
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b2.Subject)
assert.Equal(t, "test2", b2.ID)
}
func TestCreateTokenPairWithKID(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey("test", key)
ps := claims.NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")
s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)
accessToken, refreshToken, err := CreateTokenPairWithKID(s, "1", "test", "test2", nil, nil, ps, "test")
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b2.Subject)
assert.Equal(t, "test2", b2.ID)

View File

@ -1,4 +1,4 @@
package claims
package auth
import (
"bufio"

View File

@ -1,4 +1,4 @@
package claims
package auth
import (
"bytes"

View File

@ -16,21 +16,11 @@ func (r RefreshTokenClaims) Valid() error { return nil }
func (r RefreshTokenClaims) Type() string { return "refresh-token" }
// CreateRefreshToken creates a refresh token with the default 7 day duration
func CreateRefreshToken(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
func CreateRefreshToken(p *mjwt.Issuer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, ati, aud)
}
// CreateRefreshTokenWithDuration creates a refresh token with a custom duration
func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
func CreateRefreshTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
return p.GenerateJwt(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati})
}
// CreateRefreshTokenWithKID creates a refresh token with the default 7 day duration and the specified kID
func CreateRefreshTokenWithKID(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) {
return CreateRefreshTokenWithDurationAndKID(p, time.Hour*24*7, sub, id, ati, aud, kID)
}
// CreateRefreshTokenWithDurationAndKID creates a refresh token with a custom duration and the specified kID
func CreateRefreshTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) {
return p.GenerateJwtWithKID(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati}, kID)
}

View File

@ -1,8 +1,6 @@
package auth
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/stretchr/testify/assert"
"testing"
@ -10,35 +8,15 @@ import (
func TestCreateRefreshToken(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
s := mjwt.NewMJwtSigner("mjwt.test", key)
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
assert.NoError(t, err)
refreshToken, err := CreateRefreshToken(s, "1", "test", "test2", nil)
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
assert.Equal(t, "test2", b.Claims.AccessTokenId)
}
func TestCreateRefreshTokenWithKID(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey("test", key)
s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)
refreshToken, err := CreateRefreshTokenWithKID(s, "1", "test", "test2", nil, "test")
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)

View File

@ -10,11 +10,11 @@ import (
var ErrClaimTypeMismatch = errors.New("claim type mismatch")
// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct
func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
func wrapClaims[T Claims](sub, id, issuer string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
now := time.Now()
return (&BaseTypeClaims[T]{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: p.Issuer(),
Issuer: issuer,
Subject: sub,
Audience: aud,
ExpiresAt: jwt.NewNumericDate(now.Add(dur)),
@ -28,12 +28,12 @@ func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur ti
// ExtractClaims uses a Verifier to validate the MJWT token and returns the parsed
// token and BaseTypeClaims
func ExtractClaims[T Claims](p Verifier, token string) (*jwt.Token, BaseTypeClaims[T], error) {
func ExtractClaims[T Claims](ks *KeyStore, token string) (*jwt.Token, BaseTypeClaims[T], error) {
b := BaseTypeClaims[T]{
RegisteredClaims: jwt.RegisteredClaims{},
Claims: *new(T),
}
tok, err := p.VerifyJwt(token, &b)
tok, err := ks.VerifyJwt(token, &b)
return tok, b, err
}

100
claims_test.go Normal file
View File

@ -0,0 +1,100 @@
package mjwt
import (
"fmt"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
type testClaims struct{ TestValue string }
func (t testClaims) Valid() error {
if t.TestValue != "hello" && t.TestValue != "world" {
return fmt.Errorf("TestValue should be hello")
}
return nil
}
func (t testClaims) Type() string { return "testClaims" }
type testClaims2 struct{ TestValue2 string }
func (t testClaims2) Valid() error {
if t.TestValue2 != "world" {
return fmt.Errorf("TestValue2 should be world")
}
return nil
}
func (t testClaims2) Type() string { return "testClaims2" }
func TestExtractClaims(t *testing.T) {
t.Parallel()
kStore := NewKeyStore()
t.Run("TestNoKID", func(t *testing.T) {
t.Parallel()
s, err := NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
assert.NoError(t, err)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err)
a, _, err := ExtractClaims[testClaims](kStore, token)
assert.NoError(t, err)
kid, _ := a.Header["kid"].(string)
assert.Equal(t, "key1", kid)
})
t.Run("TestKID", func(t *testing.T) {
t.Parallel()
s, err := NewIssuerWithKeyStore("mjwt.test", "key2", kStore)
assert.NoError(t, err)
s2, err := NewIssuerWithKeyStore("mjwt.test", "key3", kStore)
assert.NoError(t, err)
token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err)
token2, err := s2.GenerateJwt("2", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err)
k1, _, err := ExtractClaims[testClaims](kStore, token1)
assert.NoError(t, err)
k2, _, err := ExtractClaims[testClaims](kStore, token2)
assert.NoError(t, err)
assert.NotEqual(t, k1.Header["kid"], k2.Header["kid"])
})
}
func TestExtractClaimsFail(t *testing.T) {
t.Parallel()
kStore := NewKeyStore()
t.Run("TestInvalidClaims", func(t *testing.T) {
t.Parallel()
s, err := NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
assert.NoError(t, err)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
assert.NoError(t, err)
_, _, err = ExtractClaims[testClaims2](kStore, token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrClaimTypeMismatch)
})
t.Run("TestKIDNonExist", func(t *testing.T) {
t.Parallel()
s, err := NewIssuerWithKeyStore("mjwt.test", "key2", kStore)
assert.NoError(t, err)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
assert.NoError(t, err)
kStore.RemoveKey("key2")
assert.NotContains(t, kStore.ListKeys(), "key2")
_, _, err = ExtractClaims[testClaims](kStore, token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPublicKey)
})
}

View File

@ -6,7 +6,6 @@ import (
"fmt"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/auth"
"github.com/1f349/mjwt/claims"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/golang-jwt/jwt/v4"
"github.com/google/subcommands"
@ -35,7 +34,7 @@ func (s *accessCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&s.id, "id", "", "MJWT ID")
f.StringVar(&s.audience, "aud", "", "Comma separated audience items for the MJWT")
f.StringVar(&s.duration, "dur", "15m", "Duration for the MJWT (default: 15m)")
f.StringVar(&s.kID, "kid", "\x00", "The Key ID of the signing key")
f.StringVar(&s.kID, "kid", "", "The Key ID of the signing key")
}
func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
@ -51,7 +50,7 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
return subcommands.ExitFailure
}
ps := claims.NewPermStorage()
ps := auth.NewPermStorage()
for i := 1; i < len(args); i++ {
ps.Set(args[i])
}
@ -67,16 +66,16 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
}
var token string
if s.kID == "\x00" {
signer := mjwt.NewMJwtSigner(s.issuer, key)
token, err = signer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
} else {
kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey(s.kID, key)
signer := mjwt.NewMJwtSignerWithKeyStore(s.issuer, nil, kStore)
token, err = signer.GenerateJwtWithKID(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps}, s.kID)
kStore := mjwt.NewKeyStore()
kStore.LoadPrivateKey(s.kID, key)
issuer, err := mjwt.NewIssuerWithKeyStore(s.issuer, s.kID, kStore)
if err != nil {
panic("this should not fail")
}
token, err = issuer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Error: Failed to generate MJWT token: ", err)
return subcommands.ExitFailure

View File

@ -1,7 +1,7 @@
package claims
package mjwt
// EmptyClaims contains no claims
type EmptyClaims struct {}
type EmptyClaims struct{}
func (e EmptyClaims) Valid() error { return nil }

8
go.mod
View File

@ -10,11 +10,17 @@ require (
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/google/subcommands v1.2.0
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.8.4
github.com/spf13/afero v1.11.0
github.com/stretchr/testify v1.9.0
golang.org/x/sync v0.7.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
golang.org/x/text v0.16.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
)

25
go.sum
View File

@ -2,19 +2,38 @@ github.com/1f349/rsa-helper v0.0.2 h1:N/fLQqg5wrjIzG6G4zdwa5Xcv9/jIPutCls9YekZr9
github.com/1f349/rsa-helper v0.0.2/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg=
github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA=
github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,39 +0,0 @@
package mjwt
import (
"crypto/rsa"
"github.com/golang-jwt/jwt/v4"
"time"
)
// Signer is used to generate MJWT tokens.
// Signer can also be used as a Verifier.
type Signer interface {
Verifier
GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error)
SignJwt(claims jwt.Claims) (string, error)
GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error)
SignJwtWithKID(claims jwt.Claims, kID string) (string, error)
Issuer() string
PrivateKey() *rsa.PrivateKey
PrivateKeyOf(kID string) *rsa.PrivateKey
}
// Verifier is used to verify the validity MJWT tokens and extract the claim values.
type Verifier interface {
VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error)
PublicKey() *rsa.PublicKey
PublicKeyOf(kID string) *rsa.PublicKey
GetKeyStore() KeyStore
}
// KeyStore is used for the kid header support in Signer and Verifier.
type KeyStore interface {
SetKey(kID string, prvKey *rsa.PrivateKey)
SetKeyPublic(kID string, pubKey *rsa.PublicKey)
RemoveKey(kID string)
ListKeys() []string
GetKey(kID string) *rsa.PrivateKey
GetKeyPublic(kID string) *rsa.PublicKey
ClearKeys()
}

49
issuer.go Normal file
View File

@ -0,0 +1,49 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"github.com/golang-jwt/jwt/v4"
"time"
)
type Issuer struct {
issuer string
kid string
keystore *KeyStore
}
func NewIssuer(name, kid string) (*Issuer, error) {
return NewIssuerWithKeyStore(name, kid, NewKeyStore())
}
func NewIssuerWithKeyStore(name, kid string, keystore *KeyStore) (*Issuer, error) {
i := &Issuer{name, kid, keystore}
if i.keystore.HasPrivateKey(kid) {
return i, nil
}
key, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, err
}
i.keystore.LoadPrivateKey(kid, key)
return i, i.keystore.SaveSingleKey(kid)
}
func (i *Issuer) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
return i.SignJwt(wrapClaims[Claims](sub, id, i.issuer, aud, dur, claims))
}
func (i *Issuer) SignJwt(wrapped jwt.Claims) (string, error) {
key, err := i.PrivateKey()
if err != nil {
return "", err
}
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
token.Header["kid"] = i.kid
return token.SignedString(key)
}
func (i *Issuer) PrivateKey() (*rsa.PrivateKey, error) {
return i.keystore.GetPrivateKey(i.kid)
}

58
issuer_test.go Normal file
View File

@ -0,0 +1,58 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"testing"
)
func TestNewIssuer(t *testing.T) {
t.Parallel()
t.Run("generate missing key for issuer", func(t *testing.T) {
t.Parallel()
kStore := NewKeyStore()
issuer, err := NewIssuerWithKeyStore("Test", "test", kStore)
assert.NoError(t, err)
assert.True(t, kStore.HasPrivateKey("test"))
assert.True(t, kStore.HasPublicKey("test"))
assert.Equal(t, "Test", issuer.issuer)
assert.Equal(t, "test", issuer.kid)
})
t.Run("use existing issuer key", func(t *testing.T) {
t.Parallel()
kStore := NewKeyStore()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
kStore.LoadPrivateKey("test", key)
issuer, err := NewIssuerWithKeyStore("Test", "test", kStore)
assert.NoError(t, err)
assert.True(t, kStore.HasPrivateKey("test"))
assert.True(t, kStore.HasPublicKey("test"))
assert.Equal(t, "Test", issuer.issuer)
assert.Equal(t, "test", issuer.kid)
privateKey, err := issuer.PrivateKey()
assert.NoError(t, err)
assert.True(t, key.Equal(privateKey))
})
t.Run("generate missing key in filesystem", func(t *testing.T) {
t.Parallel()
dir := afero.NewMemMapFs()
kStore := NewKeyStoreWithDir(dir)
issuer, err := NewIssuerWithKeyStore("Test", "test", kStore)
assert.NoError(t, err)
assert.True(t, kStore.HasPrivateKey("test"))
assert.True(t, kStore.HasPublicKey("test"))
assert.Equal(t, "Test", issuer.issuer)
assert.Equal(t, "test", issuer.kid)
privKeyFile, err := dir.Open("test.private.pem")
assert.NoError(t, err)
privKey, err := rsaprivate.Decode(privKeyFile)
assert.NoError(t, err)
key, err := issuer.PrivateKey()
assert.NoError(t, err)
assert.True(t, key.Equal(privKey))
})
}

View File

@ -1,185 +0,0 @@
package mjwt
import (
"crypto/rsa"
"errors"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"os"
"path"
"strings"
"sync"
)
// defaultMJwtKeyStore implements KeyStore and stores kIDs against just rsa.PublicKey
// or with rsa.PrivateKey instances as well.
type defaultMJwtKeyStore struct {
rwLocker *sync.RWMutex
store map[string]*rsa.PrivateKey
storePub map[string]*rsa.PublicKey
}
var _ KeyStore = &defaultMJwtKeyStore{}
// NewMJwtKeyStore creates a new defaultMJwtKeyStore.
func NewMJwtKeyStore() KeyStore {
return &defaultMJwtKeyStore{
rwLocker: new(sync.RWMutex),
store: make(map[string]*rsa.PrivateKey),
storePub: make(map[string]*rsa.PublicKey),
}
}
// NewMJwtKeyStoreFromDirectory loads keys from a directory with the specified extensions to denote public and private
// rsa keys; the kID is the filename of the key up to the first .
func NewMJwtKeyStoreFromDirectory(directory, keyPrvExt, keyPubExt string) (KeyStore, error) {
// Create empty KeyStore
ks := NewMJwtKeyStore().(*defaultMJwtKeyStore)
// List directory contents
dirEntries, err := os.ReadDir(directory)
if err != nil {
return nil, err
}
errs := make([]error, 0, len(dirEntries)/2)
// Import keys from files, based on extension
for _, entry := range dirEntries {
if entry.IsDir() {
continue
}
kID, _, _ := strings.Cut(entry.Name(), ".")
if kID == "" {
continue
}
pExt := path.Ext(entry.Name())
if pExt == "."+keyPrvExt {
// Load rsa private key with the file name as the kID (Up to the first .)
key, err2 := rsaprivate.Read(path.Join(directory, entry.Name()))
if err2 == nil {
ks.store[kID] = key
ks.storePub[kID] = &key.PublicKey
}
errs = append(errs, err2)
} else if pExt == "."+keyPubExt {
// Load rsa public key with the file name as the kID (Up to the first .)
key, err2 := rsapublic.Read(path.Join(directory, entry.Name()))
if err2 == nil {
_, exs := ks.store[kID]
if !exs {
ks.store[kID] = nil
}
ks.storePub[kID] = key
}
errs = append(errs, err2)
}
}
return ks, errors.Join(errs...)
}
// ExportKeyStore saves all the keys stored in the specified KeyStore into a directory with the specified
// extensions for public and private keys
func ExportKeyStore(ks KeyStore, directory, keyPrvExt, keyPubExt string) error {
if ks == nil {
return errors.New("ks is nil")
}
// Create directory
err := os.MkdirAll(directory, 0700)
if err != nil {
return err
}
errs := make([]error, 0, len(ks.ListKeys())/2)
// Export all keys
for _, kID := range ks.ListKeys() {
kPrv := ks.GetKey(kID)
if kPrv != nil {
err2 := rsaprivate.Write(path.Join(directory, kID+"."+keyPrvExt), kPrv)
errs = append(errs, err2)
}
kPub := ks.GetKeyPublic(kID)
if kPub != nil {
err2 := rsapublic.Write(path.Join(directory, kID+"."+keyPubExt), kPub)
errs = append(errs, err2)
}
}
return errors.Join(errs...)
}
// SetKey adds a new rsa.PrivateKey with the specified kID to the KeyStore.
func (d *defaultMJwtKeyStore) SetKey(kID string, prvKey *rsa.PrivateKey) {
if prvKey == nil {
return
}
d.rwLocker.Lock()
defer d.rwLocker.Unlock()
d.store[kID] = prvKey
d.storePub[kID] = &prvKey.PublicKey
return
}
// SetKeyPublic adds a new rsa.PublicKey with the specified kID to the KeyStore.
func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) {
if pubKey == nil {
return
}
d.rwLocker.Lock()
defer d.rwLocker.Unlock()
_, exs := d.store[kID]
if !exs {
d.store[kID] = nil
}
d.storePub[kID] = pubKey
return
}
// RemoveKey removes a specified kID from the KeyStore.
func (d *defaultMJwtKeyStore) RemoveKey(kID string) {
d.rwLocker.Lock()
defer d.rwLocker.Unlock()
delete(d.store, kID)
delete(d.storePub, kID)
return
}
// ListKeys lists the kIDs of all the keys in the KeyStore.
func (d *defaultMJwtKeyStore) ListKeys() []string {
d.rwLocker.RLock()
defer d.rwLocker.RUnlock()
lKeys := make([]string, len(d.store))
i := 0
for k := range d.store {
lKeys[i] = k
i++
}
return lKeys
}
// GetKey gets the rsa.PrivateKey given the kID in the KeyStore or null if not found.
func (d *defaultMJwtKeyStore) GetKey(kID string) *rsa.PrivateKey {
d.rwLocker.RLock()
defer d.rwLocker.RUnlock()
kPrv, ok := d.store[kID]
if ok {
return kPrv
}
return nil
}
// GetKeyPublic gets the rsa.PublicKey given the kID in the KeyStore or null if not found.
func (d *defaultMJwtKeyStore) GetKeyPublic(kID string) *rsa.PublicKey {
d.rwLocker.RLock()
defer d.rwLocker.RUnlock()
kPub, ok := d.storePub[kID]
if ok {
return kPub
}
return nil
}
// ClearKeys removes all the stored keys in the KeyStore.
func (d *defaultMJwtKeyStore) ClearKeys() {
d.rwLocker.Lock()
defer d.rwLocker.Unlock()
clear(d.store)
clear(d.storePub)
}

View File

@ -1,152 +0,0 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/stretchr/testify/assert"
"os"
"path"
"testing"
)
const kst_prvExt = "prv"
const kst_pubExt = "pub"
func setupTestDirKeyStore(t *testing.T, genKeys bool) (string, func(t *testing.T)) {
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
assert.NoError(t, err)
if genKeys {
key1, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+kst_prvExt), key1)
assert.NoError(t, err)
key2, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+kst_prvExt), key2)
assert.NoError(t, err)
err = rsapublic.Write(path.Join(tempDir, "key2.pem."+kst_pubExt), &key2.PublicKey)
assert.NoError(t, err)
key3, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = rsapublic.Write(path.Join(tempDir, "key3.pem."+kst_pubExt), &key3.PublicKey)
assert.NoError(t, err)
}
return tempDir, func(t *testing.T) {
err := os.RemoveAll(tempDir)
assert.NoError(t, err)
}
}
func commonSubTestsKeyStore(t *testing.T, kStore KeyStore) {
key4, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
key5, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
const extraKID1 = "key4"
const extraKID2 = "key5"
t.Run("TestSetKey", func(t *testing.T) {
kStore.SetKey(extraKID1, key4)
assert.Contains(t, kStore.ListKeys(), extraKID1)
})
t.Run("TestSetKeyPublic", func(t *testing.T) {
kStore.SetKeyPublic(extraKID2, &key5.PublicKey)
assert.Contains(t, kStore.ListKeys(), extraKID2)
})
t.Run("TestGetKey", func(t *testing.T) {
oKey := kStore.GetKey(extraKID1)
assert.Same(t, key4, oKey)
pKey := kStore.GetKey(extraKID2)
assert.Nil(t, pKey)
aKey := kStore.GetKey("key1")
assert.NotNil(t, aKey)
bKey := kStore.GetKey("key2")
assert.NotNil(t, bKey)
cKey := kStore.GetKey("key3")
assert.Nil(t, cKey)
})
t.Run("TestGetKeyPublic", func(t *testing.T) {
oKey := kStore.GetKeyPublic(extraKID1)
assert.Same(t, &key4.PublicKey, oKey)
pKey := kStore.GetKeyPublic(extraKID2)
assert.Same(t, &key5.PublicKey, pKey)
aKey := kStore.GetKeyPublic("key1")
assert.NotNil(t, aKey)
bKey := kStore.GetKeyPublic("key2")
assert.NotNil(t, bKey)
cKey := kStore.GetKeyPublic("key3")
assert.NotNil(t, cKey)
})
t.Run("TestRemoveKey", func(t *testing.T) {
kStore.RemoveKey(extraKID1)
assert.NotContains(t, kStore.ListKeys(), extraKID1)
oKey1 := kStore.GetKey(extraKID1)
assert.Nil(t, oKey1)
oKey2 := kStore.GetKeyPublic(extraKID1)
assert.Nil(t, oKey2)
})
t.Run("TestClearKeys", func(t *testing.T) {
kStore.ClearKeys()
assert.Empty(t, kStore.ListKeys())
})
}
func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
t.Parallel()
tempDir, cleaner := setupTestDirKeyStore(t, true)
defer cleaner(t)
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, kst_prvExt, kst_pubExt)
assert.NoError(t, err)
assert.Len(t, kStore.ListKeys(), 3)
kIDsToFind := []string{"key1", "key2", "key3"}
for _, k := range kIDsToFind {
assert.Contains(t, kStore.ListKeys(), k)
}
commonSubTestsKeyStore(t, kStore)
}
func TestExportKeyStore(t *testing.T) {
t.Parallel()
tempDir, cleaner := setupTestDirKeyStore(t, true)
defer cleaner(t)
tempDir2, cleaner2 := setupTestDirKeyStore(t, false)
defer cleaner2(t)
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, kst_prvExt, kst_pubExt)
assert.NoError(t, err)
const prvExt2 = "v"
const pubExt2 = "b"
err = ExportKeyStore(kStore, tempDir2, prvExt2, pubExt2)
assert.NoError(t, err)
kStore2, err := NewMJwtKeyStoreFromDirectory(tempDir2, prvExt2, pubExt2)
assert.NoError(t, err)
kIDsToFind := kStore.ListKeys()
assert.Len(t, kStore2.ListKeys(), len(kIDsToFind))
for _, k := range kIDsToFind {
assert.Contains(t, kStore2.ListKeys(), k)
}
commonSubTestsKeyStore(t, kStore2)
}

233
keystore.go Normal file
View File

@ -0,0 +1,233 @@
package mjwt
import (
"crypto/rsa"
"errors"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/golang-jwt/jwt/v4"
"github.com/spf13/afero"
"golang.org/x/sync/errgroup"
"io/fs"
"path/filepath"
"runtime"
"strings"
"sync"
)
var ErrMissingPrivateKey = errors.New("missing private key")
var ErrMissingPublicKey = errors.New("missing public key")
var ErrMissingKeyPair = errors.New("missing key pair")
const PrivateStr = ".private"
const PublicStr = ".public"
const PemExt = ".pem"
const PrivatePemExt = PrivateStr + PemExt
const PublicPemExt = PublicStr + PemExt
type KeyStore struct {
mu *sync.RWMutex
store map[string]*keyPair
dir afero.Fs
}
func NewKeyStore() *KeyStore {
return &KeyStore{
mu: new(sync.RWMutex),
store: make(map[string]*keyPair),
}
}
func NewKeyStoreWithDir(dir afero.Fs) *KeyStore {
keyStore := NewKeyStore()
keyStore.dir = dir
return keyStore
}
func NewKeyStoreFromDir(dir afero.Fs) (*KeyStore, error) {
keyStore := NewKeyStoreWithDir(dir)
err := afero.Walk(dir, ".", func(path string, d fs.FileInfo, err error) error {
// maybe this is "name.private.pem"
name := filepath.Base(path)
ext := filepath.Ext(name)
if ext != PemExt {
return nil
}
name = strings.TrimSuffix(name, ext)
ext = filepath.Ext(name)
name = strings.TrimSuffix(name, ext)
switch ext {
case PrivateStr:
open, err := dir.Open(path)
if err != nil {
return err
}
decode, err := rsaprivate.Decode(open)
if err != nil {
return err
}
keyStore.LoadPrivateKey(name, decode)
return nil
case PublicStr:
open, err := dir.Open(path)
if err != nil {
return err
}
decode, err := rsapublic.Decode(open)
if err != nil {
return err
}
keyStore.LoadPublicKey(name, decode)
return nil
}
// still invalid
return nil
})
return keyStore, err
}
type keyPair struct {
private *rsa.PrivateKey
public *rsa.PublicKey
}
func (k *KeyStore) LoadPrivateKey(kid string, key *rsa.PrivateKey) {
k.mu.Lock()
if k.store[kid] == nil {
k.store[kid] = &keyPair{}
}
k.store[kid].private = key
k.store[kid].public = &key.PublicKey
k.mu.Unlock()
}
func (k *KeyStore) LoadPublicKey(kid string, key *rsa.PublicKey) {
k.mu.Lock()
if k.store[kid] == nil {
k.store[kid] = &keyPair{}
}
k.store[kid].public = key
k.mu.Unlock()
}
func (k *KeyStore) RemoveKey(kid string) {
k.mu.Lock()
delete(k.store, kid)
k.mu.Unlock()
}
func (k *KeyStore) ListKeys() []string {
k.mu.RLock()
defer k.mu.RUnlock()
keys := make([]string, 0, len(k.store))
for k, _ := range k.store {
keys = append(keys, k)
}
return keys
}
func (k *KeyStore) GetPrivateKey(kid string) (*rsa.PrivateKey, error) {
k.mu.RLock()
defer k.mu.RUnlock()
if !k.internalHasPrivateKey(kid) {
return nil, ErrMissingPrivateKey
}
return k.store[kid].private, nil
}
func (k *KeyStore) GetPublicKey(kid string) (*rsa.PublicKey, error) {
k.mu.RLock()
defer k.mu.RUnlock()
if !k.internalHasPublicKey(kid) {
return nil, ErrMissingPublicKey
}
return k.store[kid].public, nil
}
func (k *KeyStore) ClearKeys() {
k.mu.Lock()
clear(k.store)
k.mu.Unlock()
}
func (k *KeyStore) HasPrivateKey(kid string) bool {
k.mu.RLock()
defer k.mu.RUnlock()
return k.internalHasPrivateKey(kid)
}
func (k *KeyStore) internalHasPrivateKey(kid string) bool {
v := k.store[kid]
return v != nil && v.private != nil
}
func (k *KeyStore) HasPublicKey(kid string) bool {
k.mu.RLock()
defer k.mu.RUnlock()
return k.internalHasPublicKey(kid)
}
func (k *KeyStore) internalHasPublicKey(kid string) bool {
v := k.store[kid]
return v != nil && v.public != nil
}
func (k *KeyStore) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, ErrMissingPublicKey
}
return k.GetPublicKey(kid)
})
if err != nil {
return nil, err
}
return withClaims, claims.Valid()
}
func (k *KeyStore) SaveSingleKey(kid string) error {
if k.dir == nil {
return nil
}
k.mu.RLock()
pair := k.store[kid]
k.mu.RUnlock()
if pair == nil {
return ErrMissingKeyPair
}
var errs []error
if pair.private != nil {
errs = append(errs, afero.WriteFile(k.dir, kid+PrivatePemExt, rsaprivate.Encode(pair.private), 0600))
}
if pair.public != nil {
errs = append(errs, afero.WriteFile(k.dir, kid+PublicPemExt, rsapublic.Encode(pair.public), 0600))
}
return errors.Join(errs...)
}
func (k *KeyStore) SaveKeys() error {
k.mu.RLock()
defer k.mu.RUnlock()
workers := new(errgroup.Group)
workers.SetLimit(runtime.NumCPU())
for kid, pair := range k.store {
workers.Go(func() error {
var errs []error
if pair.private != nil {
errs = append(errs, afero.WriteFile(k.dir, kid+PrivatePemExt, rsaprivate.Encode(pair.private), 0600))
}
if pair.public != nil {
errs = append(errs, afero.WriteFile(k.dir, kid+PublicPemExt, rsapublic.Encode(pair.public), 0600))
}
return errors.Join(errs...)
})
}
return workers.Wait()
}

175
keystore_test.go Normal file
View File

@ -0,0 +1,175 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"sort"
"testing"
)
const kst_prvExt = "prv"
const kst_pubExt = "pub"
func setupTestDirKeyStore(t *testing.T, genKeys bool) afero.Fs {
tempDir := afero.NewMemMapFs()
if genKeys {
key1, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = afero.WriteFile(tempDir, "key1.private.pem", rsaprivate.Encode(key1), 0600)
assert.NoError(t, err)
key2, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = afero.WriteFile(tempDir, "key2.private.pem", rsaprivate.Encode(key2), 0600)
assert.NoError(t, err)
err = afero.WriteFile(tempDir, "key2.public.pem", rsapublic.Encode(&key2.PublicKey), 0600)
assert.NoError(t, err)
key3, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = afero.WriteFile(tempDir, "key3.public.pem", rsapublic.Encode(&key3.PublicKey), 0600)
assert.NoError(t, err)
}
return tempDir
}
func commonSubTestsKeyStore(t *testing.T, kStore *KeyStore) {
key4, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
key5, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
const extraKID1 = "key4"
const extraKID2 = "key5"
t.Run("TestSetKey", func(t *testing.T) {
kStore.LoadPrivateKey(extraKID1, key4)
assert.Contains(t, kStore.ListKeys(), extraKID1)
})
t.Run("TestSetKeyPublic", func(t *testing.T) {
kStore.LoadPublicKey(extraKID2, &key5.PublicKey)
assert.Contains(t, kStore.ListKeys(), extraKID2)
})
t.Run("TestGetPrivateKey", func(t *testing.T) {
oKey, err := kStore.GetPrivateKey(extraKID1)
assert.NoError(t, err)
assert.Same(t, key4, oKey)
pKey, err := kStore.GetPrivateKey(extraKID2)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPrivateKey)
assert.Nil(t, pKey)
aKey, err := kStore.GetPrivateKey("key1")
assert.NoError(t, err)
assert.NotNil(t, aKey)
bKey, err := kStore.GetPrivateKey("key2")
assert.NoError(t, err)
assert.NotNil(t, bKey)
cKey, err := kStore.GetPrivateKey("key3")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPrivateKey)
assert.Nil(t, cKey)
wKey, err := kStore.GetPrivateKey("key1337")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPrivateKey)
assert.Nil(t, wKey)
})
t.Run("TestGetPublicKey", func(t *testing.T) {
oKey, err := kStore.GetPublicKey(extraKID1)
assert.NoError(t, err)
assert.Same(t, &key4.PublicKey, oKey)
pKey, err := kStore.GetPublicKey(extraKID2)
assert.NoError(t, err)
assert.Same(t, &key5.PublicKey, pKey)
aKey, err := kStore.GetPublicKey("key1")
assert.NoError(t, err)
assert.NotNil(t, aKey)
bKey, err := kStore.GetPublicKey("key2")
assert.NoError(t, err)
assert.NotNil(t, bKey)
cKey, err := kStore.GetPublicKey("key3")
assert.NoError(t, err)
assert.NotNil(t, cKey)
wKey, err := kStore.GetPublicKey("key1337")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPublicKey)
assert.Nil(t, wKey)
})
t.Run("TestRemoveKey", func(t *testing.T) {
kStore.RemoveKey(extraKID1)
assert.NotContains(t, kStore.ListKeys(), extraKID1)
oKey1, err := kStore.GetPrivateKey(extraKID1)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPrivateKey)
assert.Nil(t, oKey1)
oKey2, err := kStore.GetPublicKey(extraKID1)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPublicKey)
assert.Nil(t, oKey2)
})
t.Run("TestClearKeys", func(t *testing.T) {
kStore.ClearKeys()
assert.Empty(t, kStore.ListKeys())
})
}
func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
t.Parallel()
tempDir := setupTestDirKeyStore(t, true)
kStore, err := NewKeyStoreFromDir(tempDir)
assert.NoError(t, err)
assert.Len(t, kStore.ListKeys(), 3)
kIDsToFind := []string{"key1", "key2", "key3"}
for _, k := range kIDsToFind {
assert.Contains(t, kStore.ListKeys(), k)
}
assert.True(t, kStore.HasPrivateKey("key1"))
assert.True(t, kStore.HasPublicKey("key1")) // loading a private key also loads the public key
assert.True(t, kStore.HasPrivateKey("key2"))
assert.True(t, kStore.HasPublicKey("key2"))
assert.False(t, kStore.HasPrivateKey("key3"))
assert.True(t, kStore.HasPublicKey("key3"))
commonSubTestsKeyStore(t, kStore)
}
func TestExportKeyStore(t *testing.T) {
t.Parallel()
tempDir := setupTestDirKeyStore(t, true)
tempDir2 := setupTestDirKeyStore(t, false)
kStore, err := NewKeyStoreFromDir(tempDir)
assert.NoError(t, err)
// internally swap directory
kStore.dir = tempDir2
err = kStore.SaveKeys()
assert.NoError(t, err)
kStore2, err := NewKeyStoreFromDir(tempDir2)
assert.NoError(t, err)
kidList1 := kStore.ListKeys()
kidList2 := kStore2.ListKeys()
sort.Strings(kidList1)
sort.Strings(kidList2)
assert.Equal(t, kidList1, kidList2)
commonSubTestsKeyStore(t, kStore2)
}

View File

@ -1,143 +0,0 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"fmt"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
var mt_ExtraKID = "tester"
type testClaims struct{ TestValue string }
func (t testClaims) Valid() error {
if t.TestValue != "hello" && t.TestValue != "world" {
return fmt.Errorf("TestValue should be hello")
}
return nil
}
func (t testClaims) Type() string { return "testClaims" }
type testClaims2 struct{ TestValue2 string }
func (t testClaims2) Valid() error {
if t.TestValue2 != "world" {
return fmt.Errorf("TestValue2 should be world")
}
return nil
}
func (t testClaims2) Type() string { return "testClaims2" }
func setupTestKeyStoreMJWT(t *testing.T) (ks KeyStore, a, b, c *rsa.PrivateKey) {
ks = NewMJwtKeyStore()
var err error
a, err = rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
ks.SetKey("key1", a)
b, err = rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
ks.SetKey("key2", b)
c, err = rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
ks.SetKey("key3", c)
return
}
func TestExtractClaims(t *testing.T) {
t.Parallel()
kStore, key, _, _ := setupTestKeyStoreMJWT(t)
t.Run("TestNoKID", func(t *testing.T) {
t.Parallel()
s := NewMJwtSigner("mjwt.test", key)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err)
m := NewMJwtVerifier(&key.PublicKey)
_, _, err = ExtractClaims[testClaims](m, token)
assert.NoError(t, err)
})
t.Run("TestKID", func(t *testing.T) {
t.Parallel()
s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore)
token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err)
token2, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}, "key2")
assert.NoError(t, err)
m := NewMJwtVerifierWithKeyStore(&key.PublicKey, kStore)
_, _, err = ExtractClaims[testClaims](m, token1)
assert.NoError(t, err)
_, _, err = ExtractClaims[testClaims](m, token2)
assert.NoError(t, err)
})
}
func TestExtractClaimsFail(t *testing.T) {
t.Parallel()
kStore, key, key2, _ := setupTestKeyStoreMJWT(t)
t.Run("TestInvalidClaims", func(t *testing.T) {
t.Parallel()
s := NewMJwtSigner("mjwt.test", key)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
assert.NoError(t, err)
m := NewMJwtVerifier(&key.PublicKey)
_, _, err = ExtractClaims[testClaims2](m, token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrClaimTypeMismatch)
})
t.Run("TestDefaultKeyNoKID", func(t *testing.T) {
t.Parallel()
s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore)
token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}, "key1")
assert.NoError(t, err)
m := NewMJwtVerifier(&key.PublicKey)
_, _, err = ExtractClaims[testClaims](m, token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoPublicKeyFound)
})
t.Run("TestNoDefaultKey", func(t *testing.T) {
t.Parallel()
s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
assert.NoError(t, err)
m := NewMJwtVerifierWithKeyStore(nil, kStore)
_, _, err = ExtractClaims[testClaims](m, token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoPublicKeyFound)
})
t.Run("TestKIDNonExist", func(t *testing.T) {
t.Parallel()
kStore.SetKey(mt_ExtraKID, key2)
assert.Contains(t, kStore.ListKeys(), mt_ExtraKID)
s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore)
token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}, mt_ExtraKID)
assert.NoError(t, err)
kStore.RemoveKey(mt_ExtraKID)
assert.NotContains(t, kStore.ListKeys(), mt_ExtraKID)
m := NewMJwtVerifierWithKeyStore(&key.PublicKey, kStore)
_, _, err = ExtractClaims[testClaims](m, token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoPublicKeyFound)
})
}

204
signer.go
View File

@ -1,204 +0,0 @@
package mjwt
import (
"bytes"
"crypto/rsa"
"errors"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/golang-jwt/jwt/v4"
"io"
"os"
"time"
)
const readLimit = 10240 // 10 KiB
var ErrNoPrivateKeyFound = errors.New("no private key found")
// defaultMJwtSigner implements Signer and uses an rsa.PrivateKey and issuer name
// to generate MJWT tokens
type defaultMJwtSigner struct {
issuer string
key *rsa.PrivateKey
verify *defaultMJwtVerifier
}
var _ Signer = &defaultMJwtSigner{}
var _ Verifier = &defaultMJwtSigner{}
// NewMJwtSigner creates a new defaultMJwtSigner using the issuer name and rsa.PrivateKey
func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer {
return NewMJwtSignerWithKeyStore(issuer, key, NewMJwtKeyStore())
}
// NewMJwtSignerWithKeyStore creates a new defaultMJwtSigner using the issuer name, a rsa.PrivateKey
// for no kID and a KeyStore for kID based keys
func NewMJwtSignerWithKeyStore(issuer string, key *rsa.PrivateKey, kStore KeyStore) Signer {
var pKey *rsa.PublicKey = nil
if key != nil {
pKey = &key.PublicKey
}
return &defaultMJwtSigner{
issuer: issuer,
key: key,
verify: NewMJwtVerifierWithKeyStore(pKey, kStore).(*defaultMJwtVerifier),
}
}
// NewMJwtSignerFromFileOrCreate creates a new defaultMJwtSigner using the path
// of a rsa.PrivateKey file. If the file does not exist then the file is created
// and a new private key is generated.
func NewMJwtSignerFromFileOrCreate(issuer, file string, random io.Reader, bits int) (Signer, error) {
privateKey, err := readOrCreatePrivateKey(file, random, bits)
if err != nil {
return nil, err
}
return NewMJwtSigner(issuer, privateKey), nil
}
// NewMJwtSignerFromFile creates a new defaultMJwtSigner using the path of a
// rsa.PrivateKey file.
func NewMJwtSignerFromFile(issuer, file string) (Signer, error) {
return NewMJwtSignerFromFileAndDirectory(issuer, file, "", "", "")
}
// NewMJwtSignerFromDirectory creates a new defaultMJwtSigner using the path of a directory to
// load the keys into a KeyStore; there is no default rsa.PrivateKey
func NewMJwtSignerFromDirectory(issuer, directory, prvExt, pubExt string) (Signer, error) {
return NewMJwtSignerFromFileAndDirectory(issuer, "", directory, prvExt, pubExt)
}
// NewMJwtSignerFromFileAndDirectory creates a new defaultMJwtSigner using the path of a rsa.PrivateKey
// file as the non kID key and the path of a directory to load the keys into a KeyStore
func NewMJwtSignerFromFileAndDirectory(issuer, file, directory, prvExt, pubExt string) (Signer, error) {
var err error
// read key
var prv *rsa.PrivateKey = nil
if file != "" {
prv, err = rsaprivate.Read(file)
if err != nil {
return nil, err
}
}
// read KeyStore
var kStore KeyStore = nil
if directory != "" {
kStore, err = NewMJwtKeyStoreFromDirectory(directory, prvExt, pubExt)
if err != nil {
return nil, err
}
}
return NewMJwtSignerWithKeyStore(issuer, prv, kStore), nil
}
// Issuer returns the name of the issuer
func (d *defaultMJwtSigner) Issuer() string {
return d.issuer
}
// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims; uses the default key
func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims))
}
// SignJwt signs a jwt.Claims compatible struct, this is used internally by
// GenerateJwt but is available for signing custom structs; uses the default key
func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) {
if d.key == nil {
return "", ErrNoPrivateKeyFound
}
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
return token.SignedString(d.key)
}
// GenerateJwtWithKID generates and returns a JWT string using the sub, id, duration and claims; this gets signed with the specified kID
func (d *defaultMJwtSigner) GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error) {
return d.SignJwtWithKID(wrapClaims[Claims](d, sub, id, aud, dur, claims), kID)
}
// SignJwtWithKID signs a jwt.Claims compatible struct, this is used internally by
// GenerateJwt but is available for signing custom structs; this gets signed with the specified kID
func (d *defaultMJwtSigner) SignJwtWithKID(wrapped jwt.Claims, kID string) (string, error) {
pKey := d.verify.GetKeyStore().GetKey(kID)
if pKey == nil {
return "", ErrNoPrivateKeyFound
}
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
token.Header["kid"] = kID
return token.SignedString(pKey)
}
// VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt()
func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
return d.verify.VerifyJwt(token, claims)
}
func (d *defaultMJwtSigner) PrivateKey() *rsa.PrivateKey {
return d.key
}
func (d *defaultMJwtSigner) PublicKey() *rsa.PublicKey {
return d.verify.pub
}
func (d *defaultMJwtSigner) PublicKeyOf(kID string) *rsa.PublicKey {
return d.verify.kStore.GetKeyPublic(kID)
}
func (d *defaultMJwtSigner) GetKeyStore() KeyStore {
return d.verify.GetKeyStore()
}
func (d *defaultMJwtSigner) PrivateKeyOf(kID string) *rsa.PrivateKey {
return d.verify.kStore.GetKey(kID)
}
// readOrCreatePrivateKey returns the private key it the file already exists,
// generates a new private key and saves it to the file, or returns an error if
// reading or generating failed.
func readOrCreatePrivateKey(file string, random io.Reader, bits int) (*rsa.PrivateKey, error) {
// read the file or return nil
f, err := readOrEmptyFile(file)
if err != nil {
return nil, err
}
if f == nil {
// generate a new key
key, err := rsa.GenerateKey(random, bits)
if err != nil {
return nil, err
}
// save key to file
err = rsaprivate.Write(file, key)
if err != nil {
return nil, err
}
return key, err
} else {
// return key
return rsaprivate.Decode(bytes.NewReader(f))
}
}
// readOrEmptyFile returns bytes and errors from os.OpenFile or (nil, nil) if the
// file does not exist.
func readOrEmptyFile(file string) ([]byte, error) {
fp, err := os.Open(file)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
defer func() { _ = fp.Close() }()
// add hard limit
limitReader := io.LimitReader(fp, readLimit)
raw, err := io.ReadAll(limitReader)
if err != nil {
return nil, err
}
return raw, nil
}

View File

@ -1,148 +0,0 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/stretchr/testify/assert"
"os"
"path"
"testing"
)
const st_prvExt = "prv"
const st_pubExt = "pub"
func setupTestDirSigner(t *testing.T) (string, *rsa.PrivateKey, func(t *testing.T)) {
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
assert.NoError(t, err)
var key3 *rsa.PrivateKey = nil
key1, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+st_prvExt), key1)
assert.NoError(t, err)
key2, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+st_prvExt), key2)
assert.NoError(t, err)
err = rsapublic.Write(path.Join(tempDir, "key2.pem."+st_pubExt), &key2.PublicKey)
assert.NoError(t, err)
key3, err = rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = rsapublic.Write(path.Join(tempDir, "key3.pem."+st_pubExt), &key3.PublicKey)
assert.NoError(t, err)
return tempDir, key3, func(t *testing.T) {
err := os.RemoveAll(tempDir)
assert.NoError(t, err)
}
}
func TestNewMJwtSigner(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
NewMJwtSigner("Test", key)
}
func TestNewMJwtSignerWithKeyStore(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
kStore := NewMJwtKeyStore()
kStore.SetKey("test", key)
assert.Contains(t, kStore.ListKeys(), "test")
NewMJwtSignerWithKeyStore("Test", nil, kStore)
}
func TestNewMJwtSignerFromFile(t *testing.T) {
t.Parallel()
tempKey, err := os.CreateTemp("", "key-test-*.pem")
assert.NoError(t, err)
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
_, err = tempKey.Write(b)
assert.NoError(t, err)
assert.NoError(t, tempKey.Close())
signer, err := NewMJwtSignerFromFile("Test", tempKey.Name())
assert.NoError(t, err)
assert.NoError(t, os.Remove(tempKey.Name()))
_, err = NewMJwtSignerFromFile("Test", tempKey.Name())
assert.Error(t, err)
assert.True(t, os.IsNotExist(err))
assert.True(t, signer.(*defaultMJwtSigner).key.Equal(key))
}
func TestNewMJwtSignerFromFileOrCreate(t *testing.T) {
t.Parallel()
tempKey, err := os.CreateTemp("", "key-test-*.pem")
assert.NoError(t, err)
assert.NoError(t, tempKey.Close())
assert.NoError(t, os.Remove(tempKey.Name()))
signer, err := NewMJwtSignerFromFileOrCreate("Test", tempKey.Name(), rand.Reader, 2048)
assert.NoError(t, err)
signer2, err := NewMJwtSignerFromFileOrCreate("Test", tempKey.Name(), rand.Reader, 2048)
assert.NoError(t, err)
assert.True(t, signer.PrivateKey().Equal(signer2.PrivateKey()))
}
func TestReadOrCreatePrivateKey(t *testing.T) {
t.Parallel()
tempKey, err := os.CreateTemp("", "key-test-*.pem")
assert.NoError(t, err)
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
_, err = tempKey.Write(b)
assert.NoError(t, err)
assert.NoError(t, tempKey.Close())
key2, err := readOrCreatePrivateKey(tempKey.Name(), rand.Reader, 2048)
assert.NoError(t, err)
assert.True(t, key.Equal(key2))
assert.NoError(t, os.Remove(tempKey.Name()))
key3, err := readOrCreatePrivateKey(tempKey.Name(), rand.Reader, 2048)
assert.NoError(t, err)
assert.NoError(t, key3.Validate())
}
func TestNewMJwtSignerFromDirectory(t *testing.T) {
t.Parallel()
tempDir, prvKey3, cleaner := setupTestDirSigner(t)
defer cleaner(t)
signer, err := NewMJwtSignerFromDirectory("Test", tempDir, st_prvExt, st_pubExt)
assert.NoError(t, err)
assert.Len(t, signer.GetKeyStore().ListKeys(), 3)
kIDsToFind := []string{"key1", "key2", "key3"}
for _, k := range kIDsToFind {
assert.Contains(t, signer.GetKeyStore().ListKeys(), k)
}
assert.True(t, prvKey3.PublicKey.Equal(signer.GetKeyStore().GetKeyPublic("key3")))
}
func TestNewMJwtSignerFromFileAndDirectory(t *testing.T) {
t.Parallel()
tempDir, prvKey3, cleaner := setupTestDirSigner(t)
defer cleaner(t)
signer, err := NewMJwtSignerFromFileAndDirectory("Test", path.Join(tempDir, "key1.pem."+st_prvExt), tempDir, st_prvExt, st_pubExt)
assert.NoError(t, err)
assert.Len(t, signer.GetKeyStore().ListKeys(), 3)
kIDsToFind := []string{"key1", "key2", "key3"}
for _, k := range kIDsToFind {
assert.Contains(t, signer.GetKeyStore().ListKeys(), k)
}
assert.True(t, prvKey3.PublicKey.Equal(signer.GetKeyStore().GetKeyPublic("key3")))
assert.True(t, signer.PrivateKey().Equal(signer.GetKeyStore().GetKey("key1")))
}

View File

@ -1,108 +0,0 @@
package mjwt
import (
"crypto/rsa"
"errors"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/golang-jwt/jwt/v4"
)
var ErrNoPublicKeyFound = errors.New("no public key found")
var ErrKIDInvalid = errors.New("kid invalid")
// defaultMJwtVerifier implements Verifier and uses a rsa.PublicKey to validate
// MJWT tokens
type defaultMJwtVerifier struct {
pub *rsa.PublicKey
kStore KeyStore
}
var _ Verifier = &defaultMJwtVerifier{}
// NewMJwtVerifier creates a new defaultMJwtVerifier using the rsa.PublicKey
func NewMJwtVerifier(key *rsa.PublicKey) Verifier {
return NewMJwtVerifierWithKeyStore(key, NewMJwtKeyStore())
}
// NewMJwtVerifierWithKeyStore creates a new defaultMJwtVerifier using a rsa.PublicKey as the non kID key
// and a KeyStore for kID based keys
func NewMJwtVerifierWithKeyStore(defaultKey *rsa.PublicKey, kStore KeyStore) Verifier {
return &defaultMJwtVerifier{pub: defaultKey, kStore: kStore}
}
// NewMJwtVerifierFromFile creates a new defaultMJwtVerifier using the path of a
// rsa.PublicKey file
func NewMJwtVerifierFromFile(file string) (Verifier, error) {
return NewMJwtVerifierFromFileAndDirectory(file, "", "", "")
}
// NewMJwtVerifierFromDirectory creates a new defaultMJwtVerifier using the path of a directory to
// load the keys into a KeyStore; there is no default rsa.PublicKey
func NewMJwtVerifierFromDirectory(directory, prvExt, pubExt string) (Verifier, error) {
return NewMJwtVerifierFromFileAndDirectory("", directory, prvExt, pubExt)
}
// NewMJwtVerifierFromFileAndDirectory creates a new defaultMJwtVerifier using the path of a rsa.PublicKey
// file as the non kID key and the path of a directory to load the keys into a KeyStore
func NewMJwtVerifierFromFileAndDirectory(file, directory, prvExt, pubExt string) (Verifier, error) {
var err error
// read key
var pub *rsa.PublicKey = nil
if file != "" {
pub, err = rsapublic.Read(file)
if err != nil {
return nil, err
}
}
// read KeyStore
var kStore KeyStore = nil
if directory != "" {
kStore, err = NewMJwtKeyStoreFromDirectory(directory, prvExt, pubExt)
if err != nil {
return nil, err
}
}
return NewMJwtVerifierWithKeyStore(pub, kStore), nil
}
// VerifyJwt validates and parses MJWT tokens and returns the claims
func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
kIDI, exs := token.Header["kid"]
if exs {
kID, ok := kIDI.(string)
if !ok {
return nil, ErrKIDInvalid
}
key := d.kStore.GetKeyPublic(kID)
if key == nil {
return nil, ErrNoPublicKeyFound
} else {
return key, nil
}
}
if d.pub == nil {
return nil, ErrNoPublicKeyFound
}
return d.pub, nil
})
if err != nil {
return nil, err
}
return withClaims, claims.Valid()
}
func (d *defaultMJwtVerifier) PublicKey() *rsa.PublicKey {
return d.pub
}
func (d *defaultMJwtVerifier) PublicKeyOf(kID string) *rsa.PublicKey {
return d.kStore.GetKeyPublic(kID)
}
func (d *defaultMJwtVerifier) GetKeyStore() KeyStore {
return d.kStore
}

View File

@ -1,111 +0,0 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/stretchr/testify/assert"
"os"
"path"
"testing"
"time"
)
const vt_prvExt = "prv"
const vt_pubExt = "pub"
func setupTestDirVerifier(t *testing.T, genKeys bool) (string, *rsa.PrivateKey, func(t *testing.T)) {
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
assert.NoError(t, err)
var key3 *rsa.PrivateKey = nil
if genKeys {
key1, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+vt_prvExt), key1)
assert.NoError(t, err)
key2, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+vt_prvExt), key2)
assert.NoError(t, err)
err = rsapublic.Write(path.Join(tempDir, "key2.pem."+vt_pubExt), &key2.PublicKey)
assert.NoError(t, err)
key3, err = rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = rsapublic.Write(path.Join(tempDir, "key3.pem."+vt_pubExt), &key3.PublicKey)
assert.NoError(t, err)
}
return tempDir, key3, func(t *testing.T) {
err := os.RemoveAll(tempDir)
assert.NoError(t, err)
}
}
func TestNewMJwtVerifierFromFile(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
s := NewMJwtSigner("mjwt.test", key)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"})
assert.NoError(t, err)
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PUBLIC KEY", Bytes: x509.MarshalPKCS1PublicKey(&key.PublicKey)})
temp, err := os.CreateTemp("", "this-is-a-test-file.pem")
assert.NoError(t, err)
_, err = temp.Write(b)
assert.NoError(t, err)
file, err := NewMJwtVerifierFromFile(temp.Name())
assert.NoError(t, err)
_, _, err = ExtractClaims[testClaims](file, token)
assert.NoError(t, err)
err = os.Remove(temp.Name())
assert.NoError(t, err)
}
func TestNewMJwtVerifierFromDirectory(t *testing.T) {
t.Parallel()
tempDir, prvKey3, cleaner := setupTestDirVerifier(t, true)
defer cleaner(t)
s, err := NewMJwtSignerFromDirectory("mjwt.test", tempDir, vt_prvExt, vt_pubExt)
assert.NoError(t, err)
s.GetKeyStore().SetKey("key3", prvKey3)
token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}, "key3")
assert.NoError(t, err)
v, err := NewMJwtVerifierFromDirectory(tempDir, vt_prvExt, vt_pubExt)
assert.NoError(t, err)
_, _, err = ExtractClaims[testClaims](v, token)
assert.NoError(t, err)
}
func TestNewMJwtVerifierFromFileAndDirectory(t *testing.T) {
t.Parallel()
tempDir, prvKey3, cleaner := setupTestDirVerifier(t, true)
defer cleaner(t)
s, err := NewMJwtSignerFromFileAndDirectory("mjwt.test", path.Join(tempDir, "key2.pem."+vt_prvExt), tempDir, vt_prvExt, vt_pubExt)
assert.NoError(t, err)
s.GetKeyStore().SetKey("key3", prvKey3)
token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"})
assert.NoError(t, err)
token2, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}, "key3")
assert.NoError(t, err)
v, err := NewMJwtVerifierFromFileAndDirectory(path.Join(tempDir, "key2.pem."+vt_pubExt), tempDir, vt_prvExt, vt_pubExt)
assert.NoError(t, err)
_, _, err = ExtractClaims[testClaims](v, token1)
assert.NoError(t, err)
_, _, err = ExtractClaims[testClaims](v, token2)
assert.NoError(t, err)
}