diff --git a/auth/access-token.go b/auth/access-token.go index 17e8ab7..3e5ea30 100644 --- a/auth/access-token.go +++ b/auth/access-token.go @@ -2,14 +2,13 @@ package auth import ( "github.com/1f349/mjwt" - "github.com/1f349/mjwt/claims" "github.com/golang-jwt/jwt/v4" "time" ) // AccessTokenClaims contains the JWT claims for an access token type AccessTokenClaims struct { - Perms *claims.PermStorage `json:"per"` + Perms *PermStorage `json:"per"` } func (a AccessTokenClaims) Valid() error { return nil } @@ -17,21 +16,11 @@ func (a AccessTokenClaims) Valid() error { return nil } func (a AccessTokenClaims) Type() string { return "access-token" } // CreateAccessToken creates an access token with the default 15 minute duration -func CreateAccessToken(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) { +func CreateAccessToken(p *mjwt.Issuer, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) { return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, aud, perms) } // CreateAccessTokenWithDuration creates an access token with a custom duration -func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) { +func CreateAccessTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) { return p.GenerateJwt(sub, id, aud, dur, &AccessTokenClaims{Perms: perms}) } - -// CreateAccessTokenWithKID creates an access token with the default 15 minute duration and the specified kID -func CreateAccessTokenWithKID(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) { - return CreateAccessTokenWithDurationAndKID(p, time.Minute*15, sub, id, aud, perms, kID) -} - -// CreateAccessTokenWithDurationAndKID creates an access token with a custom duration and the specified kID -func CreateAccessTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) { - return p.GenerateJwtWithKID(sub, id, aud, dur, &AccessTokenClaims{Perms: perms}, kID) -} diff --git a/auth/access-token_test.go b/auth/access-token_test.go index 11b2523..15a5c94 100644 --- a/auth/access-token_test.go +++ b/auth/access-token_test.go @@ -1,55 +1,26 @@ package auth import ( - "crypto/rand" - "crypto/rsa" "github.com/1f349/mjwt" - "github.com/1f349/mjwt/claims" "github.com/stretchr/testify/assert" "testing" ) func TestCreateAccessToken(t *testing.T) { t.Parallel() - key, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - ps := claims.NewPermStorage() + ps := NewPermStorage() ps.Set("mjwt:test") ps.Set("mjwt:test2") - s := mjwt.NewMJwtSigner("mjwt.test", key) + kStore := mjwt.NewKeyStore() + s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", kStore) + assert.NoError(t, err) accessToken, err := CreateAccessToken(s, "1", "test", nil, ps) assert.NoError(t, err) - _, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken) - assert.NoError(t, err) - assert.Equal(t, "1", b.Subject) - assert.Equal(t, "test", b.ID) - assert.True(t, b.Claims.Perms.Has("mjwt:test")) - assert.True(t, b.Claims.Perms.Has("mjwt:test2")) - assert.False(t, b.Claims.Perms.Has("mjwt:test3")) -} - -func TestCreateAccessTokenInvalid(t *testing.T) { - t.Parallel() - key, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - - kStore := mjwt.NewMJwtKeyStore() - kStore.SetKey("test", key) - - ps := claims.NewPermStorage() - ps.Set("mjwt:test") - ps.Set("mjwt:test2") - - s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore) - - accessToken, err := CreateAccessTokenWithKID(s, "1", "test", nil, ps, "test") - assert.NoError(t, err) - - _, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken) + _, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken) assert.NoError(t, err) assert.Equal(t, "1", b.Subject) assert.Equal(t, "test", b.ID) diff --git a/auth/pair.go b/auth/pair.go index 7d50e55..025f181 100644 --- a/auth/pair.go +++ b/auth/pair.go @@ -2,20 +2,19 @@ package auth import ( "github.com/1f349/mjwt" - "github.com/1f349/mjwt/claims" "github.com/golang-jwt/jwt/v4" "time" ) // CreateTokenPair creates an access and refresh token pair using the default // 15 minute and 7 day durations respectively -func CreateTokenPair(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) { +func CreateTokenPair(p *mjwt.Issuer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) { return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms) } // CreateTokenPairWithDuration creates an access and refresh token pair using // custom durations for the access and refresh tokens -func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) { +func CreateTokenPairWithDuration(p *mjwt.Issuer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) { accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, aud, perms) if err != nil { return "", "", err @@ -26,23 +25,3 @@ func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Durat } return accessToken, refreshToken, nil } - -// CreateTokenPairWithKID creates an access and refresh token pair using the default -// 15 minute and 7 day durations respectively using the specified kID -func CreateTokenPairWithKID(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) { - return CreateTokenPairWithDurationAndKID(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms, kID) -} - -// CreateTokenPairWithDurationAndKID creates an access and refresh token pair using -// custom durations for the access and refresh tokens -func CreateTokenPairWithDurationAndKID(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) { - accessToken, err := CreateAccessTokenWithDurationAndKID(p, accessDur, sub, id, aud, perms, kID) - if err != nil { - return "", "", err - } - refreshToken, err := CreateRefreshTokenWithDurationAndKID(p, refreshDur, sub, rId, id, rAud, kID) - if err != nil { - return "", "", err - } - return accessToken, refreshToken, nil -} diff --git a/auth/pair_test.go b/auth/pair_test.go index 0b9d135..aa69704 100644 --- a/auth/pair_test.go +++ b/auth/pair_test.go @@ -1,29 +1,26 @@ package auth import ( - "crypto/rand" - "crypto/rsa" "github.com/1f349/mjwt" - "github.com/1f349/mjwt/claims" "github.com/stretchr/testify/assert" "testing" ) func TestCreateTokenPair(t *testing.T) { t.Parallel() - key, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - ps := claims.NewPermStorage() + ps := NewPermStorage() ps.Set("mjwt:test") ps.Set("mjwt:test2") - s := mjwt.NewMJwtSigner("mjwt.test", key) + kStore := mjwt.NewKeyStore() + s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key2", kStore) + assert.NoError(t, err) accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", "test2", nil, nil, ps) assert.NoError(t, err) - _, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken) + _, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken) assert.NoError(t, err) assert.Equal(t, "1", b.Subject) assert.Equal(t, "test", b.ID) @@ -31,38 +28,7 @@ func TestCreateTokenPair(t *testing.T) { assert.True(t, b.Claims.Perms.Has("mjwt:test2")) assert.False(t, b.Claims.Perms.Has("mjwt:test3")) - _, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken) - assert.NoError(t, err) - assert.Equal(t, "1", b2.Subject) - assert.Equal(t, "test2", b2.ID) -} - -func TestCreateTokenPairWithKID(t *testing.T) { - t.Parallel() - key, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - - kStore := mjwt.NewMJwtKeyStore() - kStore.SetKey("test", key) - - ps := claims.NewPermStorage() - ps.Set("mjwt:test") - ps.Set("mjwt:test2") - - s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore) - - accessToken, refreshToken, err := CreateTokenPairWithKID(s, "1", "test", "test2", nil, nil, ps, "test") - assert.NoError(t, err) - - _, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken) - assert.NoError(t, err) - assert.Equal(t, "1", b.Subject) - assert.Equal(t, "test", b.ID) - assert.True(t, b.Claims.Perms.Has("mjwt:test")) - assert.True(t, b.Claims.Perms.Has("mjwt:test2")) - assert.False(t, b.Claims.Perms.Has("mjwt:test3")) - - _, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken) + _, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken) assert.NoError(t, err) assert.Equal(t, "1", b2.Subject) assert.Equal(t, "test2", b2.ID) diff --git a/claims/perms.go b/auth/perms.go similarity index 99% rename from claims/perms.go rename to auth/perms.go index 861963b..7fed025 100644 --- a/claims/perms.go +++ b/auth/perms.go @@ -1,4 +1,4 @@ -package claims +package auth import ( "bufio" diff --git a/claims/perms_test.go b/auth/perms_test.go similarity index 99% rename from claims/perms_test.go rename to auth/perms_test.go index 87b7277..09bded9 100644 --- a/claims/perms_test.go +++ b/auth/perms_test.go @@ -1,4 +1,4 @@ -package claims +package auth import ( "bytes" diff --git a/auth/refresh-token.go b/auth/refresh-token.go index b0b4675..0871786 100644 --- a/auth/refresh-token.go +++ b/auth/refresh-token.go @@ -16,21 +16,11 @@ func (r RefreshTokenClaims) Valid() error { return nil } func (r RefreshTokenClaims) Type() string { return "refresh-token" } // CreateRefreshToken creates a refresh token with the default 7 day duration -func CreateRefreshToken(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) { +func CreateRefreshToken(p *mjwt.Issuer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) { return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, ati, aud) } // CreateRefreshTokenWithDuration creates a refresh token with a custom duration -func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) { +func CreateRefreshTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) { return p.GenerateJwt(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati}) } - -// CreateRefreshTokenWithKID creates a refresh token with the default 7 day duration and the specified kID -func CreateRefreshTokenWithKID(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) { - return CreateRefreshTokenWithDurationAndKID(p, time.Hour*24*7, sub, id, ati, aud, kID) -} - -// CreateRefreshTokenWithDurationAndKID creates a refresh token with a custom duration and the specified kID -func CreateRefreshTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) { - return p.GenerateJwtWithKID(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati}, kID) -} diff --git a/auth/refresh-token_test.go b/auth/refresh-token_test.go index bcb7521..c932921 100644 --- a/auth/refresh-token_test.go +++ b/auth/refresh-token_test.go @@ -1,8 +1,6 @@ package auth import ( - "crypto/rand" - "crypto/rsa" "github.com/1f349/mjwt" "github.com/stretchr/testify/assert" "testing" @@ -10,35 +8,15 @@ import ( func TestCreateRefreshToken(t *testing.T) { t.Parallel() - key, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - s := mjwt.NewMJwtSigner("mjwt.test", key) + kStore := mjwt.NewKeyStore() + s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", kStore) + assert.NoError(t, err) refreshToken, err := CreateRefreshToken(s, "1", "test", "test2", nil) assert.NoError(t, err) - _, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken) - assert.NoError(t, err) - assert.Equal(t, "1", b.Subject) - assert.Equal(t, "test", b.ID) - assert.Equal(t, "test2", b.Claims.AccessTokenId) -} - -func TestCreateRefreshTokenWithKID(t *testing.T) { - t.Parallel() - key, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - - kStore := mjwt.NewMJwtKeyStore() - kStore.SetKey("test", key) - - s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore) - - refreshToken, err := CreateRefreshTokenWithKID(s, "1", "test", "test2", nil, "test") - assert.NoError(t, err) - - _, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken) + _, b, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken) assert.NoError(t, err) assert.Equal(t, "1", b.Subject) assert.Equal(t, "test", b.ID) diff --git a/mjwt.go b/claims.go similarity index 92% rename from mjwt.go rename to claims.go index fc1a0af..f5af7ee 100644 --- a/mjwt.go +++ b/claims.go @@ -10,11 +10,11 @@ import ( var ErrClaimTypeMismatch = errors.New("claim type mismatch") // wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct -func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] { +func wrapClaims[T Claims](sub, id, issuer string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] { now := time.Now() return (&BaseTypeClaims[T]{ RegisteredClaims: jwt.RegisteredClaims{ - Issuer: p.Issuer(), + Issuer: issuer, Subject: sub, Audience: aud, ExpiresAt: jwt.NewNumericDate(now.Add(dur)), @@ -28,12 +28,12 @@ func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur ti // ExtractClaims uses a Verifier to validate the MJWT token and returns the parsed // token and BaseTypeClaims -func ExtractClaims[T Claims](p Verifier, token string) (*jwt.Token, BaseTypeClaims[T], error) { +func ExtractClaims[T Claims](ks *KeyStore, token string) (*jwt.Token, BaseTypeClaims[T], error) { b := BaseTypeClaims[T]{ RegisteredClaims: jwt.RegisteredClaims{}, Claims: *new(T), } - tok, err := p.VerifyJwt(token, &b) + tok, err := ks.VerifyJwt(token, &b) return tok, b, err } diff --git a/claims_test.go b/claims_test.go new file mode 100644 index 0000000..f35610d --- /dev/null +++ b/claims_test.go @@ -0,0 +1,100 @@ +package mjwt + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +type testClaims struct{ TestValue string } + +func (t testClaims) Valid() error { + if t.TestValue != "hello" && t.TestValue != "world" { + return fmt.Errorf("TestValue should be hello") + } + return nil +} + +func (t testClaims) Type() string { return "testClaims" } + +type testClaims2 struct{ TestValue2 string } + +func (t testClaims2) Valid() error { + if t.TestValue2 != "world" { + return fmt.Errorf("TestValue2 should be world") + } + return nil +} + +func (t testClaims2) Type() string { return "testClaims2" } + +func TestExtractClaims(t *testing.T) { + t.Parallel() + kStore := NewKeyStore() + + t.Run("TestNoKID", func(t *testing.T) { + t.Parallel() + s, err := NewIssuerWithKeyStore("mjwt.test", "key1", kStore) + assert.NoError(t, err) + token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}) + assert.NoError(t, err) + + a, _, err := ExtractClaims[testClaims](kStore, token) + assert.NoError(t, err) + kid, _ := a.Header["kid"].(string) + assert.Equal(t, "key1", kid) + }) + + t.Run("TestKID", func(t *testing.T) { + t.Parallel() + s, err := NewIssuerWithKeyStore("mjwt.test", "key2", kStore) + assert.NoError(t, err) + s2, err := NewIssuerWithKeyStore("mjwt.test", "key3", kStore) + assert.NoError(t, err) + + token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}) + assert.NoError(t, err) + token2, err := s2.GenerateJwt("2", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}) + assert.NoError(t, err) + + k1, _, err := ExtractClaims[testClaims](kStore, token1) + assert.NoError(t, err) + k2, _, err := ExtractClaims[testClaims](kStore, token2) + assert.NoError(t, err) + assert.NotEqual(t, k1.Header["kid"], k2.Header["kid"]) + }) +} + +func TestExtractClaimsFail(t *testing.T) { + t.Parallel() + kStore := NewKeyStore() + + t.Run("TestInvalidClaims", func(t *testing.T) { + t.Parallel() + s, err := NewIssuerWithKeyStore("mjwt.test", "key1", kStore) + assert.NoError(t, err) + token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}) + assert.NoError(t, err) + + _, _, err = ExtractClaims[testClaims2](kStore, token) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrClaimTypeMismatch) + }) + + t.Run("TestKIDNonExist", func(t *testing.T) { + t.Parallel() + + s, err := NewIssuerWithKeyStore("mjwt.test", "key2", kStore) + assert.NoError(t, err) + token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}) + assert.NoError(t, err) + + kStore.RemoveKey("key2") + assert.NotContains(t, kStore.ListKeys(), "key2") + + _, _, err = ExtractClaims[testClaims](kStore, token) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrMissingPublicKey) + }) +} diff --git a/cmd/mjwt/access.go b/cmd/mjwt/access.go index a2cb35e..0f0efcb 100644 --- a/cmd/mjwt/access.go +++ b/cmd/mjwt/access.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/1f349/mjwt" "github.com/1f349/mjwt/auth" - "github.com/1f349/mjwt/claims" "github.com/1f349/rsa-helper/rsaprivate" "github.com/golang-jwt/jwt/v4" "github.com/google/subcommands" @@ -35,7 +34,7 @@ func (s *accessCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&s.id, "id", "", "MJWT ID") f.StringVar(&s.audience, "aud", "", "Comma separated audience items for the MJWT") f.StringVar(&s.duration, "dur", "15m", "Duration for the MJWT (default: 15m)") - f.StringVar(&s.kID, "kid", "\x00", "The Key ID of the signing key") + f.StringVar(&s.kID, "kid", "", "The Key ID of the signing key") } func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { @@ -51,7 +50,7 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{} return subcommands.ExitFailure } - ps := claims.NewPermStorage() + ps := auth.NewPermStorage() for i := 1; i < len(args); i++ { ps.Set(args[i]) } @@ -67,16 +66,16 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{} } var token string - if s.kID == "\x00" { - signer := mjwt.NewMJwtSigner(s.issuer, key) - token, err = signer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps}) - } else { - kStore := mjwt.NewMJwtKeyStore() - kStore.SetKey(s.kID, key) - signer := mjwt.NewMJwtSignerWithKeyStore(s.issuer, nil, kStore) - token, err = signer.GenerateJwtWithKID(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps}, s.kID) + + kStore := mjwt.NewKeyStore() + kStore.LoadPrivateKey(s.kID, key) + + issuer, err := mjwt.NewIssuerWithKeyStore(s.issuer, s.kID, kStore) + if err != nil { + panic("this should not fail") } + token, err = issuer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps}) if err != nil { _, _ = fmt.Fprintln(os.Stderr, "Error: Failed to generate MJWT token: ", err) return subcommands.ExitFailure diff --git a/claims/empty-claims.go b/empty-claims.go similarity index 77% rename from claims/empty-claims.go rename to empty-claims.go index 84ff4f9..a4f2e6e 100644 --- a/claims/empty-claims.go +++ b/empty-claims.go @@ -1,7 +1,7 @@ -package claims +package mjwt // EmptyClaims contains no claims -type EmptyClaims struct {} +type EmptyClaims struct{} func (e EmptyClaims) Valid() error { return nil } diff --git a/go.mod b/go.mod index d5857c2..b012c9e 100644 --- a/go.mod +++ b/go.mod @@ -10,11 +10,17 @@ require ( github.com/golang-jwt/jwt/v4 v4.5.0 github.com/google/subcommands v1.2.0 github.com/pkg/errors v0.9.1 - github.com/stretchr/testify v1.8.4 + github.com/spf13/afero v1.11.0 + github.com/stretchr/testify v1.9.0 + golang.org/x/sync v0.7.0 gopkg.in/yaml.v3 v3.0.1 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/pretty v0.3.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect + golang.org/x/text v0.16.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect ) diff --git a/go.sum b/go.sum index bc4493a..19e7eff 100644 --- a/go.sum +++ b/go.sum @@ -2,19 +2,38 @@ github.com/1f349/rsa-helper v0.0.2 h1:N/fLQqg5wrjIzG6G4zdwa5Xcv9/jIPutCls9YekZr9 github.com/1f349/rsa-helper v0.0.2/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg= github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA= github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/interfaces.go b/interfaces.go deleted file mode 100644 index aeef231..0000000 --- a/interfaces.go +++ /dev/null @@ -1,39 +0,0 @@ -package mjwt - -import ( - "crypto/rsa" - "github.com/golang-jwt/jwt/v4" - "time" -) - -// Signer is used to generate MJWT tokens. -// Signer can also be used as a Verifier. -type Signer interface { - Verifier - GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) - SignJwt(claims jwt.Claims) (string, error) - GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error) - SignJwtWithKID(claims jwt.Claims, kID string) (string, error) - Issuer() string - PrivateKey() *rsa.PrivateKey - PrivateKeyOf(kID string) *rsa.PrivateKey -} - -// Verifier is used to verify the validity MJWT tokens and extract the claim values. -type Verifier interface { - VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) - PublicKey() *rsa.PublicKey - PublicKeyOf(kID string) *rsa.PublicKey - GetKeyStore() KeyStore -} - -// KeyStore is used for the kid header support in Signer and Verifier. -type KeyStore interface { - SetKey(kID string, prvKey *rsa.PrivateKey) - SetKeyPublic(kID string, pubKey *rsa.PublicKey) - RemoveKey(kID string) - ListKeys() []string - GetKey(kID string) *rsa.PrivateKey - GetKeyPublic(kID string) *rsa.PublicKey - ClearKeys() -} diff --git a/issuer.go b/issuer.go new file mode 100644 index 0000000..02a1ee4 --- /dev/null +++ b/issuer.go @@ -0,0 +1,49 @@ +package mjwt + +import ( + "crypto/rand" + "crypto/rsa" + "github.com/golang-jwt/jwt/v4" + "time" +) + +type Issuer struct { + issuer string + kid string + keystore *KeyStore +} + +func NewIssuer(name, kid string) (*Issuer, error) { + return NewIssuerWithKeyStore(name, kid, NewKeyStore()) +} + +func NewIssuerWithKeyStore(name, kid string, keystore *KeyStore) (*Issuer, error) { + i := &Issuer{name, kid, keystore} + if i.keystore.HasPrivateKey(kid) { + return i, nil + } + key, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, err + } + i.keystore.LoadPrivateKey(kid, key) + return i, i.keystore.SaveSingleKey(kid) +} + +func (i *Issuer) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) { + return i.SignJwt(wrapClaims[Claims](sub, id, i.issuer, aud, dur, claims)) +} + +func (i *Issuer) SignJwt(wrapped jwt.Claims) (string, error) { + key, err := i.PrivateKey() + if err != nil { + return "", err + } + token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped) + token.Header["kid"] = i.kid + return token.SignedString(key) +} + +func (i *Issuer) PrivateKey() (*rsa.PrivateKey, error) { + return i.keystore.GetPrivateKey(i.kid) +} diff --git a/issuer_test.go b/issuer_test.go new file mode 100644 index 0000000..7d89a44 --- /dev/null +++ b/issuer_test.go @@ -0,0 +1,58 @@ +package mjwt + +import ( + "crypto/rand" + "crypto/rsa" + "github.com/1f349/rsa-helper/rsaprivate" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestNewIssuer(t *testing.T) { + t.Parallel() + t.Run("generate missing key for issuer", func(t *testing.T) { + t.Parallel() + kStore := NewKeyStore() + issuer, err := NewIssuerWithKeyStore("Test", "test", kStore) + assert.NoError(t, err) + assert.True(t, kStore.HasPrivateKey("test")) + assert.True(t, kStore.HasPublicKey("test")) + assert.Equal(t, "Test", issuer.issuer) + assert.Equal(t, "test", issuer.kid) + }) + t.Run("use existing issuer key", func(t *testing.T) { + t.Parallel() + kStore := NewKeyStore() + key, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + kStore.LoadPrivateKey("test", key) + issuer, err := NewIssuerWithKeyStore("Test", "test", kStore) + assert.NoError(t, err) + assert.True(t, kStore.HasPrivateKey("test")) + assert.True(t, kStore.HasPublicKey("test")) + assert.Equal(t, "Test", issuer.issuer) + assert.Equal(t, "test", issuer.kid) + privateKey, err := issuer.PrivateKey() + assert.NoError(t, err) + assert.True(t, key.Equal(privateKey)) + }) + t.Run("generate missing key in filesystem", func(t *testing.T) { + t.Parallel() + dir := afero.NewMemMapFs() + kStore := NewKeyStoreWithDir(dir) + issuer, err := NewIssuerWithKeyStore("Test", "test", kStore) + assert.NoError(t, err) + assert.True(t, kStore.HasPrivateKey("test")) + assert.True(t, kStore.HasPublicKey("test")) + assert.Equal(t, "Test", issuer.issuer) + assert.Equal(t, "test", issuer.kid) + privKeyFile, err := dir.Open("test.private.pem") + assert.NoError(t, err) + privKey, err := rsaprivate.Decode(privKeyFile) + assert.NoError(t, err) + key, err := issuer.PrivateKey() + assert.NoError(t, err) + assert.True(t, key.Equal(privKey)) + }) +} diff --git a/key_store.go b/key_store.go deleted file mode 100644 index 082fa3c..0000000 --- a/key_store.go +++ /dev/null @@ -1,185 +0,0 @@ -package mjwt - -import ( - "crypto/rsa" - "errors" - "github.com/1f349/rsa-helper/rsaprivate" - "github.com/1f349/rsa-helper/rsapublic" - "os" - "path" - "strings" - "sync" -) - -// defaultMJwtKeyStore implements KeyStore and stores kIDs against just rsa.PublicKey -// or with rsa.PrivateKey instances as well. -type defaultMJwtKeyStore struct { - rwLocker *sync.RWMutex - store map[string]*rsa.PrivateKey - storePub map[string]*rsa.PublicKey -} - -var _ KeyStore = &defaultMJwtKeyStore{} - -// NewMJwtKeyStore creates a new defaultMJwtKeyStore. -func NewMJwtKeyStore() KeyStore { - return &defaultMJwtKeyStore{ - rwLocker: new(sync.RWMutex), - store: make(map[string]*rsa.PrivateKey), - storePub: make(map[string]*rsa.PublicKey), - } -} - -// NewMJwtKeyStoreFromDirectory loads keys from a directory with the specified extensions to denote public and private -// rsa keys; the kID is the filename of the key up to the first . -func NewMJwtKeyStoreFromDirectory(directory, keyPrvExt, keyPubExt string) (KeyStore, error) { - // Create empty KeyStore - ks := NewMJwtKeyStore().(*defaultMJwtKeyStore) - // List directory contents - dirEntries, err := os.ReadDir(directory) - if err != nil { - return nil, err - } - errs := make([]error, 0, len(dirEntries)/2) - // Import keys from files, based on extension - for _, entry := range dirEntries { - if entry.IsDir() { - continue - } - kID, _, _ := strings.Cut(entry.Name(), ".") - if kID == "" { - continue - } - pExt := path.Ext(entry.Name()) - if pExt == "."+keyPrvExt { - // Load rsa private key with the file name as the kID (Up to the first .) - key, err2 := rsaprivate.Read(path.Join(directory, entry.Name())) - if err2 == nil { - ks.store[kID] = key - ks.storePub[kID] = &key.PublicKey - } - errs = append(errs, err2) - } else if pExt == "."+keyPubExt { - // Load rsa public key with the file name as the kID (Up to the first .) - key, err2 := rsapublic.Read(path.Join(directory, entry.Name())) - if err2 == nil { - _, exs := ks.store[kID] - if !exs { - ks.store[kID] = nil - } - ks.storePub[kID] = key - } - errs = append(errs, err2) - } - } - return ks, errors.Join(errs...) -} - -// ExportKeyStore saves all the keys stored in the specified KeyStore into a directory with the specified -// extensions for public and private keys -func ExportKeyStore(ks KeyStore, directory, keyPrvExt, keyPubExt string) error { - if ks == nil { - return errors.New("ks is nil") - } - - // Create directory - err := os.MkdirAll(directory, 0700) - if err != nil { - return err - } - - errs := make([]error, 0, len(ks.ListKeys())/2) - // Export all keys - for _, kID := range ks.ListKeys() { - kPrv := ks.GetKey(kID) - if kPrv != nil { - err2 := rsaprivate.Write(path.Join(directory, kID+"."+keyPrvExt), kPrv) - errs = append(errs, err2) - } - kPub := ks.GetKeyPublic(kID) - if kPub != nil { - err2 := rsapublic.Write(path.Join(directory, kID+"."+keyPubExt), kPub) - errs = append(errs, err2) - } - } - return errors.Join(errs...) -} - -// SetKey adds a new rsa.PrivateKey with the specified kID to the KeyStore. -func (d *defaultMJwtKeyStore) SetKey(kID string, prvKey *rsa.PrivateKey) { - if prvKey == nil { - return - } - d.rwLocker.Lock() - defer d.rwLocker.Unlock() - d.store[kID] = prvKey - d.storePub[kID] = &prvKey.PublicKey - return -} - -// SetKeyPublic adds a new rsa.PublicKey with the specified kID to the KeyStore. -func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) { - if pubKey == nil { - return - } - d.rwLocker.Lock() - defer d.rwLocker.Unlock() - _, exs := d.store[kID] - if !exs { - d.store[kID] = nil - } - d.storePub[kID] = pubKey - return -} - -// RemoveKey removes a specified kID from the KeyStore. -func (d *defaultMJwtKeyStore) RemoveKey(kID string) { - d.rwLocker.Lock() - defer d.rwLocker.Unlock() - delete(d.store, kID) - delete(d.storePub, kID) - return -} - -// ListKeys lists the kIDs of all the keys in the KeyStore. -func (d *defaultMJwtKeyStore) ListKeys() []string { - d.rwLocker.RLock() - defer d.rwLocker.RUnlock() - lKeys := make([]string, len(d.store)) - i := 0 - for k := range d.store { - lKeys[i] = k - i++ - } - return lKeys -} - -// GetKey gets the rsa.PrivateKey given the kID in the KeyStore or null if not found. -func (d *defaultMJwtKeyStore) GetKey(kID string) *rsa.PrivateKey { - d.rwLocker.RLock() - defer d.rwLocker.RUnlock() - kPrv, ok := d.store[kID] - if ok { - return kPrv - } - return nil -} - -// GetKeyPublic gets the rsa.PublicKey given the kID in the KeyStore or null if not found. -func (d *defaultMJwtKeyStore) GetKeyPublic(kID string) *rsa.PublicKey { - d.rwLocker.RLock() - defer d.rwLocker.RUnlock() - kPub, ok := d.storePub[kID] - if ok { - return kPub - } - return nil -} - -// ClearKeys removes all the stored keys in the KeyStore. -func (d *defaultMJwtKeyStore) ClearKeys() { - d.rwLocker.Lock() - defer d.rwLocker.Unlock() - clear(d.store) - clear(d.storePub) -} diff --git a/key_store_test.go b/key_store_test.go deleted file mode 100644 index 264f689..0000000 --- a/key_store_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package mjwt - -import ( - "crypto/rand" - "crypto/rsa" - "github.com/1f349/rsa-helper/rsaprivate" - "github.com/1f349/rsa-helper/rsapublic" - "github.com/stretchr/testify/assert" - "os" - "path" - "testing" -) - -const kst_prvExt = "prv" -const kst_pubExt = "pub" - -func setupTestDirKeyStore(t *testing.T, genKeys bool) (string, func(t *testing.T)) { - tempDir, err := os.MkdirTemp("", "this-is-a-test-dir") - assert.NoError(t, err) - - if genKeys { - key1, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+kst_prvExt), key1) - assert.NoError(t, err) - - key2, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+kst_prvExt), key2) - assert.NoError(t, err) - err = rsapublic.Write(path.Join(tempDir, "key2.pem."+kst_pubExt), &key2.PublicKey) - assert.NoError(t, err) - - key3, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - err = rsapublic.Write(path.Join(tempDir, "key3.pem."+kst_pubExt), &key3.PublicKey) - assert.NoError(t, err) - } - - return tempDir, func(t *testing.T) { - err := os.RemoveAll(tempDir) - assert.NoError(t, err) - } -} - -func commonSubTestsKeyStore(t *testing.T, kStore KeyStore) { - key4, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - - key5, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - - const extraKID1 = "key4" - const extraKID2 = "key5" - - t.Run("TestSetKey", func(t *testing.T) { - kStore.SetKey(extraKID1, key4) - assert.Contains(t, kStore.ListKeys(), extraKID1) - }) - - t.Run("TestSetKeyPublic", func(t *testing.T) { - kStore.SetKeyPublic(extraKID2, &key5.PublicKey) - assert.Contains(t, kStore.ListKeys(), extraKID2) - }) - - t.Run("TestGetKey", func(t *testing.T) { - oKey := kStore.GetKey(extraKID1) - assert.Same(t, key4, oKey) - pKey := kStore.GetKey(extraKID2) - assert.Nil(t, pKey) - aKey := kStore.GetKey("key1") - assert.NotNil(t, aKey) - bKey := kStore.GetKey("key2") - assert.NotNil(t, bKey) - cKey := kStore.GetKey("key3") - assert.Nil(t, cKey) - }) - - t.Run("TestGetKeyPublic", func(t *testing.T) { - oKey := kStore.GetKeyPublic(extraKID1) - assert.Same(t, &key4.PublicKey, oKey) - pKey := kStore.GetKeyPublic(extraKID2) - assert.Same(t, &key5.PublicKey, pKey) - aKey := kStore.GetKeyPublic("key1") - assert.NotNil(t, aKey) - bKey := kStore.GetKeyPublic("key2") - assert.NotNil(t, bKey) - cKey := kStore.GetKeyPublic("key3") - assert.NotNil(t, cKey) - }) - - t.Run("TestRemoveKey", func(t *testing.T) { - kStore.RemoveKey(extraKID1) - assert.NotContains(t, kStore.ListKeys(), extraKID1) - oKey1 := kStore.GetKey(extraKID1) - assert.Nil(t, oKey1) - oKey2 := kStore.GetKeyPublic(extraKID1) - assert.Nil(t, oKey2) - }) - - t.Run("TestClearKeys", func(t *testing.T) { - kStore.ClearKeys() - assert.Empty(t, kStore.ListKeys()) - }) -} - -func TestNewMJwtKeyStoreFromDirectory(t *testing.T) { - t.Parallel() - - tempDir, cleaner := setupTestDirKeyStore(t, true) - defer cleaner(t) - - kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, kst_prvExt, kst_pubExt) - assert.NoError(t, err) - - assert.Len(t, kStore.ListKeys(), 3) - kIDsToFind := []string{"key1", "key2", "key3"} - for _, k := range kIDsToFind { - assert.Contains(t, kStore.ListKeys(), k) - } - - commonSubTestsKeyStore(t, kStore) -} - -func TestExportKeyStore(t *testing.T) { - t.Parallel() - - tempDir, cleaner := setupTestDirKeyStore(t, true) - defer cleaner(t) - tempDir2, cleaner2 := setupTestDirKeyStore(t, false) - defer cleaner2(t) - - kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, kst_prvExt, kst_pubExt) - assert.NoError(t, err) - - const prvExt2 = "v" - const pubExt2 = "b" - - err = ExportKeyStore(kStore, tempDir2, prvExt2, pubExt2) - assert.NoError(t, err) - - kStore2, err := NewMJwtKeyStoreFromDirectory(tempDir2, prvExt2, pubExt2) - assert.NoError(t, err) - - kIDsToFind := kStore.ListKeys() - assert.Len(t, kStore2.ListKeys(), len(kIDsToFind)) - for _, k := range kIDsToFind { - assert.Contains(t, kStore2.ListKeys(), k) - } - - commonSubTestsKeyStore(t, kStore2) -} diff --git a/keystore.go b/keystore.go new file mode 100644 index 0000000..dc2e0c2 --- /dev/null +++ b/keystore.go @@ -0,0 +1,233 @@ +package mjwt + +import ( + "crypto/rsa" + "errors" + "github.com/1f349/rsa-helper/rsaprivate" + "github.com/1f349/rsa-helper/rsapublic" + "github.com/golang-jwt/jwt/v4" + "github.com/spf13/afero" + "golang.org/x/sync/errgroup" + "io/fs" + "path/filepath" + "runtime" + "strings" + "sync" +) + +var ErrMissingPrivateKey = errors.New("missing private key") +var ErrMissingPublicKey = errors.New("missing public key") +var ErrMissingKeyPair = errors.New("missing key pair") + +const PrivateStr = ".private" +const PublicStr = ".public" + +const PemExt = ".pem" +const PrivatePemExt = PrivateStr + PemExt +const PublicPemExt = PublicStr + PemExt + +type KeyStore struct { + mu *sync.RWMutex + store map[string]*keyPair + dir afero.Fs +} + +func NewKeyStore() *KeyStore { + return &KeyStore{ + mu: new(sync.RWMutex), + store: make(map[string]*keyPair), + } +} + +func NewKeyStoreWithDir(dir afero.Fs) *KeyStore { + keyStore := NewKeyStore() + keyStore.dir = dir + return keyStore +} + +func NewKeyStoreFromDir(dir afero.Fs) (*KeyStore, error) { + keyStore := NewKeyStoreWithDir(dir) + err := afero.Walk(dir, ".", func(path string, d fs.FileInfo, err error) error { + // maybe this is "name.private.pem" + name := filepath.Base(path) + ext := filepath.Ext(name) + if ext != PemExt { + return nil + } + + name = strings.TrimSuffix(name, ext) + ext = filepath.Ext(name) + name = strings.TrimSuffix(name, ext) + switch ext { + case PrivateStr: + open, err := dir.Open(path) + if err != nil { + return err + } + decode, err := rsaprivate.Decode(open) + if err != nil { + return err + } + keyStore.LoadPrivateKey(name, decode) + return nil + case PublicStr: + open, err := dir.Open(path) + if err != nil { + return err + } + decode, err := rsapublic.Decode(open) + if err != nil { + return err + } + keyStore.LoadPublicKey(name, decode) + return nil + } + + // still invalid + return nil + }) + return keyStore, err +} + +type keyPair struct { + private *rsa.PrivateKey + public *rsa.PublicKey +} + +func (k *KeyStore) LoadPrivateKey(kid string, key *rsa.PrivateKey) { + k.mu.Lock() + if k.store[kid] == nil { + k.store[kid] = &keyPair{} + } + k.store[kid].private = key + k.store[kid].public = &key.PublicKey + k.mu.Unlock() +} + +func (k *KeyStore) LoadPublicKey(kid string, key *rsa.PublicKey) { + k.mu.Lock() + if k.store[kid] == nil { + k.store[kid] = &keyPair{} + } + k.store[kid].public = key + k.mu.Unlock() +} + +func (k *KeyStore) RemoveKey(kid string) { + k.mu.Lock() + delete(k.store, kid) + k.mu.Unlock() +} + +func (k *KeyStore) ListKeys() []string { + k.mu.RLock() + defer k.mu.RUnlock() + keys := make([]string, 0, len(k.store)) + for k, _ := range k.store { + keys = append(keys, k) + } + return keys +} + +func (k *KeyStore) GetPrivateKey(kid string) (*rsa.PrivateKey, error) { + k.mu.RLock() + defer k.mu.RUnlock() + if !k.internalHasPrivateKey(kid) { + return nil, ErrMissingPrivateKey + } + return k.store[kid].private, nil +} + +func (k *KeyStore) GetPublicKey(kid string) (*rsa.PublicKey, error) { + k.mu.RLock() + defer k.mu.RUnlock() + if !k.internalHasPublicKey(kid) { + return nil, ErrMissingPublicKey + } + return k.store[kid].public, nil +} + +func (k *KeyStore) ClearKeys() { + k.mu.Lock() + clear(k.store) + k.mu.Unlock() +} + +func (k *KeyStore) HasPrivateKey(kid string) bool { + k.mu.RLock() + defer k.mu.RUnlock() + return k.internalHasPrivateKey(kid) +} + +func (k *KeyStore) internalHasPrivateKey(kid string) bool { + v := k.store[kid] + return v != nil && v.private != nil +} + +func (k *KeyStore) HasPublicKey(kid string) bool { + k.mu.RLock() + defer k.mu.RUnlock() + return k.internalHasPublicKey(kid) +} + +func (k *KeyStore) internalHasPublicKey(kid string) bool { + v := k.store[kid] + return v != nil && v.public != nil +} + +func (k *KeyStore) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) { + withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { + kid, ok := token.Header["kid"].(string) + if !ok { + return nil, ErrMissingPublicKey + } + return k.GetPublicKey(kid) + }) + if err != nil { + return nil, err + } + return withClaims, claims.Valid() +} + +func (k *KeyStore) SaveSingleKey(kid string) error { + if k.dir == nil { + return nil + } + + k.mu.RLock() + pair := k.store[kid] + k.mu.RUnlock() + if pair == nil { + return ErrMissingKeyPair + } + + var errs []error + if pair.private != nil { + errs = append(errs, afero.WriteFile(k.dir, kid+PrivatePemExt, rsaprivate.Encode(pair.private), 0600)) + } + if pair.public != nil { + errs = append(errs, afero.WriteFile(k.dir, kid+PublicPemExt, rsapublic.Encode(pair.public), 0600)) + } + return errors.Join(errs...) +} + +func (k *KeyStore) SaveKeys() error { + k.mu.RLock() + defer k.mu.RUnlock() + + workers := new(errgroup.Group) + workers.SetLimit(runtime.NumCPU()) + for kid, pair := range k.store { + workers.Go(func() error { + var errs []error + if pair.private != nil { + errs = append(errs, afero.WriteFile(k.dir, kid+PrivatePemExt, rsaprivate.Encode(pair.private), 0600)) + } + if pair.public != nil { + errs = append(errs, afero.WriteFile(k.dir, kid+PublicPemExt, rsapublic.Encode(pair.public), 0600)) + } + return errors.Join(errs...) + }) + } + return workers.Wait() +} diff --git a/keystore_test.go b/keystore_test.go new file mode 100644 index 0000000..9f58998 --- /dev/null +++ b/keystore_test.go @@ -0,0 +1,175 @@ +package mjwt + +import ( + "crypto/rand" + "crypto/rsa" + "github.com/1f349/rsa-helper/rsaprivate" + "github.com/1f349/rsa-helper/rsapublic" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "sort" + "testing" +) + +const kst_prvExt = "prv" +const kst_pubExt = "pub" + +func setupTestDirKeyStore(t *testing.T, genKeys bool) afero.Fs { + tempDir := afero.NewMemMapFs() + + if genKeys { + key1, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = afero.WriteFile(tempDir, "key1.private.pem", rsaprivate.Encode(key1), 0600) + assert.NoError(t, err) + + key2, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = afero.WriteFile(tempDir, "key2.private.pem", rsaprivate.Encode(key2), 0600) + assert.NoError(t, err) + err = afero.WriteFile(tempDir, "key2.public.pem", rsapublic.Encode(&key2.PublicKey), 0600) + assert.NoError(t, err) + + key3, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + err = afero.WriteFile(tempDir, "key3.public.pem", rsapublic.Encode(&key3.PublicKey), 0600) + assert.NoError(t, err) + } + + return tempDir +} + +func commonSubTestsKeyStore(t *testing.T, kStore *KeyStore) { + key4, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + key5, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + const extraKID1 = "key4" + const extraKID2 = "key5" + + t.Run("TestSetKey", func(t *testing.T) { + kStore.LoadPrivateKey(extraKID1, key4) + assert.Contains(t, kStore.ListKeys(), extraKID1) + }) + + t.Run("TestSetKeyPublic", func(t *testing.T) { + kStore.LoadPublicKey(extraKID2, &key5.PublicKey) + assert.Contains(t, kStore.ListKeys(), extraKID2) + }) + + t.Run("TestGetPrivateKey", func(t *testing.T) { + oKey, err := kStore.GetPrivateKey(extraKID1) + assert.NoError(t, err) + assert.Same(t, key4, oKey) + pKey, err := kStore.GetPrivateKey(extraKID2) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrMissingPrivateKey) + assert.Nil(t, pKey) + aKey, err := kStore.GetPrivateKey("key1") + assert.NoError(t, err) + assert.NotNil(t, aKey) + bKey, err := kStore.GetPrivateKey("key2") + assert.NoError(t, err) + assert.NotNil(t, bKey) + cKey, err := kStore.GetPrivateKey("key3") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrMissingPrivateKey) + assert.Nil(t, cKey) + wKey, err := kStore.GetPrivateKey("key1337") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrMissingPrivateKey) + assert.Nil(t, wKey) + }) + + t.Run("TestGetPublicKey", func(t *testing.T) { + oKey, err := kStore.GetPublicKey(extraKID1) + assert.NoError(t, err) + assert.Same(t, &key4.PublicKey, oKey) + pKey, err := kStore.GetPublicKey(extraKID2) + assert.NoError(t, err) + assert.Same(t, &key5.PublicKey, pKey) + aKey, err := kStore.GetPublicKey("key1") + assert.NoError(t, err) + assert.NotNil(t, aKey) + bKey, err := kStore.GetPublicKey("key2") + assert.NoError(t, err) + assert.NotNil(t, bKey) + cKey, err := kStore.GetPublicKey("key3") + assert.NoError(t, err) + assert.NotNil(t, cKey) + wKey, err := kStore.GetPublicKey("key1337") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrMissingPublicKey) + assert.Nil(t, wKey) + }) + + t.Run("TestRemoveKey", func(t *testing.T) { + kStore.RemoveKey(extraKID1) + assert.NotContains(t, kStore.ListKeys(), extraKID1) + oKey1, err := kStore.GetPrivateKey(extraKID1) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrMissingPrivateKey) + assert.Nil(t, oKey1) + oKey2, err := kStore.GetPublicKey(extraKID1) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrMissingPublicKey) + assert.Nil(t, oKey2) + }) + + t.Run("TestClearKeys", func(t *testing.T) { + kStore.ClearKeys() + assert.Empty(t, kStore.ListKeys()) + }) +} + +func TestNewMJwtKeyStoreFromDirectory(t *testing.T) { + t.Parallel() + + tempDir := setupTestDirKeyStore(t, true) + + kStore, err := NewKeyStoreFromDir(tempDir) + assert.NoError(t, err) + + assert.Len(t, kStore.ListKeys(), 3) + kIDsToFind := []string{"key1", "key2", "key3"} + for _, k := range kIDsToFind { + assert.Contains(t, kStore.ListKeys(), k) + } + assert.True(t, kStore.HasPrivateKey("key1")) + assert.True(t, kStore.HasPublicKey("key1")) // loading a private key also loads the public key + assert.True(t, kStore.HasPrivateKey("key2")) + assert.True(t, kStore.HasPublicKey("key2")) + assert.False(t, kStore.HasPrivateKey("key3")) + assert.True(t, kStore.HasPublicKey("key3")) + + commonSubTestsKeyStore(t, kStore) +} + +func TestExportKeyStore(t *testing.T) { + t.Parallel() + + tempDir := setupTestDirKeyStore(t, true) + tempDir2 := setupTestDirKeyStore(t, false) + + kStore, err := NewKeyStoreFromDir(tempDir) + assert.NoError(t, err) + + // internally swap directory + kStore.dir = tempDir2 + + err = kStore.SaveKeys() + assert.NoError(t, err) + + kStore2, err := NewKeyStoreFromDir(tempDir2) + assert.NoError(t, err) + + kidList1 := kStore.ListKeys() + kidList2 := kStore2.ListKeys() + sort.Strings(kidList1) + sort.Strings(kidList2) + assert.Equal(t, kidList1, kidList2) + + commonSubTestsKeyStore(t, kStore2) +} diff --git a/mjwt_test.go b/mjwt_test.go deleted file mode 100644 index 1c059eb..0000000 --- a/mjwt_test.go +++ /dev/null @@ -1,143 +0,0 @@ -package mjwt - -import ( - "crypto/rand" - "crypto/rsa" - "fmt" - "github.com/stretchr/testify/assert" - "testing" - "time" -) - -var mt_ExtraKID = "tester" - -type testClaims struct{ TestValue string } - -func (t testClaims) Valid() error { - if t.TestValue != "hello" && t.TestValue != "world" { - return fmt.Errorf("TestValue should be hello") - } - return nil -} - -func (t testClaims) Type() string { return "testClaims" } - -type testClaims2 struct{ TestValue2 string } - -func (t testClaims2) Valid() error { - if t.TestValue2 != "world" { - return fmt.Errorf("TestValue2 should be world") - } - return nil -} - -func (t testClaims2) Type() string { return "testClaims2" } - -func setupTestKeyStoreMJWT(t *testing.T) (ks KeyStore, a, b, c *rsa.PrivateKey) { - ks = NewMJwtKeyStore() - var err error - - a, err = rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - ks.SetKey("key1", a) - - b, err = rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - ks.SetKey("key2", b) - - c, err = rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - ks.SetKey("key3", c) - - return -} - -func TestExtractClaims(t *testing.T) { - t.Parallel() - kStore, key, _, _ := setupTestKeyStoreMJWT(t) - - t.Run("TestNoKID", func(t *testing.T) { - t.Parallel() - s := NewMJwtSigner("mjwt.test", key) - token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}) - assert.NoError(t, err) - - m := NewMJwtVerifier(&key.PublicKey) - _, _, err = ExtractClaims[testClaims](m, token) - assert.NoError(t, err) - }) - - t.Run("TestKID", func(t *testing.T) { - t.Parallel() - s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore) - token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}) - assert.NoError(t, err) - token2, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}, "key2") - assert.NoError(t, err) - - m := NewMJwtVerifierWithKeyStore(&key.PublicKey, kStore) - _, _, err = ExtractClaims[testClaims](m, token1) - assert.NoError(t, err) - _, _, err = ExtractClaims[testClaims](m, token2) - assert.NoError(t, err) - }) -} - -func TestExtractClaimsFail(t *testing.T) { - t.Parallel() - kStore, key, key2, _ := setupTestKeyStoreMJWT(t) - - t.Run("TestInvalidClaims", func(t *testing.T) { - t.Parallel() - s := NewMJwtSigner("mjwt.test", key) - token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}) - assert.NoError(t, err) - - m := NewMJwtVerifier(&key.PublicKey) - _, _, err = ExtractClaims[testClaims2](m, token) - assert.Error(t, err) - assert.ErrorIs(t, err, ErrClaimTypeMismatch) - }) - - t.Run("TestDefaultKeyNoKID", func(t *testing.T) { - t.Parallel() - s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore) - token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}, "key1") - assert.NoError(t, err) - - m := NewMJwtVerifier(&key.PublicKey) - _, _, err = ExtractClaims[testClaims](m, token) - assert.Error(t, err) - assert.ErrorIs(t, err, ErrNoPublicKeyFound) - }) - - t.Run("TestNoDefaultKey", func(t *testing.T) { - t.Parallel() - s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore) - token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}) - assert.NoError(t, err) - - m := NewMJwtVerifierWithKeyStore(nil, kStore) - _, _, err = ExtractClaims[testClaims](m, token) - assert.Error(t, err) - assert.ErrorIs(t, err, ErrNoPublicKeyFound) - }) - - t.Run("TestKIDNonExist", func(t *testing.T) { - t.Parallel() - kStore.SetKey(mt_ExtraKID, key2) - assert.Contains(t, kStore.ListKeys(), mt_ExtraKID) - - s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore) - token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}, mt_ExtraKID) - assert.NoError(t, err) - - kStore.RemoveKey(mt_ExtraKID) - assert.NotContains(t, kStore.ListKeys(), mt_ExtraKID) - - m := NewMJwtVerifierWithKeyStore(&key.PublicKey, kStore) - _, _, err = ExtractClaims[testClaims](m, token) - assert.Error(t, err) - assert.ErrorIs(t, err, ErrNoPublicKeyFound) - }) -} diff --git a/signer.go b/signer.go deleted file mode 100644 index d378d8a..0000000 --- a/signer.go +++ /dev/null @@ -1,204 +0,0 @@ -package mjwt - -import ( - "bytes" - "crypto/rsa" - "errors" - "github.com/1f349/rsa-helper/rsaprivate" - "github.com/golang-jwt/jwt/v4" - "io" - "os" - "time" -) - -const readLimit = 10240 // 10 KiB - -var ErrNoPrivateKeyFound = errors.New("no private key found") - -// defaultMJwtSigner implements Signer and uses an rsa.PrivateKey and issuer name -// to generate MJWT tokens -type defaultMJwtSigner struct { - issuer string - key *rsa.PrivateKey - verify *defaultMJwtVerifier -} - -var _ Signer = &defaultMJwtSigner{} -var _ Verifier = &defaultMJwtSigner{} - -// NewMJwtSigner creates a new defaultMJwtSigner using the issuer name and rsa.PrivateKey -func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer { - return NewMJwtSignerWithKeyStore(issuer, key, NewMJwtKeyStore()) -} - -// NewMJwtSignerWithKeyStore creates a new defaultMJwtSigner using the issuer name, a rsa.PrivateKey -// for no kID and a KeyStore for kID based keys -func NewMJwtSignerWithKeyStore(issuer string, key *rsa.PrivateKey, kStore KeyStore) Signer { - var pKey *rsa.PublicKey = nil - if key != nil { - pKey = &key.PublicKey - } - return &defaultMJwtSigner{ - issuer: issuer, - key: key, - verify: NewMJwtVerifierWithKeyStore(pKey, kStore).(*defaultMJwtVerifier), - } -} - -// NewMJwtSignerFromFileOrCreate creates a new defaultMJwtSigner using the path -// of a rsa.PrivateKey file. If the file does not exist then the file is created -// and a new private key is generated. -func NewMJwtSignerFromFileOrCreate(issuer, file string, random io.Reader, bits int) (Signer, error) { - privateKey, err := readOrCreatePrivateKey(file, random, bits) - if err != nil { - return nil, err - } - return NewMJwtSigner(issuer, privateKey), nil -} - -// NewMJwtSignerFromFile creates a new defaultMJwtSigner using the path of a -// rsa.PrivateKey file. -func NewMJwtSignerFromFile(issuer, file string) (Signer, error) { - return NewMJwtSignerFromFileAndDirectory(issuer, file, "", "", "") -} - -// NewMJwtSignerFromDirectory creates a new defaultMJwtSigner using the path of a directory to -// load the keys into a KeyStore; there is no default rsa.PrivateKey -func NewMJwtSignerFromDirectory(issuer, directory, prvExt, pubExt string) (Signer, error) { - return NewMJwtSignerFromFileAndDirectory(issuer, "", directory, prvExt, pubExt) -} - -// NewMJwtSignerFromFileAndDirectory creates a new defaultMJwtSigner using the path of a rsa.PrivateKey -// file as the non kID key and the path of a directory to load the keys into a KeyStore -func NewMJwtSignerFromFileAndDirectory(issuer, file, directory, prvExt, pubExt string) (Signer, error) { - var err error - - // read key - var prv *rsa.PrivateKey = nil - if file != "" { - prv, err = rsaprivate.Read(file) - if err != nil { - return nil, err - } - } - - // read KeyStore - var kStore KeyStore = nil - if directory != "" { - kStore, err = NewMJwtKeyStoreFromDirectory(directory, prvExt, pubExt) - if err != nil { - return nil, err - } - } - - return NewMJwtSignerWithKeyStore(issuer, prv, kStore), nil -} - -// Issuer returns the name of the issuer -func (d *defaultMJwtSigner) Issuer() string { - return d.issuer -} - -// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims; uses the default key -func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) { - return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims)) -} - -// SignJwt signs a jwt.Claims compatible struct, this is used internally by -// GenerateJwt but is available for signing custom structs; uses the default key -func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) { - if d.key == nil { - return "", ErrNoPrivateKeyFound - } - token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped) - return token.SignedString(d.key) -} - -// GenerateJwtWithKID generates and returns a JWT string using the sub, id, duration and claims; this gets signed with the specified kID -func (d *defaultMJwtSigner) GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error) { - return d.SignJwtWithKID(wrapClaims[Claims](d, sub, id, aud, dur, claims), kID) -} - -// SignJwtWithKID signs a jwt.Claims compatible struct, this is used internally by -// GenerateJwt but is available for signing custom structs; this gets signed with the specified kID -func (d *defaultMJwtSigner) SignJwtWithKID(wrapped jwt.Claims, kID string) (string, error) { - pKey := d.verify.GetKeyStore().GetKey(kID) - if pKey == nil { - return "", ErrNoPrivateKeyFound - } - token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped) - token.Header["kid"] = kID - return token.SignedString(pKey) -} - -// VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt() -func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) { - return d.verify.VerifyJwt(token, claims) -} - -func (d *defaultMJwtSigner) PrivateKey() *rsa.PrivateKey { - return d.key -} -func (d *defaultMJwtSigner) PublicKey() *rsa.PublicKey { - return d.verify.pub -} - -func (d *defaultMJwtSigner) PublicKeyOf(kID string) *rsa.PublicKey { - return d.verify.kStore.GetKeyPublic(kID) -} - -func (d *defaultMJwtSigner) GetKeyStore() KeyStore { - return d.verify.GetKeyStore() -} - -func (d *defaultMJwtSigner) PrivateKeyOf(kID string) *rsa.PrivateKey { - return d.verify.kStore.GetKey(kID) -} - -// readOrCreatePrivateKey returns the private key it the file already exists, -// generates a new private key and saves it to the file, or returns an error if -// reading or generating failed. -func readOrCreatePrivateKey(file string, random io.Reader, bits int) (*rsa.PrivateKey, error) { - // read the file or return nil - f, err := readOrEmptyFile(file) - if err != nil { - return nil, err - } - if f == nil { - // generate a new key - key, err := rsa.GenerateKey(random, bits) - if err != nil { - return nil, err - } - - // save key to file - err = rsaprivate.Write(file, key) - if err != nil { - return nil, err - } - return key, err - } else { - // return key - return rsaprivate.Decode(bytes.NewReader(f)) - } -} - -// readOrEmptyFile returns bytes and errors from os.OpenFile or (nil, nil) if the -// file does not exist. -func readOrEmptyFile(file string) ([]byte, error) { - fp, err := os.Open(file) - if err != nil { - if os.IsNotExist(err) { - return nil, nil - } - return nil, err - } - defer func() { _ = fp.Close() }() - // add hard limit - limitReader := io.LimitReader(fp, readLimit) - raw, err := io.ReadAll(limitReader) - if err != nil { - return nil, err - } - return raw, nil -} diff --git a/signer_test.go b/signer_test.go deleted file mode 100644 index f04fbae..0000000 --- a/signer_test.go +++ /dev/null @@ -1,148 +0,0 @@ -package mjwt - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "github.com/1f349/rsa-helper/rsaprivate" - "github.com/1f349/rsa-helper/rsapublic" - "github.com/stretchr/testify/assert" - "os" - "path" - "testing" -) - -const st_prvExt = "prv" -const st_pubExt = "pub" - -func setupTestDirSigner(t *testing.T) (string, *rsa.PrivateKey, func(t *testing.T)) { - tempDir, err := os.MkdirTemp("", "this-is-a-test-dir") - assert.NoError(t, err) - - var key3 *rsa.PrivateKey = nil - key1, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+st_prvExt), key1) - assert.NoError(t, err) - - key2, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+st_prvExt), key2) - assert.NoError(t, err) - err = rsapublic.Write(path.Join(tempDir, "key2.pem."+st_pubExt), &key2.PublicKey) - assert.NoError(t, err) - - key3, err = rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - err = rsapublic.Write(path.Join(tempDir, "key3.pem."+st_pubExt), &key3.PublicKey) - assert.NoError(t, err) - - return tempDir, key3, func(t *testing.T) { - err := os.RemoveAll(tempDir) - assert.NoError(t, err) - } -} - -func TestNewMJwtSigner(t *testing.T) { - t.Parallel() - key, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - NewMJwtSigner("Test", key) -} - -func TestNewMJwtSignerWithKeyStore(t *testing.T) { - t.Parallel() - key, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - kStore := NewMJwtKeyStore() - kStore.SetKey("test", key) - assert.Contains(t, kStore.ListKeys(), "test") - NewMJwtSignerWithKeyStore("Test", nil, kStore) -} - -func TestNewMJwtSignerFromFile(t *testing.T) { - t.Parallel() - tempKey, err := os.CreateTemp("", "key-test-*.pem") - assert.NoError(t, err) - key, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - b := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) - _, err = tempKey.Write(b) - assert.NoError(t, err) - assert.NoError(t, tempKey.Close()) - signer, err := NewMJwtSignerFromFile("Test", tempKey.Name()) - assert.NoError(t, err) - assert.NoError(t, os.Remove(tempKey.Name())) - _, err = NewMJwtSignerFromFile("Test", tempKey.Name()) - assert.Error(t, err) - assert.True(t, os.IsNotExist(err)) - assert.True(t, signer.(*defaultMJwtSigner).key.Equal(key)) -} - -func TestNewMJwtSignerFromFileOrCreate(t *testing.T) { - t.Parallel() - tempKey, err := os.CreateTemp("", "key-test-*.pem") - assert.NoError(t, err) - assert.NoError(t, tempKey.Close()) - assert.NoError(t, os.Remove(tempKey.Name())) - signer, err := NewMJwtSignerFromFileOrCreate("Test", tempKey.Name(), rand.Reader, 2048) - assert.NoError(t, err) - signer2, err := NewMJwtSignerFromFileOrCreate("Test", tempKey.Name(), rand.Reader, 2048) - assert.NoError(t, err) - assert.True(t, signer.PrivateKey().Equal(signer2.PrivateKey())) -} - -func TestReadOrCreatePrivateKey(t *testing.T) { - t.Parallel() - tempKey, err := os.CreateTemp("", "key-test-*.pem") - assert.NoError(t, err) - key, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - b := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) - _, err = tempKey.Write(b) - assert.NoError(t, err) - assert.NoError(t, tempKey.Close()) - key2, err := readOrCreatePrivateKey(tempKey.Name(), rand.Reader, 2048) - assert.NoError(t, err) - assert.True(t, key.Equal(key2)) - assert.NoError(t, os.Remove(tempKey.Name())) - key3, err := readOrCreatePrivateKey(tempKey.Name(), rand.Reader, 2048) - assert.NoError(t, err) - assert.NoError(t, key3.Validate()) -} - -func TestNewMJwtSignerFromDirectory(t *testing.T) { - t.Parallel() - - tempDir, prvKey3, cleaner := setupTestDirSigner(t) - defer cleaner(t) - - signer, err := NewMJwtSignerFromDirectory("Test", tempDir, st_prvExt, st_pubExt) - assert.NoError(t, err) - - assert.Len(t, signer.GetKeyStore().ListKeys(), 3) - kIDsToFind := []string{"key1", "key2", "key3"} - for _, k := range kIDsToFind { - assert.Contains(t, signer.GetKeyStore().ListKeys(), k) - } - assert.True(t, prvKey3.PublicKey.Equal(signer.GetKeyStore().GetKeyPublic("key3"))) -} - -func TestNewMJwtSignerFromFileAndDirectory(t *testing.T) { - t.Parallel() - - tempDir, prvKey3, cleaner := setupTestDirSigner(t) - defer cleaner(t) - - signer, err := NewMJwtSignerFromFileAndDirectory("Test", path.Join(tempDir, "key1.pem."+st_prvExt), tempDir, st_prvExt, st_pubExt) - assert.NoError(t, err) - - assert.Len(t, signer.GetKeyStore().ListKeys(), 3) - kIDsToFind := []string{"key1", "key2", "key3"} - for _, k := range kIDsToFind { - assert.Contains(t, signer.GetKeyStore().ListKeys(), k) - } - assert.True(t, prvKey3.PublicKey.Equal(signer.GetKeyStore().GetKeyPublic("key3"))) - assert.True(t, signer.PrivateKey().Equal(signer.GetKeyStore().GetKey("key1"))) -} diff --git a/verifier.go b/verifier.go deleted file mode 100644 index d3d483b..0000000 --- a/verifier.go +++ /dev/null @@ -1,108 +0,0 @@ -package mjwt - -import ( - "crypto/rsa" - "errors" - "github.com/1f349/rsa-helper/rsapublic" - "github.com/golang-jwt/jwt/v4" -) - -var ErrNoPublicKeyFound = errors.New("no public key found") -var ErrKIDInvalid = errors.New("kid invalid") - -// defaultMJwtVerifier implements Verifier and uses a rsa.PublicKey to validate -// MJWT tokens -type defaultMJwtVerifier struct { - pub *rsa.PublicKey - kStore KeyStore -} - -var _ Verifier = &defaultMJwtVerifier{} - -// NewMJwtVerifier creates a new defaultMJwtVerifier using the rsa.PublicKey -func NewMJwtVerifier(key *rsa.PublicKey) Verifier { - return NewMJwtVerifierWithKeyStore(key, NewMJwtKeyStore()) -} - -// NewMJwtVerifierWithKeyStore creates a new defaultMJwtVerifier using a rsa.PublicKey as the non kID key -// and a KeyStore for kID based keys -func NewMJwtVerifierWithKeyStore(defaultKey *rsa.PublicKey, kStore KeyStore) Verifier { - return &defaultMJwtVerifier{pub: defaultKey, kStore: kStore} -} - -// NewMJwtVerifierFromFile creates a new defaultMJwtVerifier using the path of a -// rsa.PublicKey file -func NewMJwtVerifierFromFile(file string) (Verifier, error) { - return NewMJwtVerifierFromFileAndDirectory(file, "", "", "") -} - -// NewMJwtVerifierFromDirectory creates a new defaultMJwtVerifier using the path of a directory to -// load the keys into a KeyStore; there is no default rsa.PublicKey -func NewMJwtVerifierFromDirectory(directory, prvExt, pubExt string) (Verifier, error) { - return NewMJwtVerifierFromFileAndDirectory("", directory, prvExt, pubExt) -} - -// NewMJwtVerifierFromFileAndDirectory creates a new defaultMJwtVerifier using the path of a rsa.PublicKey -// file as the non kID key and the path of a directory to load the keys into a KeyStore -func NewMJwtVerifierFromFileAndDirectory(file, directory, prvExt, pubExt string) (Verifier, error) { - var err error - - // read key - var pub *rsa.PublicKey = nil - if file != "" { - pub, err = rsapublic.Read(file) - if err != nil { - return nil, err - } - } - - // read KeyStore - var kStore KeyStore = nil - if directory != "" { - kStore, err = NewMJwtKeyStoreFromDirectory(directory, prvExt, pubExt) - if err != nil { - return nil, err - } - } - - return NewMJwtVerifierWithKeyStore(pub, kStore), nil -} - -// VerifyJwt validates and parses MJWT tokens and returns the claims -func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) { - withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { - kIDI, exs := token.Header["kid"] - if exs { - kID, ok := kIDI.(string) - if !ok { - return nil, ErrKIDInvalid - } - key := d.kStore.GetKeyPublic(kID) - if key == nil { - return nil, ErrNoPublicKeyFound - } else { - return key, nil - } - } - if d.pub == nil { - return nil, ErrNoPublicKeyFound - } - return d.pub, nil - }) - if err != nil { - return nil, err - } - return withClaims, claims.Valid() -} - -func (d *defaultMJwtVerifier) PublicKey() *rsa.PublicKey { - return d.pub -} - -func (d *defaultMJwtVerifier) PublicKeyOf(kID string) *rsa.PublicKey { - return d.kStore.GetKeyPublic(kID) -} - -func (d *defaultMJwtVerifier) GetKeyStore() KeyStore { - return d.kStore -} diff --git a/verifier_test.go b/verifier_test.go deleted file mode 100644 index 378ce1a..0000000 --- a/verifier_test.go +++ /dev/null @@ -1,111 +0,0 @@ -package mjwt - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "github.com/1f349/rsa-helper/rsaprivate" - "github.com/1f349/rsa-helper/rsapublic" - "github.com/stretchr/testify/assert" - "os" - "path" - "testing" - "time" -) - -const vt_prvExt = "prv" -const vt_pubExt = "pub" - -func setupTestDirVerifier(t *testing.T, genKeys bool) (string, *rsa.PrivateKey, func(t *testing.T)) { - tempDir, err := os.MkdirTemp("", "this-is-a-test-dir") - assert.NoError(t, err) - - var key3 *rsa.PrivateKey = nil - - if genKeys { - key1, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+vt_prvExt), key1) - assert.NoError(t, err) - - key2, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+vt_prvExt), key2) - assert.NoError(t, err) - err = rsapublic.Write(path.Join(tempDir, "key2.pem."+vt_pubExt), &key2.PublicKey) - assert.NoError(t, err) - - key3, err = rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - err = rsapublic.Write(path.Join(tempDir, "key3.pem."+vt_pubExt), &key3.PublicKey) - assert.NoError(t, err) - } - - return tempDir, key3, func(t *testing.T) { - err := os.RemoveAll(tempDir) - assert.NoError(t, err) - } -} - -func TestNewMJwtVerifierFromFile(t *testing.T) { - t.Parallel() - key, err := rsa.GenerateKey(rand.Reader, 2048) - assert.NoError(t, err) - - s := NewMJwtSigner("mjwt.test", key) - token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}) - assert.NoError(t, err) - - b := pem.EncodeToMemory(&pem.Block{Type: "RSA PUBLIC KEY", Bytes: x509.MarshalPKCS1PublicKey(&key.PublicKey)}) - temp, err := os.CreateTemp("", "this-is-a-test-file.pem") - assert.NoError(t, err) - _, err = temp.Write(b) - assert.NoError(t, err) - file, err := NewMJwtVerifierFromFile(temp.Name()) - assert.NoError(t, err) - _, _, err = ExtractClaims[testClaims](file, token) - assert.NoError(t, err) - err = os.Remove(temp.Name()) - assert.NoError(t, err) -} - -func TestNewMJwtVerifierFromDirectory(t *testing.T) { - t.Parallel() - - tempDir, prvKey3, cleaner := setupTestDirVerifier(t, true) - defer cleaner(t) - - s, err := NewMJwtSignerFromDirectory("mjwt.test", tempDir, vt_prvExt, vt_pubExt) - assert.NoError(t, err) - s.GetKeyStore().SetKey("key3", prvKey3) - token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}, "key3") - assert.NoError(t, err) - - v, err := NewMJwtVerifierFromDirectory(tempDir, vt_prvExt, vt_pubExt) - assert.NoError(t, err) - _, _, err = ExtractClaims[testClaims](v, token) - assert.NoError(t, err) -} - -func TestNewMJwtVerifierFromFileAndDirectory(t *testing.T) { - t.Parallel() - - tempDir, prvKey3, cleaner := setupTestDirVerifier(t, true) - defer cleaner(t) - - s, err := NewMJwtSignerFromFileAndDirectory("mjwt.test", path.Join(tempDir, "key2.pem."+vt_prvExt), tempDir, vt_prvExt, vt_pubExt) - assert.NoError(t, err) - s.GetKeyStore().SetKey("key3", prvKey3) - token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}) - assert.NoError(t, err) - token2, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}, "key3") - assert.NoError(t, err) - - v, err := NewMJwtVerifierFromFileAndDirectory(path.Join(tempDir, "key2.pem."+vt_pubExt), tempDir, vt_prvExt, vt_pubExt) - assert.NoError(t, err) - _, _, err = ExtractClaims[testClaims](v, token1) - assert.NoError(t, err) - _, _, err = ExtractClaims[testClaims](v, token2) - assert.NoError(t, err) -}