Rewrite mjwt library to better support keystores

This commit is contained in:
Melon 2024-07-27 17:05:27 +01:00
parent 5d1bd6f8fd
commit cd2d80cb09
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
26 changed files with 684 additions and 1262 deletions

View File

@ -2,14 +2,13 @@ package auth
import (
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"time"
)
// AccessTokenClaims contains the JWT claims for an access token
type AccessTokenClaims struct {
Perms *claims.PermStorage `json:"per"`
Perms *PermStorage `json:"per"`
}
func (a AccessTokenClaims) Valid() error { return nil }
@ -17,21 +16,11 @@ func (a AccessTokenClaims) Valid() error { return nil }
func (a AccessTokenClaims) Type() string { return "access-token" }
// CreateAccessToken creates an access token with the default 15 minute duration
func CreateAccessToken(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
func CreateAccessToken(p *mjwt.Issuer, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, aud, perms)
}
// CreateAccessTokenWithDuration creates an access token with a custom duration
func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
func CreateAccessTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
return p.GenerateJwt(sub, id, aud, dur, &AccessTokenClaims{Perms: perms})
}
// CreateAccessTokenWithKID creates an access token with the default 15 minute duration and the specified kID
func CreateAccessTokenWithKID(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) {
return CreateAccessTokenWithDurationAndKID(p, time.Minute*15, sub, id, aud, perms, kID)
}
// CreateAccessTokenWithDurationAndKID creates an access token with a custom duration and the specified kID
func CreateAccessTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) {
return p.GenerateJwtWithKID(sub, id, aud, dur, &AccessTokenClaims{Perms: perms}, kID)
}

View File

@ -1,55 +1,26 @@
package auth
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/stretchr/testify/assert"
"testing"
)
func TestCreateAccessToken(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
ps := claims.NewPermStorage()
ps := NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")
s := mjwt.NewMJwtSigner("mjwt.test", key)
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
assert.NoError(t, err)
accessToken, err := CreateAccessToken(s, "1", "test", nil, ps)
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
}
func TestCreateAccessTokenInvalid(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey("test", key)
ps := claims.NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")
s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)
accessToken, err := CreateAccessTokenWithKID(s, "1", "test", nil, ps, "test")
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)

View File

@ -2,20 +2,19 @@ package auth
import (
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"time"
)
// CreateTokenPair creates an access and refresh token pair using the default
// 15 minute and 7 day durations respectively
func CreateTokenPair(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
func CreateTokenPair(p *mjwt.Issuer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms)
}
// CreateTokenPairWithDuration creates an access and refresh token pair using
// custom durations for the access and refresh tokens
func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
func CreateTokenPairWithDuration(p *mjwt.Issuer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, aud, perms)
if err != nil {
return "", "", err
@ -26,23 +25,3 @@ func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Durat
}
return accessToken, refreshToken, nil
}
// CreateTokenPairWithKID creates an access and refresh token pair using the default
// 15 minute and 7 day durations respectively using the specified kID
func CreateTokenPairWithKID(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) {
return CreateTokenPairWithDurationAndKID(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms, kID)
}
// CreateTokenPairWithDurationAndKID creates an access and refresh token pair using
// custom durations for the access and refresh tokens
func CreateTokenPairWithDurationAndKID(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) {
accessToken, err := CreateAccessTokenWithDurationAndKID(p, accessDur, sub, id, aud, perms, kID)
if err != nil {
return "", "", err
}
refreshToken, err := CreateRefreshTokenWithDurationAndKID(p, refreshDur, sub, rId, id, rAud, kID)
if err != nil {
return "", "", err
}
return accessToken, refreshToken, nil
}

View File

@ -1,29 +1,26 @@
package auth
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/claims"
"github.com/stretchr/testify/assert"
"testing"
)
func TestCreateTokenPair(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
ps := claims.NewPermStorage()
ps := NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")
s := mjwt.NewMJwtSigner("mjwt.test", key)
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key2", kStore)
assert.NoError(t, err)
accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", "test2", nil, nil, ps)
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
@ -31,38 +28,7 @@ func TestCreateTokenPair(t *testing.T) {
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b2.Subject)
assert.Equal(t, "test2", b2.ID)
}
func TestCreateTokenPairWithKID(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey("test", key)
ps := claims.NewPermStorage()
ps.Set("mjwt:test")
ps.Set("mjwt:test2")
s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)
accessToken, refreshToken, err := CreateTokenPairWithKID(s, "1", "test", "test2", nil, nil, ps, "test")
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b2.Subject)
assert.Equal(t, "test2", b2.ID)

