mirror of
https://github.com/1f349/mjwt.git
synced 2024-12-10 20:11:35 +00:00
Rewrite mjwt library to better support keystores
This commit is contained in:
parent
5d1bd6f8fd
commit
cd2d80cb09
@ -2,14 +2,13 @@ package auth
|
||||
|
||||
import (
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/mjwt/claims"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AccessTokenClaims contains the JWT claims for an access token
|
||||
type AccessTokenClaims struct {
|
||||
Perms *claims.PermStorage `json:"per"`
|
||||
Perms *PermStorage `json:"per"`
|
||||
}
|
||||
|
||||
func (a AccessTokenClaims) Valid() error { return nil }
|
||||
@ -17,21 +16,11 @@ func (a AccessTokenClaims) Valid() error { return nil }
|
||||
func (a AccessTokenClaims) Type() string { return "access-token" }
|
||||
|
||||
// CreateAccessToken creates an access token with the default 15 minute duration
|
||||
func CreateAccessToken(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
|
||||
func CreateAccessToken(p *mjwt.Issuer, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
|
||||
return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, aud, perms)
|
||||
}
|
||||
|
||||
// CreateAccessTokenWithDuration creates an access token with a custom duration
|
||||
func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
|
||||
func CreateAccessTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
|
||||
return p.GenerateJwt(sub, id, aud, dur, &AccessTokenClaims{Perms: perms})
|
||||
}
|
||||
|
||||
// CreateAccessTokenWithKID creates an access token with the default 15 minute duration and the specified kID
|
||||
func CreateAccessTokenWithKID(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) {
|
||||
return CreateAccessTokenWithDurationAndKID(p, time.Minute*15, sub, id, aud, perms, kID)
|
||||
}
|
||||
|
||||
// CreateAccessTokenWithDurationAndKID creates an access token with a custom duration and the specified kID
|
||||
func CreateAccessTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) {
|
||||
return p.GenerateJwtWithKID(sub, id, aud, dur, &AccessTokenClaims{Perms: perms}, kID)
|
||||
}
|
||||
|
@ -1,55 +1,26 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/mjwt/claims"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCreateAccessToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ps := claims.NewPermStorage()
|
||||
ps := NewPermStorage()
|
||||
ps.Set("mjwt:test")
|
||||
ps.Set("mjwt:test2")
|
||||
|
||||
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
||||
kStore := mjwt.NewKeyStore()
|
||||
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
accessToken, err := CreateAccessToken(s, "1", "test", nil, ps)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b.Subject)
|
||||
assert.Equal(t, "test", b.ID)
|
||||
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
|
||||
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
||||
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
||||
}
|
||||
|
||||
func TestCreateAccessTokenInvalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
kStore := mjwt.NewMJwtKeyStore()
|
||||
kStore.SetKey("test", key)
|
||||
|
||||
ps := claims.NewPermStorage()
|
||||
ps.Set("mjwt:test")
|
||||
ps.Set("mjwt:test2")
|
||||
|
||||
s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)
|
||||
|
||||
accessToken, err := CreateAccessTokenWithKID(s, "1", "test", nil, ps, "test")
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b.Subject)
|
||||
assert.Equal(t, "test", b.ID)
|
||||
|
25
auth/pair.go
25
auth/pair.go
@ -2,20 +2,19 @@ package auth
|
||||
|
||||
import (
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/mjwt/claims"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CreateTokenPair creates an access and refresh token pair using the default
|
||||
// 15 minute and 7 day durations respectively
|
||||
func CreateTokenPair(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
|
||||
func CreateTokenPair(p *mjwt.Issuer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
|
||||
return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms)
|
||||
}
|
||||
|
||||
// CreateTokenPairWithDuration creates an access and refresh token pair using
|
||||
// custom durations for the access and refresh tokens
|
||||
func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
|
||||
func CreateTokenPairWithDuration(p *mjwt.Issuer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
|
||||
accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, aud, perms)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
@ -26,23 +25,3 @@ func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Durat
|
||||
}
|
||||
return accessToken, refreshToken, nil
|
||||
}
|
||||
|
||||
// CreateTokenPairWithKID creates an access and refresh token pair using the default
|
||||
// 15 minute and 7 day durations respectively using the specified kID
|
||||
func CreateTokenPairWithKID(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) {
|
||||
return CreateTokenPairWithDurationAndKID(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms, kID)
|
||||
}
|
||||
|
||||
// CreateTokenPairWithDurationAndKID creates an access and refresh token pair using
|
||||
// custom durations for the access and refresh tokens
|
||||
func CreateTokenPairWithDurationAndKID(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) {
|
||||
accessToken, err := CreateAccessTokenWithDurationAndKID(p, accessDur, sub, id, aud, perms, kID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
refreshToken, err := CreateRefreshTokenWithDurationAndKID(p, refreshDur, sub, rId, id, rAud, kID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return accessToken, refreshToken, nil
|
||||
}
|
||||
|
@ -1,29 +1,26 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/mjwt/claims"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCreateTokenPair(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ps := claims.NewPermStorage()
|
||||
ps := NewPermStorage()
|
||||
ps.Set("mjwt:test")
|
||||
ps.Set("mjwt:test2")
|
||||
|
||||
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
||||
kStore := mjwt.NewKeyStore()
|
||||
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key2", kStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", "test2", nil, nil, ps)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b.Subject)
|
||||
assert.Equal(t, "test", b.ID)
|
||||
@ -31,38 +28,7 @@ func TestCreateTokenPair(t *testing.T) {
|
||||
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
||||
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
||||
|
||||
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b2.Subject)
|
||||
assert.Equal(t, "test2", b2.ID)
|
||||
}
|
||||
|
||||
func TestCreateTokenPairWithKID(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
kStore := mjwt.NewMJwtKeyStore()
|
||||
kStore.SetKey("test", key)
|
||||
|
||||
ps := claims.NewPermStorage()
|
||||
ps.Set("mjwt:test")
|
||||
ps.Set("mjwt:test2")
|
||||
|
||||
s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)
|
||||
|
||||
accessToken, refreshToken, err := CreateTokenPairWithKID(s, "1", "test", "test2", nil, nil, ps, "test")
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b.Subject)
|
||||
assert.Equal(t, "test", b.ID)
|
||||
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
|
||||
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
||||
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
||||
|
||||
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
||||
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b2.Subject)
|
||||
assert.Equal(t, "test2", b2.ID)
|
||||
|
@ -1,4 +1,4 @@
|
||||
package claims
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bufio"
|
@ -1,4 +1,4 @@
|
||||
package claims
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
@ -16,21 +16,11 @@ func (r RefreshTokenClaims) Valid() error { return nil }
|
||||
func (r RefreshTokenClaims) Type() string { return "refresh-token" }
|
||||
|
||||
// CreateRefreshToken creates a refresh token with the default 7 day duration
|
||||
func CreateRefreshToken(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
||||
func CreateRefreshToken(p *mjwt.Issuer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
||||
return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, ati, aud)
|
||||
}
|
||||
|
||||
// CreateRefreshTokenWithDuration creates a refresh token with a custom duration
|
||||
func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
||||
func CreateRefreshTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
||||
return p.GenerateJwt(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati})
|
||||
}
|
||||
|
||||
// CreateRefreshTokenWithKID creates a refresh token with the default 7 day duration and the specified kID
|
||||
func CreateRefreshTokenWithKID(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) {
|
||||
return CreateRefreshTokenWithDurationAndKID(p, time.Hour*24*7, sub, id, ati, aud, kID)
|
||||
}
|
||||
|
||||
// CreateRefreshTokenWithDurationAndKID creates a refresh token with a custom duration and the specified kID
|
||||
func CreateRefreshTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) {
|
||||
return p.GenerateJwtWithKID(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati}, kID)
|
||||
}
|
||||
|
@ -1,8 +1,6 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
@ -10,35 +8,15 @@ import (
|
||||
|
||||
func TestCreateRefreshToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
||||
kStore := mjwt.NewKeyStore()
|
||||
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
refreshToken, err := CreateRefreshToken(s, "1", "test", "test2", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b.Subject)
|
||||
assert.Equal(t, "test", b.ID)
|
||||
assert.Equal(t, "test2", b.Claims.AccessTokenId)
|
||||
}
|
||||
|
||||
func TestCreateRefreshTokenWithKID(t *testing.T) {
|
||||
t.Parallel()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
kStore := mjwt.NewMJwtKeyStore()
|
||||
kStore.SetKey("test", key)
|
||||
|
||||
s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)
|
||||
|
||||
refreshToken, err := CreateRefreshTokenWithKID(s, "1", "test", "test2", nil, "test")
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
||||
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", b.Subject)
|
||||
assert.Equal(t, "test", b.ID)
|
||||
|
@ -10,11 +10,11 @@ import (
|
||||
var ErrClaimTypeMismatch = errors.New("claim type mismatch")
|
||||
|
||||
// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct
|
||||
func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
|
||||
func wrapClaims[T Claims](sub, id, issuer string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
|
||||
now := time.Now()
|
||||
return (&BaseTypeClaims[T]{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: p.Issuer(),
|
||||
Issuer: issuer,
|
||||
Subject: sub,
|
||||
Audience: aud,
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(dur)),
|
||||
@ -28,12 +28,12 @@ func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur ti
|
||||
|
||||
// ExtractClaims uses a Verifier to validate the MJWT token and returns the parsed
|
||||
// token and BaseTypeClaims
|
||||
func ExtractClaims[T Claims](p Verifier, token string) (*jwt.Token, BaseTypeClaims[T], error) {
|
||||
func ExtractClaims[T Claims](ks *KeyStore, token string) (*jwt.Token, BaseTypeClaims[T], error) {
|
||||
b := BaseTypeClaims[T]{
|
||||
RegisteredClaims: jwt.RegisteredClaims{},
|
||||
Claims: *new(T),
|
||||
}
|
||||
tok, err := p.VerifyJwt(token, &b)
|
||||
tok, err := ks.VerifyJwt(token, &b)
|
||||
return tok, b, err
|
||||
}
|
||||
|
100
claims_test.go
Normal file
100
claims_test.go
Normal file
@ -0,0 +1,100 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type testClaims struct{ TestValue string }
|
||||
|
||||
func (t testClaims) Valid() error {
|
||||
if t.TestValue != "hello" && t.TestValue != "world" {
|
||||
return fmt.Errorf("TestValue should be hello")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t testClaims) Type() string { return "testClaims" }
|
||||
|
||||
type testClaims2 struct{ TestValue2 string }
|
||||
|
||||
func (t testClaims2) Valid() error {
|
||||
if t.TestValue2 != "world" {
|
||||
return fmt.Errorf("TestValue2 should be world")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t testClaims2) Type() string { return "testClaims2" }
|
||||
|
||||
func TestExtractClaims(t *testing.T) {
|
||||
t.Parallel()
|
||||
kStore := NewKeyStore()
|
||||
|
||||
t.Run("TestNoKID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
s, err := NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
|
||||
assert.NoError(t, err)
|
||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
a, _, err := ExtractClaims[testClaims](kStore, token)
|
||||
assert.NoError(t, err)
|
||||
kid, _ := a.Header["kid"].(string)
|
||||
assert.Equal(t, "key1", kid)
|
||||
})
|
||||
|
||||
t.Run("TestKID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
s, err := NewIssuerWithKeyStore("mjwt.test", "key2", kStore)
|
||||
assert.NoError(t, err)
|
||||
s2, err := NewIssuerWithKeyStore("mjwt.test", "key3", kStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
||||
assert.NoError(t, err)
|
||||
token2, err := s2.GenerateJwt("2", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
k1, _, err := ExtractClaims[testClaims](kStore, token1)
|
||||
assert.NoError(t, err)
|
||||
k2, _, err := ExtractClaims[testClaims](kStore, token2)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, k1.Header["kid"], k2.Header["kid"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractClaimsFail(t *testing.T) {
|
||||
t.Parallel()
|
||||
kStore := NewKeyStore()
|
||||
|
||||
t.Run("TestInvalidClaims", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
s, err := NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
|
||||
assert.NoError(t, err)
|
||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, _, err = ExtractClaims[testClaims2](kStore, token)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrClaimTypeMismatch)
|
||||
})
|
||||
|
||||
t.Run("TestKIDNonExist", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s, err := NewIssuerWithKeyStore("mjwt.test", "key2", kStore)
|
||||
assert.NoError(t, err)
|
||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
kStore.RemoveKey("key2")
|
||||
assert.NotContains(t, kStore.ListKeys(), "key2")
|
||||
|
||||
_, _, err = ExtractClaims[testClaims](kStore, token)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrMissingPublicKey)
|
||||
})
|
||||
}
|
@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/mjwt/auth"
|
||||
"github.com/1f349/mjwt/claims"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/subcommands"
|
||||
@ -35,7 +34,7 @@ func (s *accessCmd) SetFlags(f *flag.FlagSet) {
|
||||
f.StringVar(&s.id, "id", "", "MJWT ID")
|
||||
f.StringVar(&s.audience, "aud", "", "Comma separated audience items for the MJWT")
|
||||
f.StringVar(&s.duration, "dur", "15m", "Duration for the MJWT (default: 15m)")
|
||||
f.StringVar(&s.kID, "kid", "\x00", "The Key ID of the signing key")
|
||||
f.StringVar(&s.kID, "kid", "", "The Key ID of the signing key")
|
||||
}
|
||||
|
||||
func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
|
||||
@ -51,7 +50,7 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
|
||||
return subcommands.ExitFailure
|
||||
}
|
||||
|
||||
ps := claims.NewPermStorage()
|
||||
ps := auth.NewPermStorage()
|
||||
for i := 1; i < len(args); i++ {
|
||||
ps.Set(args[i])
|
||||
}
|
||||
@ -67,16 +66,16 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
|
||||
}
|
||||
|
||||
var token string
|
||||
if s.kID == "\x00" {
|
||||
signer := mjwt.NewMJwtSigner(s.issuer, key)
|
||||
token, err = signer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
|
||||
} else {
|
||||
kStore := mjwt.NewMJwtKeyStore()
|
||||
kStore.SetKey(s.kID, key)
|
||||
signer := mjwt.NewMJwtSignerWithKeyStore(s.issuer, nil, kStore)
|
||||
token, err = signer.GenerateJwtWithKID(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps}, s.kID)
|
||||
|
||||
kStore := mjwt.NewKeyStore()
|
||||
kStore.LoadPrivateKey(s.kID, key)
|
||||
|
||||
issuer, err := mjwt.NewIssuerWithKeyStore(s.issuer, s.kID, kStore)
|
||||
if err != nil {
|
||||
panic("this should not fail")
|
||||
}
|
||||
|
||||
token, err = issuer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Error: Failed to generate MJWT token: ", err)
|
||||
return subcommands.ExitFailure
|
||||
|
@ -1,7 +1,7 @@
|
||||
package claims
|
||||
package mjwt
|
||||
|
||||
// EmptyClaims contains no claims
|
||||
type EmptyClaims struct {}
|
||||
type EmptyClaims struct{}
|
||||
|
||||
func (e EmptyClaims) Valid() error { return nil }
|
||||
|
8
go.mod
8
go.mod
@ -10,11 +10,17 @@ require (
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0
|
||||
github.com/google/subcommands v1.2.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/spf13/afero v1.11.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
golang.org/x/sync v0.7.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.12.0 // indirect
|
||||
golang.org/x/text v0.16.0 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
)
|
||||
|
25
go.sum
25
go.sum
@ -2,19 +2,38 @@ github.com/1f349/rsa-helper v0.0.2 h1:N/fLQqg5wrjIzG6G4zdwa5Xcv9/jIPutCls9YekZr9
|
||||
github.com/1f349/rsa-helper v0.0.2/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg=
|
||||
github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA=
|
||||
github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
|
||||
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
@ -1,39 +0,0 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Signer is used to generate MJWT tokens.
|
||||
// Signer can also be used as a Verifier.
|
||||
type Signer interface {
|
||||
Verifier
|
||||
GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error)
|
||||
SignJwt(claims jwt.Claims) (string, error)
|
||||
GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error)
|
||||
SignJwtWithKID(claims jwt.Claims, kID string) (string, error)
|
||||
Issuer() string
|
||||
PrivateKey() *rsa.PrivateKey
|
||||
PrivateKeyOf(kID string) *rsa.PrivateKey
|
||||
}
|
||||
|
||||
// Verifier is used to verify the validity MJWT tokens and extract the claim values.
|
||||
type Verifier interface {
|
||||
VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error)
|
||||
PublicKey() *rsa.PublicKey
|
||||
PublicKeyOf(kID string) *rsa.PublicKey
|
||||
GetKeyStore() KeyStore
|
||||
}
|
||||
|
||||
// KeyStore is used for the kid header support in Signer and Verifier.
|
||||
type KeyStore interface {
|
||||
SetKey(kID string, prvKey *rsa.PrivateKey)
|
||||
SetKeyPublic(kID string, pubKey *rsa.PublicKey)
|
||||
RemoveKey(kID string)
|
||||
ListKeys() []string
|
||||
GetKey(kID string) *rsa.PrivateKey
|
||||
GetKeyPublic(kID string) *rsa.PublicKey
|
||||
ClearKeys()
|
||||
}
|
49
issuer.go
Normal file
49
issuer.go
Normal file
@ -0,0 +1,49 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Issuer struct {
|
||||
issuer string
|
||||
kid string
|
||||
keystore *KeyStore
|
||||
}
|
||||
|
||||
func NewIssuer(name, kid string) (*Issuer, error) {
|
||||
return NewIssuerWithKeyStore(name, kid, NewKeyStore())
|
||||
}
|
||||
|
||||
func NewIssuerWithKeyStore(name, kid string, keystore *KeyStore) (*Issuer, error) {
|
||||
i := &Issuer{name, kid, keystore}
|
||||
if i.keystore.HasPrivateKey(kid) {
|
||||
return i, nil
|
||||
}
|
||||
key, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
i.keystore.LoadPrivateKey(kid, key)
|
||||
return i, i.keystore.SaveSingleKey(kid)
|
||||
}
|
||||
|
||||
func (i *Issuer) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
|
||||
return i.SignJwt(wrapClaims[Claims](sub, id, i.issuer, aud, dur, claims))
|
||||
}
|
||||
|
||||
func (i *Issuer) SignJwt(wrapped jwt.Claims) (string, error) {
|
||||
key, err := i.PrivateKey()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
|
||||
token.Header["kid"] = i.kid
|
||||
return token.SignedString(key)
|
||||
}
|
||||
|
||||
func (i *Issuer) PrivateKey() (*rsa.PrivateKey, error) {
|
||||
return i.keystore.GetPrivateKey(i.kid)
|
||||
}
|
58
issuer_test.go
Normal file
58
issuer_test.go
Normal file
@ -0,0 +1,58 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewIssuer(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("generate missing key for issuer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
kStore := NewKeyStore()
|
||||
issuer, err := NewIssuerWithKeyStore("Test", "test", kStore)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, kStore.HasPrivateKey("test"))
|
||||
assert.True(t, kStore.HasPublicKey("test"))
|
||||
assert.Equal(t, "Test", issuer.issuer)
|
||||
assert.Equal(t, "test", issuer.kid)
|
||||
})
|
||||
t.Run("use existing issuer key", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
kStore := NewKeyStore()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
kStore.LoadPrivateKey("test", key)
|
||||
issuer, err := NewIssuerWithKeyStore("Test", "test", kStore)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, kStore.HasPrivateKey("test"))
|
||||
assert.True(t, kStore.HasPublicKey("test"))
|
||||
assert.Equal(t, "Test", issuer.issuer)
|
||||
assert.Equal(t, "test", issuer.kid)
|
||||
privateKey, err := issuer.PrivateKey()
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, key.Equal(privateKey))
|
||||
})
|
||||
t.Run("generate missing key in filesystem", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := afero.NewMemMapFs()
|
||||
kStore := NewKeyStoreWithDir(dir)
|
||||
issuer, err := NewIssuerWithKeyStore("Test", "test", kStore)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, kStore.HasPrivateKey("test"))
|
||||
assert.True(t, kStore.HasPublicKey("test"))
|
||||
assert.Equal(t, "Test", issuer.issuer)
|
||||
assert.Equal(t, "test", issuer.kid)
|
||||
privKeyFile, err := dir.Open("test.private.pem")
|
||||
assert.NoError(t, err)
|
||||
privKey, err := rsaprivate.Decode(privKeyFile)
|
||||
assert.NoError(t, err)
|
||||
key, err := issuer.PrivateKey()
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, key.Equal(privKey))
|
||||
})
|
||||
}
|
185
key_store.go
185
key_store.go
@ -1,185 +0,0 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/1f349/rsa-helper/rsapublic"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// defaultMJwtKeyStore implements KeyStore and stores kIDs against just rsa.PublicKey
|
||||
// or with rsa.PrivateKey instances as well.
|
||||
type defaultMJwtKeyStore struct {
|
||||
rwLocker *sync.RWMutex
|
||||
store map[string]*rsa.PrivateKey
|
||||
storePub map[string]*rsa.PublicKey
|
||||
}
|
||||
|
||||
var _ KeyStore = &defaultMJwtKeyStore{}
|
||||
|
||||
// NewMJwtKeyStore creates a new defaultMJwtKeyStore.
|
||||
func NewMJwtKeyStore() KeyStore {
|
||||
return &defaultMJwtKeyStore{
|
||||
rwLocker: new(sync.RWMutex),
|
||||
store: make(map[string]*rsa.PrivateKey),
|
||||
storePub: make(map[string]*rsa.PublicKey),
|
||||
}
|
||||
}
|
||||
|
||||
// NewMJwtKeyStoreFromDirectory loads keys from a directory with the specified extensions to denote public and private
|
||||
// rsa keys; the kID is the filename of the key up to the first .
|
||||
func NewMJwtKeyStoreFromDirectory(directory, keyPrvExt, keyPubExt string) (KeyStore, error) {
|
||||
// Create empty KeyStore
|
||||
ks := NewMJwtKeyStore().(*defaultMJwtKeyStore)
|
||||
// List directory contents
|
||||
dirEntries, err := os.ReadDir(directory)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
errs := make([]error, 0, len(dirEntries)/2)
|
||||
// Import keys from files, based on extension
|
||||
for _, entry := range dirEntries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
kID, _, _ := strings.Cut(entry.Name(), ".")
|
||||
if kID == "" {
|
||||
continue
|
||||
}
|
||||
pExt := path.Ext(entry.Name())
|
||||
if pExt == "."+keyPrvExt {
|
||||
// Load rsa private key with the file name as the kID (Up to the first .)
|
||||
key, err2 := rsaprivate.Read(path.Join(directory, entry.Name()))
|
||||
if err2 == nil {
|
||||
ks.store[kID] = key
|
||||
ks.storePub[kID] = &key.PublicKey
|
||||
}
|
||||
errs = append(errs, err2)
|
||||
} else if pExt == "."+keyPubExt {
|
||||
// Load rsa public key with the file name as the kID (Up to the first .)
|
||||
key, err2 := rsapublic.Read(path.Join(directory, entry.Name()))
|
||||
if err2 == nil {
|
||||
_, exs := ks.store[kID]
|
||||
if !exs {
|
||||
ks.store[kID] = nil
|
||||
}
|
||||
ks.storePub[kID] = key
|
||||
}
|
||||
errs = append(errs, err2)
|
||||
}
|
||||
}
|
||||
return ks, errors.Join(errs...)
|
||||
}
|
||||
|
||||
// ExportKeyStore saves all the keys stored in the specified KeyStore into a directory with the specified
|
||||
// extensions for public and private keys
|
||||
func ExportKeyStore(ks KeyStore, directory, keyPrvExt, keyPubExt string) error {
|
||||
if ks == nil {
|
||||
return errors.New("ks is nil")
|
||||
}
|
||||
|
||||
// Create directory
|
||||
err := os.MkdirAll(directory, 0700)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
errs := make([]error, 0, len(ks.ListKeys())/2)
|
||||
// Export all keys
|
||||
for _, kID := range ks.ListKeys() {
|
||||
kPrv := ks.GetKey(kID)
|
||||
if kPrv != nil {
|
||||
err2 := rsaprivate.Write(path.Join(directory, kID+"."+keyPrvExt), kPrv)
|
||||
errs = append(errs, err2)
|
||||
}
|
||||
kPub := ks.GetKeyPublic(kID)
|
||||
if kPub != nil {
|
||||
err2 := rsapublic.Write(path.Join(directory, kID+"."+keyPubExt), kPub)
|
||||
errs = append(errs, err2)
|
||||
}
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// SetKey adds a new rsa.PrivateKey with the specified kID to the KeyStore.
|
||||
func (d *defaultMJwtKeyStore) SetKey(kID string, prvKey *rsa.PrivateKey) {
|
||||
if prvKey == nil {
|
||||
return
|
||||
}
|
||||
d.rwLocker.Lock()
|
||||
defer d.rwLocker.Unlock()
|
||||
d.store[kID] = prvKey
|
||||
d.storePub[kID] = &prvKey.PublicKey
|
||||
return
|
||||
}
|
||||
|
||||
// SetKeyPublic adds a new rsa.PublicKey with the specified kID to the KeyStore.
|
||||
func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) {
|
||||
if pubKey == nil {
|
||||
return
|
||||
}
|
||||
d.rwLocker.Lock()
|
||||
defer d.rwLocker.Unlock()
|
||||
_, exs := d.store[kID]
|
||||
if !exs {
|
||||
d.store[kID] = nil
|
||||
}
|
||||
d.storePub[kID] = pubKey
|
||||
return
|
||||
}
|
||||
|
||||
// RemoveKey removes a specified kID from the KeyStore.
|
||||
func (d *defaultMJwtKeyStore) RemoveKey(kID string) {
|
||||
d.rwLocker.Lock()
|
||||
defer d.rwLocker.Unlock()
|
||||
delete(d.store, kID)
|
||||
delete(d.storePub, kID)
|
||||
return
|
||||
}
|
||||
|
||||
// ListKeys lists the kIDs of all the keys in the KeyStore.
|
||||
func (d *defaultMJwtKeyStore) ListKeys() []string {
|
||||
d.rwLocker.RLock()
|
||||
defer d.rwLocker.RUnlock()
|
||||
lKeys := make([]string, len(d.store))
|
||||
i := 0
|
||||
for k := range d.store {
|
||||
lKeys[i] = k
|
||||
i++
|
||||
}
|
||||
return lKeys
|
||||
}
|
||||
|
||||
// GetKey gets the rsa.PrivateKey given the kID in the KeyStore or null if not found.
|
||||
func (d *defaultMJwtKeyStore) GetKey(kID string) *rsa.PrivateKey {
|
||||
d.rwLocker.RLock()
|
||||
defer d.rwLocker.RUnlock()
|
||||
kPrv, ok := d.store[kID]
|
||||
if ok {
|
||||
return kPrv
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetKeyPublic gets the rsa.PublicKey given the kID in the KeyStore or null if not found.
|
||||
func (d *defaultMJwtKeyStore) GetKeyPublic(kID string) *rsa.PublicKey {
|
||||
d.rwLocker.RLock()
|
||||
defer d.rwLocker.RUnlock()
|
||||
kPub, ok := d.storePub[kID]
|
||||
if ok {
|
||||
return kPub
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearKeys removes all the stored keys in the KeyStore.
|
||||
func (d *defaultMJwtKeyStore) ClearKeys() {
|
||||
d.rwLocker.Lock()
|
||||
defer d.rwLocker.Unlock()
|
||||
clear(d.store)
|
||||
clear(d.storePub)
|
||||
}
|
@ -1,152 +0,0 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/1f349/rsa-helper/rsapublic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const kst_prvExt = "prv"
|
||||
const kst_pubExt = "pub"
|
||||
|
||||
func setupTestDirKeyStore(t *testing.T, genKeys bool) (string, func(t *testing.T)) {
|
||||
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
|
||||
assert.NoError(t, err)
|
||||
|
||||
if genKeys {
|
||||
key1, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+kst_prvExt), key1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key2, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+kst_prvExt), key2)
|
||||
assert.NoError(t, err)
|
||||
err = rsapublic.Write(path.Join(tempDir, "key2.pem."+kst_pubExt), &key2.PublicKey)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key3, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = rsapublic.Write(path.Join(tempDir, "key3.pem."+kst_pubExt), &key3.PublicKey)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
return tempDir, func(t *testing.T) {
|
||||
err := os.RemoveAll(tempDir)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func commonSubTestsKeyStore(t *testing.T, kStore KeyStore) {
|
||||
key4, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key5, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
|
||||
const extraKID1 = "key4"
|
||||
const extraKID2 = "key5"
|
||||
|
||||
t.Run("TestSetKey", func(t *testing.T) {
|
||||
kStore.SetKey(extraKID1, key4)
|
||||
assert.Contains(t, kStore.ListKeys(), extraKID1)
|
||||
})
|
||||
|
||||
t.Run("TestSetKeyPublic", func(t *testing.T) {
|
||||
kStore.SetKeyPublic(extraKID2, &key5.PublicKey)
|
||||
assert.Contains(t, kStore.ListKeys(), extraKID2)
|
||||
})
|
||||
|
||||
t.Run("TestGetKey", func(t *testing.T) {
|
||||
oKey := kStore.GetKey(extraKID1)
|
||||
assert.Same(t, key4, oKey)
|
||||
pKey := kStore.GetKey(extraKID2)
|
||||
assert.Nil(t, pKey)
|
||||
aKey := kStore.GetKey("key1")
|
||||
assert.NotNil(t, aKey)
|
||||
bKey := kStore.GetKey("key2")
|
||||
assert.NotNil(t, bKey)
|
||||
cKey := kStore.GetKey("key3")
|
||||
assert.Nil(t, cKey)
|
||||
})
|
||||
|
||||
t.Run("TestGetKeyPublic", func(t *testing.T) {
|
||||
oKey := kStore.GetKeyPublic(extraKID1)
|
||||
assert.Same(t, &key4.PublicKey, oKey)
|
||||
pKey := kStore.GetKeyPublic(extraKID2)
|
||||
assert.Same(t, &key5.PublicKey, pKey)
|
||||
aKey := kStore.GetKeyPublic("key1")
|
||||
assert.NotNil(t, aKey)
|
||||
bKey := kStore.GetKeyPublic("key2")
|
||||
assert.NotNil(t, bKey)
|
||||
cKey := kStore.GetKeyPublic("key3")
|
||||
assert.NotNil(t, cKey)
|
||||
})
|
||||
|
||||
t.Run("TestRemoveKey", func(t *testing.T) {
|
||||
kStore.RemoveKey(extraKID1)
|
||||
assert.NotContains(t, kStore.ListKeys(), extraKID1)
|
||||
oKey1 := kStore.GetKey(extraKID1)
|
||||
assert.Nil(t, oKey1)
|
||||
oKey2 := kStore.GetKeyPublic(extraKID1)
|
||||
assert.Nil(t, oKey2)
|
||||
})
|
||||
|
||||
t.Run("TestClearKeys", func(t *testing.T) {
|
||||
kStore.ClearKeys()
|
||||
assert.Empty(t, kStore.ListKeys())
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir, cleaner := setupTestDirKeyStore(t, true)
|
||||
defer cleaner(t)
|
||||
|
||||
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, kst_prvExt, kst_pubExt)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, kStore.ListKeys(), 3)
|
||||
kIDsToFind := []string{"key1", "key2", "key3"}
|
||||
for _, k := range kIDsToFind {
|
||||
assert.Contains(t, kStore.ListKeys(), k)
|
||||
}
|
||||
|
||||
commonSubTestsKeyStore(t, kStore)
|
||||
}
|
||||
|
||||
func TestExportKeyStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir, cleaner := setupTestDirKeyStore(t, true)
|
||||
defer cleaner(t)
|
||||
tempDir2, cleaner2 := setupTestDirKeyStore(t, false)
|
||||
defer cleaner2(t)
|
||||
|
||||
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, kst_prvExt, kst_pubExt)
|
||||
assert.NoError(t, err)
|
||||
|
||||
const prvExt2 = "v"
|
||||
const pubExt2 = "b"
|
||||
|
||||
err = ExportKeyStore(kStore, tempDir2, prvExt2, pubExt2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
kStore2, err := NewMJwtKeyStoreFromDirectory(tempDir2, prvExt2, pubExt2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
kIDsToFind := kStore.ListKeys()
|
||||
assert.Len(t, kStore2.ListKeys(), len(kIDsToFind))
|
||||
for _, k := range kIDsToFind {
|
||||
assert.Contains(t, kStore2.ListKeys(), k)
|
||||
}
|
||||
|
||||
commonSubTestsKeyStore(t, kStore2)
|
||||
}
|
233
keystore.go
Normal file
233
keystore.go
Normal file
@ -0,0 +1,233 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/1f349/rsa-helper/rsapublic"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/spf13/afero"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var ErrMissingPrivateKey = errors.New("missing private key")
|
||||
var ErrMissingPublicKey = errors.New("missing public key")
|
||||
var ErrMissingKeyPair = errors.New("missing key pair")
|
||||
|
||||
const PrivateStr = ".private"
|
||||
const PublicStr = ".public"
|
||||
|
||||
const PemExt = ".pem"
|
||||
const PrivatePemExt = PrivateStr + PemExt
|
||||
const PublicPemExt = PublicStr + PemExt
|
||||
|
||||
type KeyStore struct {
|
||||
mu *sync.RWMutex
|
||||
store map[string]*keyPair
|
||||
dir afero.Fs
|
||||
}
|
||||
|
||||
func NewKeyStore() *KeyStore {
|
||||
return &KeyStore{
|
||||
mu: new(sync.RWMutex),
|
||||
store: make(map[string]*keyPair),
|
||||
}
|
||||
}
|
||||
|
||||
func NewKeyStoreWithDir(dir afero.Fs) *KeyStore {
|
||||
keyStore := NewKeyStore()
|
||||
keyStore.dir = dir
|
||||
return keyStore
|
||||
}
|
||||
|
||||
func NewKeyStoreFromDir(dir afero.Fs) (*KeyStore, error) {
|
||||
keyStore := NewKeyStoreWithDir(dir)
|
||||
err := afero.Walk(dir, ".", func(path string, d fs.FileInfo, err error) error {
|
||||
// maybe this is "name.private.pem"
|
||||
name := filepath.Base(path)
|
||||
ext := filepath.Ext(name)
|
||||
if ext != PemExt {
|
||||
return nil
|
||||
}
|
||||
|
||||
name = strings.TrimSuffix(name, ext)
|
||||
ext = filepath.Ext(name)
|
||||
name = strings.TrimSuffix(name, ext)
|
||||
switch ext {
|
||||
case PrivateStr:
|
||||
open, err := dir.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decode, err := rsaprivate.Decode(open)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyStore.LoadPrivateKey(name, decode)
|
||||
return nil
|
||||
case PublicStr:
|
||||
open, err := dir.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decode, err := rsapublic.Decode(open)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyStore.LoadPublicKey(name, decode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// still invalid
|
||||
return nil
|
||||
})
|
||||
return keyStore, err
|
||||
}
|
||||
|
||||
type keyPair struct {
|
||||
private *rsa.PrivateKey
|
||||
public *rsa.PublicKey
|
||||
}
|
||||
|
||||
func (k *KeyStore) LoadPrivateKey(kid string, key *rsa.PrivateKey) {
|
||||
k.mu.Lock()
|
||||
if k.store[kid] == nil {
|
||||
k.store[kid] = &keyPair{}
|
||||
}
|
||||
k.store[kid].private = key
|
||||
k.store[kid].public = &key.PublicKey
|
||||
k.mu.Unlock()
|
||||
}
|
||||
|
||||
func (k *KeyStore) LoadPublicKey(kid string, key *rsa.PublicKey) {
|
||||
k.mu.Lock()
|
||||
if k.store[kid] == nil {
|
||||
k.store[kid] = &keyPair{}
|
||||
}
|
||||
k.store[kid].public = key
|
||||
k.mu.Unlock()
|
||||
}
|
||||
|
||||
func (k *KeyStore) RemoveKey(kid string) {
|
||||
k.mu.Lock()
|
||||
delete(k.store, kid)
|
||||
k.mu.Unlock()
|
||||
}
|
||||
|
||||
func (k *KeyStore) ListKeys() []string {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
keys := make([]string, 0, len(k.store))
|
||||
for k, _ := range k.store {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func (k *KeyStore) GetPrivateKey(kid string) (*rsa.PrivateKey, error) {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
if !k.internalHasPrivateKey(kid) {
|
||||
return nil, ErrMissingPrivateKey
|
||||
}
|
||||
return k.store[kid].private, nil
|
||||
}
|
||||
|
||||
func (k *KeyStore) GetPublicKey(kid string) (*rsa.PublicKey, error) {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
if !k.internalHasPublicKey(kid) {
|
||||
return nil, ErrMissingPublicKey
|
||||
}
|
||||
return k.store[kid].public, nil
|
||||
}
|
||||
|
||||
func (k *KeyStore) ClearKeys() {
|
||||
k.mu.Lock()
|
||||
clear(k.store)
|
||||
k.mu.Unlock()
|
||||
}
|
||||
|
||||
func (k *KeyStore) HasPrivateKey(kid string) bool {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
return k.internalHasPrivateKey(kid)
|
||||
}
|
||||
|
||||
func (k *KeyStore) internalHasPrivateKey(kid string) bool {
|
||||
v := k.store[kid]
|
||||
return v != nil && v.private != nil
|
||||
}
|
||||
|
||||
func (k *KeyStore) HasPublicKey(kid string) bool {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
return k.internalHasPublicKey(kid)
|
||||
}
|
||||
|
||||
func (k *KeyStore) internalHasPublicKey(kid string) bool {
|
||||
v := k.store[kid]
|
||||
return v != nil && v.public != nil
|
||||
}
|
||||
|
||||
func (k *KeyStore) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
|
||||
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
kid, ok := token.Header["kid"].(string)
|
||||
if !ok {
|
||||
return nil, ErrMissingPublicKey
|
||||
}
|
||||
return k.GetPublicKey(kid)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return withClaims, claims.Valid()
|
||||
}
|
||||
|
||||
func (k *KeyStore) SaveSingleKey(kid string) error {
|
||||
if k.dir == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
k.mu.RLock()
|
||||
pair := k.store[kid]
|
||||
k.mu.RUnlock()
|
||||
if pair == nil {
|
||||
return ErrMissingKeyPair
|
||||
}
|
||||
|
||||
var errs []error
|
||||
if pair.private != nil {
|
||||
errs = append(errs, afero.WriteFile(k.dir, kid+PrivatePemExt, rsaprivate.Encode(pair.private), 0600))
|
||||
}
|
||||
if pair.public != nil {
|
||||
errs = append(errs, afero.WriteFile(k.dir, kid+PublicPemExt, rsapublic.Encode(pair.public), 0600))
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (k *KeyStore) SaveKeys() error {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
|
||||
workers := new(errgroup.Group)
|
||||
workers.SetLimit(runtime.NumCPU())
|
||||
for kid, pair := range k.store {
|
||||
workers.Go(func() error {
|
||||
var errs []error
|
||||
if pair.private != nil {
|
||||
errs = append(errs, afero.WriteFile(k.dir, kid+PrivatePemExt, rsaprivate.Encode(pair.private), 0600))
|
||||
}
|
||||
if pair.public != nil {
|
||||
errs = append(errs, afero.WriteFile(k.dir, kid+PublicPemExt, rsapublic.Encode(pair.public), 0600))
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
})
|
||||
}
|
||||
return workers.Wait()
|
||||
}
|
175
keystore_test.go
Normal file
175
keystore_test.go
Normal file
@ -0,0 +1,175 @@
|
||||
package mjwt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"github.com/1f349/rsa-helper/rsaprivate"
|
||||
"github.com/1f349/rsa-helper/rsapublic"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const kst_prvExt = "prv"
|
||||
const kst_pubExt = "pub"
|
||||
|
||||
func setupTestDirKeyStore(t *testing.T, genKeys bool) afero.Fs {
|
||||
tempDir := afero.NewMemMapFs()
|
||||
|
||||
if genKeys {
|
||||
key1, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = afero.WriteFile(tempDir, "key1.private.pem", rsaprivate.Encode(key1), 0600)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key2, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = afero.WriteFile(tempDir, "key2.private.pem", rsaprivate.Encode(key2), 0600)
|
||||
assert.NoError(t, err)
|
||||
err = afero.WriteFile(tempDir, "key2.public.pem", rsapublic.Encode(&key2.PublicKey), 0600)
|
||||
assert.NoError(t, err)
|
||||
|
||||
key3, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
assert.NoError(t, err)
|
||||
err = afero.WriteFile(tempDir, "key3.public.pem", rsapublic.Encode(&key3.PublicKey), 0600)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
return tempDir
|
||||
}
|
||||
|
||||
func commonSubTestsKeyStore(t *testing.T, kStore *KeyStore) {
|
||||
< |