View File

@ -1,4 +1,4 @@
package claims
package auth
import (
"bufio"

View File

@ -1,4 +1,4 @@
package claims
package auth
import (
"bytes"

View File

@ -16,21 +16,11 @@ func (r RefreshTokenClaims) Valid() error { return nil }
func (r RefreshTokenClaims) Type() string { return "refresh-token" }
// CreateRefreshToken creates a refresh token with the default 7 day duration
func CreateRefreshToken(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
func CreateRefreshToken(p *mjwt.Issuer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, ati, aud)
}
// CreateRefreshTokenWithDuration creates a refresh token with a custom duration
func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
func CreateRefreshTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
return p.GenerateJwt(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati})
}
// CreateRefreshTokenWithKID creates a refresh token with the default 7 day duration and the specified kID
func CreateRefreshTokenWithKID(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) {
return CreateRefreshTokenWithDurationAndKID(p, time.Hour*24*7, sub, id, ati, aud, kID)
}
// CreateRefreshTokenWithDurationAndKID creates a refresh token with a custom duration and the specified kID
func CreateRefreshTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) {
return p.GenerateJwtWithKID(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati}, kID)
}

View File

@ -1,8 +1,6 @@
package auth
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/mjwt"
"github.com/stretchr/testify/assert"
"testing"
@ -10,35 +8,15 @@ import (
func TestCreateRefreshToken(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
s := mjwt.NewMJwtSigner("mjwt.test", key)
kStore := mjwt.NewKeyStore()
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
assert.NoError(t, err)
refreshToken, err := CreateRefreshToken(s, "1", "test", "test2", nil)
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)
assert.Equal(t, "test2", b.Claims.AccessTokenId)
}
func TestCreateRefreshTokenWithKID(t *testing.T) {
t.Parallel()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey("test", key)
s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)
refreshToken, err := CreateRefreshTokenWithKID(s, "1", "test", "test2", nil, "test")
assert.NoError(t, err)
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
assert.NoError(t, err)
assert.Equal(t, "1", b.Subject)
assert.Equal(t, "test", b.ID)

View File

@ -10,11 +10,11 @@ import (
var ErrClaimTypeMismatch = errors.New("claim type mismatch")
// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct
func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
func wrapClaims[T Claims](sub, id, issuer string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
now := time.Now()
return (&BaseTypeClaims[T]{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: p.Issuer(),
Issuer: issuer,
Subject: sub,
Audience: aud,
ExpiresAt: jwt.NewNumericDate(now.Add(dur)),
@ -28,12 +28,12 @@ func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur ti
// ExtractClaims uses a Verifier to validate the MJWT token and returns the parsed
// token and BaseTypeClaims
func ExtractClaims[T Claims](p Verifier, token string) (*jwt.Token, BaseTypeClaims[T], error) {
func ExtractClaims[T Claims](ks *KeyStore, token string) (*jwt.Token, BaseTypeClaims[T], error) {
b := BaseTypeClaims[T]{
RegisteredClaims: jwt.RegisteredClaims{},
Claims: *new(T),
}
tok, err := p.VerifyJwt(token, &b)
tok, err := ks.VerifyJwt(token, &b)
return tok, b, err
}

100
claims_test.go Normal file
View File

@ -0,0 +1,100 @@
package mjwt
import (
"fmt"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
type testClaims struct{ TestValue string }
func (t testClaims) Valid() error {
if t.TestValue != "hello" && t.TestValue != "world" {
return fmt.Errorf("TestValue should be hello")
}
return nil
}
func (t testClaims) Type() string { return "testClaims" }
type testClaims2 struct{ TestValue2 string }
func (t testClaims2) Valid() error {
if t.TestValue2 != "world" {
return fmt.Errorf("TestValue2 should be world")
}
return nil
}
func (t testClaims2) Type() string { return "testClaims2" }
func TestExtractClaims(t *testing.T) {
t.Parallel()
kStore := NewKeyStore()
t.Run("TestNoKID", func(t *testing.T) {
t.Parallel()
s, err := NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
assert.NoError(t, err)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err)
a, _, err := ExtractClaims[testClaims](kStore, token)
assert.NoError(t, err)
kid, _ := a.Header["kid"].(string)
assert.Equal(t, "key1", kid)
})
t.Run("TestKID", func(t *testing.T) {
t.Parallel()
s, err := NewIssuerWithKeyStore("mjwt.test", "key2", kStore)
assert.NoError(t, err)
s2, err := NewIssuerWithKeyStore("mjwt.test", "key3", kStore)
assert.NoError(t, err)
token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err)
token2, err := s2.GenerateJwt("2", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
assert.NoError(t, err)
k1, _, err := ExtractClaims[testClaims](kStore, token1)
assert.NoError(t, err)
k2, _, err := ExtractClaims[testClaims](kStore, token2)
assert.NoError(t, err)
assert.NotEqual(t, k1.Header["kid"], k2.Header["kid"])
})
}
func TestExtractClaimsFail(t *testing.T) {
t.Parallel()
kStore := NewKeyStore()
t.Run("TestInvalidClaims", func(t *testing.T) {
t.Parallel()
s, err := NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
assert.NoError(t, err)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
assert.NoError(t, err)
_, _, err = ExtractClaims[testClaims2](kStore, token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrClaimTypeMismatch)
})
t.Run("TestKIDNonExist", func(t *testing.T) {
t.Parallel()
s, err := NewIssuerWithKeyStore("mjwt.test", "key2", kStore)
assert.NoError(t, err)
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
assert.NoError(t, err)
kStore.RemoveKey("key2")
assert.NotContains(t, kStore.ListKeys(), "key2")
_, _, err = ExtractClaims[testClaims](kStore, token)
assert.Error(t, err)
assert.ErrorIs(t, err, ErrMissingPublicKey)
})
}

View File

@ -6,7 +6,6 @@ import (
"fmt"
"github.com/1f349/mjwt"
"github.com/1f349/mjwt/auth"
"github.com/1f349/mjwt/claims"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/golang-jwt/jwt/v4"
"github.com/google/subcommands"
@ -35,7 +34,7 @@ func (s *accessCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&s.id, "id", "", "MJWT ID")
f.StringVar(&s.audience, "aud", "", "Comma separated audience items for the MJWT")
f.StringVar(&s.duration, "dur", "15m", "Duration for the MJWT (default: 15m)")
f.StringVar(&s.kID, "kid", "\x00", "The Key ID of the signing key")
f.StringVar(&s.kID, "kid", "", "The Key ID of the signing key")
}
func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
@ -51,7 +50,7 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
return subcommands.ExitFailure
}
ps := claims.NewPermStorage()
ps := auth.NewPermStorage()
for i := 1; i < len(args); i++ {
ps.Set(args[i])
}
@ -67,16 +66,16 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
}
var token string
if s.kID == "\x00" {
signer := mjwt.NewMJwtSigner(s.issuer, key)
token, err = signer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
} else {
kStore := mjwt.NewMJwtKeyStore()
kStore.SetKey(s.kID, key)
signer := mjwt.NewMJwtSignerWithKeyStore(s.issuer, nil, kStore)
token, err = signer.GenerateJwtWithKID(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps}, s.kID)
kStore := mjwt.NewKeyStore()
kStore.LoadPrivateKey(s.kID, key)
issuer, err := mjwt.NewIssuerWithKeyStore(s.issuer, s.kID, kStore)
if err != nil {
panic("this should not fail")
}
token, err = issuer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Error: Failed to generate MJWT token: ", err)
return subcommands.ExitFailure

View File

@ -1,7 +1,7 @@
package claims
package mjwt
// EmptyClaims contains no claims
type EmptyClaims struct {}
type EmptyClaims struct{}
func (e EmptyClaims) Valid() error { return nil }

8
go.mod
View File

@ -10,11 +10,17 @@ require (
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/google/subcommands v1.2.0
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.8.4
github.com/spf13/afero v1.11.0
github.com/stretchr/testify v1.9.0
golang.org/x/sync v0.7.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
golang.org/x/text v0.16.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
)

25
go.sum
View File

@ -2,19 +2,38 @@ github.com/1f349/rsa-helper v0.0.2 h1:N/fLQqg5wrjIzG6G4zdwa5Xcv9/jIPutCls9YekZr9
github.com/1f349/rsa-helper v0.0.2/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg=
github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA=
github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,39 +0,0 @@
package mjwt
import (
"crypto/rsa"
"github.com/golang-jwt/jwt/v4"
"time"
)
// Signer is used to generate MJWT tokens.
// Signer can also be used as a Verifier.
type Signer interface {
Verifier
GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error)
SignJwt(claims jwt.Claims) (string, error)
GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error)
SignJwtWithKID(claims jwt.Claims, kID string) (string, error)
Issuer() string
PrivateKey() *rsa.PrivateKey
PrivateKeyOf(kID string) *rsa.PrivateKey
}
// Verifier is used to verify the validity MJWT tokens and extract the claim values.
type Verifier interface {
VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error)
PublicKey() *rsa.PublicKey
PublicKeyOf(kID string) *rsa.PublicKey
GetKeyStore() KeyStore
}
// KeyStore is used for the kid header support in Signer and Verifier.
type KeyStore interface {
SetKey(kID string, prvKey *rsa.PrivateKey)
SetKeyPublic(kID string, pubKey *rsa.PublicKey)
RemoveKey(kID string)
ListKeys() []string
GetKey(kID string) *rsa.PrivateKey
GetKeyPublic(kID string) *rsa.PublicKey
ClearKeys()
}

49
issuer.go Normal file
View File

@ -0,0 +1,49 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"github.com/golang-jwt/jwt/v4"
"time"
)
type Issuer struct {
issuer string
kid string
keystore *KeyStore
}
func NewIssuer(name, kid string) (*Issuer, error) {
return NewIssuerWithKeyStore(name, kid, NewKeyStore())
}
func NewIssuerWithKeyStore(name, kid string, keystore *KeyStore) (*Issuer, error) {
i := &Issuer{name, kid, keystore}
if i.keystore.HasPrivateKey(kid) {
return i, nil
}
key, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, err
}
i.keystore.LoadPrivateKey(kid, key)
return i, i.keystore.SaveSingleKey(kid)
}
func (i *Issuer) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
return i.SignJwt(wrapClaims[Claims](sub, id, i.issuer, aud, dur, claims))
}
func (i *Issuer) SignJwt(wrapped jwt.Claims) (string, error) {
key, err := i.PrivateKey()
if err != nil {
return "", err
}
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
token.Header["kid"] = i.kid
return token.SignedString(key)
}
func (i *Issuer) PrivateKey() (*rsa.PrivateKey, error) {
return i.keystore.GetPrivateKey(i.kid)
}

58
issuer_test.go Normal file
View File

@ -0,0 +1,58 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"testing"
)
func TestNewIssuer(t *testing.T) {
t.Parallel()
t.Run("generate missing key for issuer", func(t *testing.T) {
t.Parallel()
kStore := NewKeyStore()
issuer, err := NewIssuerWithKeyStore("Test", "test", kStore)
assert.NoError(t, err)
assert.True(t, kStore.HasPrivateKey("test"))
assert.True(t, kStore.HasPublicKey("test"))
assert.Equal(t, "Test", issuer.issuer)
assert.Equal(t, "test", issuer.kid)
})
t.Run("use existing issuer key", func(t *testing.T) {
t.Parallel()
kStore := NewKeyStore()
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
kStore.LoadPrivateKey("test", key)
issuer, err := NewIssuerWithKeyStore("Test", "test", kStore)
assert.NoError(t, err)
assert.True(t, kStore.HasPrivateKey("test"))
assert.True(t, kStore.HasPublicKey("test"))
assert.Equal(t, "Test", issuer.issuer)
assert.Equal(t, "test", issuer.kid)
privateKey, err := issuer.PrivateKey()
assert.NoError(t, err)
assert.True(t, key.Equal(privateKey))
})
t.Run("generate missing key in filesystem", func(t *testing.T) {
t.Parallel()
dir := afero.NewMemMapFs()
kStore := NewKeyStoreWithDir(dir)
issuer, err := NewIssuerWithKeyStore("Test", "test", kStore)
assert.NoError(t, err)
assert.True(t, kStore.HasPrivateKey("test"))
assert.True(t, kStore.HasPublicKey("test"))
assert.Equal(t, "Test", issuer.issuer)
assert.Equal(t, "test", issuer.kid)
privKeyFile, err := dir.Open("test.private.pem")
assert.NoError(t, err)
privKey, err := rsaprivate.Decode(privKeyFile)
assert.NoError(t, err)
key, err := issuer.PrivateKey()
assert.NoError(t, err)
assert.True(t, key.Equal(privKey))
})
}

View File

@ -1,185 +0,0 @@
package mjwt
import (
"crypto/rsa"
"errors"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"os"
"path"
"strings"
"sync"
)
// defaultMJwtKeyStore implements KeyStore and stores kIDs against just rsa.PublicKey
// or with rsa.PrivateKey instances as well.
type defaultMJwtKeyStore struct {
rwLocker *sync.RWMutex
store map[string]*rsa.PrivateKey
storePub map[string]*rsa.PublicKey
}
var _ KeyStore = &defaultMJwtKeyStore{}
// NewMJwtKeyStore creates a new defaultMJwtKeyStore.
func NewMJwtKeyStore() KeyStore {
return &defaultMJwtKeyStore{
rwLocker: new(sync.RWMutex),
store: make(map[string]*rsa.PrivateKey),
storePub: make(map[string]*rsa.PublicKey),
}
}
// NewMJwtKeyStoreFromDirectory loads keys from a directory with the specified extensions to denote public and private
// rsa keys; the kID is the filename of the key up to the first .
func NewMJwtKeyStoreFromDirectory(directory, keyPrvExt, keyPubExt string) (KeyStore, error) {
// Create empty KeyStore
ks := NewMJwtKeyStore().(*defaultMJwtKeyStore)
// List directory contents
dirEntries, err := os.ReadDir(directory)
if err != nil {
return nil, err
}
errs := make([]error, 0, len(dirEntries)/2)
// Import keys from files, based on extension
for _, entry := range dirEntries {
if entry.IsDir() {
continue
}
kID, _, _ := strings.Cut(entry.Name(), ".")
if kID == "" {
continue
}
pExt := path.Ext(entry.Name())
if pExt == "."+keyPrvExt {
// Load rsa private key with the file name as the kID (Up to the first .)
key, err2 := rsaprivate.Read(path.Join(directory, entry.Name()))
if err2 == nil {
ks.store[kID] = key
ks.storePub[kID] = &key.PublicKey
}
errs = append(errs, err2)
} else if pExt == "."+keyPubExt {
// Load rsa public key with the file name as the kID (Up to the first .)
key, err2 := rsapublic.Read(path.Join(directory, entry.Name()))
if err2 == nil {
_, exs := ks.store[kID]
if !exs {
ks.store[kID] = nil
}
ks.storePub[kID] = key
}
errs = append(errs, err2)
}
}
return ks, errors.Join(errs...)
}
// ExportKeyStore saves all the keys stored in the specified KeyStore into a directory with the specified
// extensions for public and private keys
func ExportKeyStore(ks KeyStore, directory, keyPrvExt, keyPubExt string) error {
if ks == nil {
return errors.New("ks is nil")
}
// Create directory
err := os.MkdirAll(directory, 0700)
if err != nil {
return err
}
errs := make([]error, 0, len(ks.ListKeys())/2)
// Export all keys
for _, kID := range ks.ListKeys() {
kPrv := ks.GetKey(kID)
if kPrv != nil {
err2 := rsaprivate.Write(path.Join(directory, kID+"."+keyPrvExt), kPrv)
errs = append(errs, err2)
}
kPub := ks.GetKeyPublic(kID)
if kPub != nil {
err2 := rsapublic.Write(path.Join(directory, kID+"."+keyPubExt), kPub)
errs = append(errs, err2)
}
}
return errors.Join(errs...)
}
// SetKey adds a new rsa.PrivateKey with the specified kID to the KeyStore.
func (d *defaultMJwtKeyStore) SetKey(kID string, prvKey *rsa.PrivateKey) {
if prvKey == nil {
return
}
d.rwLocker.Lock()
defer d.rwLocker.Unlock()
d.store[kID] = prvKey
d.storePub[kID] = &prvKey.PublicKey
return
}
// SetKeyPublic adds a new rsa.PublicKey with the specified kID to the KeyStore.
func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) {
if pubKey == nil {
return
}
d.rwLocker.Lock()
defer d.rwLocker.Unlock()
_, exs := d.store[kID]
if !exs {
d.store[kID] = nil
}
d.storePub[kID] = pubKey
return
}
// RemoveKey removes a specified kID from the KeyStore.
func (d *defaultMJwtKeyStore) RemoveKey(kID string) {
d.rwLocker.Lock()
defer d.rwLocker.Unlock()
delete(d.store, kID)
delete(d.storePub, kID)
return
}
// ListKeys lists the kIDs of all the keys in the KeyStore.
func (d *defaultMJwtKeyStore) ListKeys() []string {
d.rwLocker.RLock()
defer d.rwLocker.RUnlock()
lKeys := make([]string, len(d.store))
i := 0
for k := range d.store {
lKeys[i] = k
i++
}
return lKeys
}
// GetKey gets the rsa.PrivateKey given the kID in the KeyStore or null if not found.
func (d *defaultMJwtKeyStore) GetKey(kID string) *rsa.PrivateKey {
d.rwLocker.RLock()
defer d.rwLocker.RUnlock()
kPrv, ok := d.store[kID]
if ok {
return kPrv
}
return nil
}
// GetKeyPublic gets the rsa.PublicKey given the kID in the KeyStore or null if not found.
func (d *defaultMJwtKeyStore) GetKeyPublic(kID string) *rsa.PublicKey {
d.rwLocker.RLock()
defer d.rwLocker.RUnlock()
kPub, ok := d.storePub[kID]
if ok {
return kPub
}
return nil
}
// ClearKeys removes all the stored keys in the KeyStore.
func (d *defaultMJwtKeyStore) ClearKeys() {
d.rwLocker.Lock()
defer d.rwLocker.Unlock()
clear(d.store)
clear(d.storePub)
}

View File

@ -1,152 +0,0 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/stretchr/testify/assert"
"os"
"path"
"testing"
)
const kst_prvExt = "prv"
const kst_pubExt = "pub"
func setupTestDirKeyStore(t *testing.T, genKeys bool) (string, func(t *testing.T)) {
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
assert.NoError(t, err)
if genKeys {
key1, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+kst_prvExt), key1)
assert.NoError(t, err)
key2, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+kst_prvExt), key2)
assert.NoError(t, err)
err = rsapublic.Write(path.Join(tempDir, "key2.pem."+kst_pubExt), &key2.PublicKey)
assert.NoError(t, err)
key3, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = rsapublic.Write(path.Join(tempDir, "key3.pem."+kst_pubExt), &key3.PublicKey)
assert.NoError(t, err)
}
return tempDir, func(t *testing.T) {
err := os.RemoveAll(tempDir)
assert.NoError(t, err)
}
}
func commonSubTestsKeyStore(t *testing.T, kStore KeyStore) {
key4, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
key5, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
const extraKID1 = "key4"
const extraKID2 = "key5"
t.Run("TestSetKey", func(t *testing.T) {
kStore.SetKey(extraKID1, key4)
assert.Contains(t, kStore.ListKeys(), extraKID1)
})
t.Run("TestSetKeyPublic", func(t *testing.T) {
kStore.SetKeyPublic(extraKID2, &key5.PublicKey)
assert.Contains(t, kStore.ListKeys(), extraKID2)
})
t.Run("TestGetKey", func(t *testing.T) {
oKey := kStore.GetKey(extraKID1)
assert.Same(t, key4, oKey)
pKey := kStore.GetKey(extraKID2)
assert.Nil(t, pKey)
aKey := kStore.GetKey("key1")
assert.NotNil(t, aKey)
bKey := kStore.GetKey("key2")
assert.NotNil(t, bKey)
cKey := kStore.GetKey("key3")
assert.Nil(t, cKey)
})
t.Run("TestGetKeyPublic", func(t *testing.T) {
oKey := kStore.GetKeyPublic(extraKID1)
assert.Same(t, &key4.PublicKey, oKey)
pKey := kStore.GetKeyPublic(extraKID2)
assert.Same(t, &key5.PublicKey, pKey)
aKey := kStore.GetKeyPublic("key1")
assert.NotNil(t, aKey)
bKey := kStore.GetKeyPublic("key2")
assert.NotNil(t, bKey)
cKey := kStore.GetKeyPublic("key3")
assert.NotNil(t, cKey)
})
t.Run("TestRemoveKey", func(t *testing.T) {
kStore.RemoveKey(extraKID1)
assert.NotContains(t, kStore.ListKeys(), extraKID1)
oKey1 := kStore.GetKey(extraKID1)
assert.Nil(t, oKey1)
oKey2 := kStore.GetKeyPublic(extraKID1)
assert.Nil(t, oKey2)
})
t.Run("TestClearKeys", func(t *testing.T) {
kStore.ClearKeys()
assert.Empty(t, kStore.ListKeys())
})
}
func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
t.Parallel()
tempDir, cleaner := setupTestDirKeyStore(t, true)
defer cleaner(t)
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, kst_prvExt, kst_pubExt)
assert.NoError(t, err)
assert.Len(t, kStore.ListKeys(), 3)
kIDsToFind := []string{"key1", "key2", "key3"}
for _, k := range kIDsToFind {
assert.Contains(t, kStore.ListKeys(), k)
}
commonSubTestsKeyStore(t, kStore)
}
func TestExportKeyStore(t *testing.T) {
t.Parallel()
tempDir, cleaner := setupTestDirKeyStore(t, true)
defer cleaner(t)
tempDir2, cleaner2 := setupTestDirKeyStore(t, false)
defer cleaner2(t)
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, kst_prvExt, kst_pubExt)
assert.NoError(t, err)
const prvExt2 = "v"
const pubExt2 = "b"
err = ExportKeyStore(kStore, tempDir2, prvExt2, pubExt2)
assert.NoError(t, err)
kStore2, err := NewMJwtKeyStoreFromDirectory(tempDir2, prvExt2, pubExt2)
assert.NoError(t, err)
kIDsToFind := kStore.ListKeys()
assert.Len(t, kStore2.ListKeys(), len(kIDsToFind))
for _, k := range kIDsToFind {
assert.Contains(t, kStore2.ListKeys(), k)
}
commonSubTestsKeyStore(t, kStore2)
}

233
keystore.go Normal file
View File

@ -0,0 +1,233 @@
package mjwt
import (
"crypto/rsa"
"errors"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/golang-jwt/jwt/v4"
"github.com/spf13/afero"
"golang.org/x/sync/errgroup"
"io/fs"
"path/filepath"
"runtime"
"strings"
"sync"
)
var ErrMissingPrivateKey = errors.New("missing private key")
var ErrMissingPublicKey = errors.New("missing public key")
var ErrMissingKeyPair = errors.New("missing key pair")
const PrivateStr = ".private"
const PublicStr = ".public"
const PemExt = ".pem"
const PrivatePemExt = PrivateStr + PemExt
const PublicPemExt = PublicStr + PemExt
type KeyStore struct {
mu *sync.RWMutex
store map[string]*keyPair
dir afero.Fs
}
func NewKeyStore() *KeyStore {
return &KeyStore{
mu: new(sync.RWMutex),
store: make(map[string]*keyPair),
}
}
func NewKeyStoreWithDir(dir afero.Fs) *KeyStore {
keyStore := NewKeyStore()
keyStore.dir = dir
return keyStore
}
func NewKeyStoreFromDir(dir afero.Fs) (*KeyStore, error) {
keyStore := NewKeyStoreWithDir(dir)
err := afero.Walk(dir, ".", func(path string, d fs.FileInfo, err error) error {
// maybe this is "name.private.pem"
name := filepath.Base(path)
ext := filepath.Ext(name)
if ext != PemExt {
return nil
}
name = strings.TrimSuffix(name, ext)
ext = filepath.Ext(name)
name = strings.TrimSuffix(name, ext)
switch ext {
case PrivateStr:
open, err := dir.Open(path)
if err != nil {
return err
}
decode, err := rsaprivate.Decode(open)
if err != nil {
return err
}
keyStore.LoadPrivateKey(name, decode)
return nil
case PublicStr:
open, err := dir.Open(path)
if err != nil {
return err
}
decode, err := rsapublic.Decode(open)
if err != nil {
return err
}
keyStore.LoadPublicKey(name, decode)
return nil
}
// still invalid
return nil
})
return keyStore, err
}
type keyPair struct {
private *rsa.PrivateKey
public *rsa.PublicKey
}
func (k *KeyStore) LoadPrivateKey(kid string, key *rsa.PrivateKey) {
k.mu.Lock()
if k.store[kid] == nil {
k.store[kid] = &keyPair{}
}
k.store[kid].private = key
k.store[kid].public = &key.PublicKey
k.mu.Unlock()
}
func (k *KeyStore) LoadPublicKey(kid string, key *rsa.PublicKey) {
k.mu.Lock()
if k.store[kid] == nil {
k.store[kid] = &keyPair{}
}
k.store[kid].public = key
k.mu.Unlock()
}
func (k *KeyStore) RemoveKey(kid string) {
k.mu.Lock()
delete(k.store, kid)
k.mu.Unlock()
}
func (k *KeyStore) ListKeys() []string {
k.mu.RLock()
defer k.mu.RUnlock()
keys := make([]string, 0, len(k.store))
for k, _ := range k.store {
keys = append(keys, k)
}
return keys
}
func (k *KeyStore) GetPrivateKey(kid string) (*rsa.PrivateKey, error) {
k.mu.RLock()
defer k.mu.RUnlock()
if !k.internalHasPrivateKey(kid) {
return nil, ErrMissingPrivateKey
}
return k.store[kid].private, nil
}
func (k *KeyStore) GetPublicKey(kid string) (*rsa.PublicKey, error) {
k.mu.RLock()
defer k.mu.RUnlock()
if !k.internalHasPublicKey(kid) {
return nil, ErrMissingPublicKey
}
return k.store[kid].public, nil
}
func (k *KeyStore) ClearKeys() {
k.mu.Lock()
clear(k.store)
k.mu.Unlock()
}
func (k *KeyStore) HasPrivateKey(kid string) bool {
k.mu.RLock()
defer k.mu.RUnlock()
return k.internalHasPrivateKey(kid)
}
func (k *KeyStore) internalHasPrivateKey(kid string) bool {
v := k.store[kid]
return v != nil && v.private != nil
}
func (k *KeyStore) HasPublicKey(kid string) bool {
k.mu.RLock()
defer k.mu.RUnlock()
return k.internalHasPublicKey(kid)
}
func (k *KeyStore) internalHasPublicKey(kid string) bool {
v := k.store[kid]
return v != nil && v.public != nil
}
func (k *KeyStore) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, ErrMissingPublicKey
}
return k.GetPublicKey(kid)
})
if err != nil {
return nil, err
}
return withClaims, claims.Valid()
}
func (k *KeyStore) SaveSingleKey(kid string) error {
if k.dir == nil {
return nil
}
k.mu.RLock()
pair := k.store[kid]
k.mu.RUnlock()
if pair == nil {
return ErrMissingKeyPair
}
var errs []error
if pair.private != nil {
errs = append(errs, afero.WriteFile(k.dir, kid+PrivatePemExt, rsaprivate.Encode(pair.private), 0600))
}
if pair.public != nil {
errs = append(errs, afero.WriteFile(k.dir, kid+PublicPemExt, rsapublic.Encode(pair.public), 0600))
}
return errors.Join(errs...)
}
func (k *KeyStore) SaveKeys() error {
k.mu.RLock()
defer k.mu.RUnlock()
workers := new(errgroup.Group)
workers.SetLimit(runtime.NumCPU())
for kid, pair := range k.store {
workers.Go(func() error {
var errs []error
if pair.private != nil {
errs = append(errs, afero.WriteFile(k.dir, kid+PrivatePemExt, rsaprivate.Encode(pair.private), 0600))
}
if pair.public != nil {
errs = append(errs, afero.WriteFile(k.dir, kid+PublicPemExt, rsapublic.Encode(pair.public), 0600))
}
return errors.Join(errs...)
})
}
return workers.Wait()
}

175
keystore_test.go Normal file
View File

<
@ -0,0 +1,175 @@
package mjwt
import (
"crypto/rand"
"crypto/rsa"
"github.com/1f349/rsa-helper/rsaprivate"
"github.com/1f349/rsa-helper/rsapublic"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"sort"
"testing"
)
const kst_prvExt = "prv"
const kst_pubExt = "pub"
func setupTestDirKeyStore(t *testing.T, genKeys bool) afero.Fs {
tempDir := afero.NewMemMapFs()
if genKeys {
key1, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = afero.WriteFile(tempDir, "key1.private.pem", rsaprivate.Encode(key1), 0600)
assert.NoError(t, err)
key2, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = afero.WriteFile(tempDir, "key2.private.pem", rsaprivate.Encode(key2), 0600)
assert.NoError(t, err)
err = afero.WriteFile(tempDir, "key2.public.pem", rsapublic.Encode(&key2.PublicKey), 0600)
assert.NoError(t, err)
key3, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
err = afero.WriteFile(tempDir, "key3.public.pem", rsapublic.Encode(&key3.PublicKey), 0600)
assert.NoError(t, err)
}
return tempDir
}
func commonSubTestsKeyStore(t *testing.T, kStore *KeyStore) {