mirror of
https://github.com/1f349/mjwt.git
synced 2024-12-22 07:24:05 +00:00
Rewrite mjwt library to better support keystores
This commit is contained in:
parent
5d1bd6f8fd
commit
cd2d80cb09
@ -2,14 +2,13 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/1f349/mjwt"
|
"github.com/1f349/mjwt"
|
||||||
"github.com/1f349/mjwt/claims"
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AccessTokenClaims contains the JWT claims for an access token
|
// AccessTokenClaims contains the JWT claims for an access token
|
||||||
type AccessTokenClaims struct {
|
type AccessTokenClaims struct {
|
||||||
Perms *claims.PermStorage `json:"per"`
|
Perms *PermStorage `json:"per"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a AccessTokenClaims) Valid() error { return nil }
|
func (a AccessTokenClaims) Valid() error { return nil }
|
||||||
@ -17,21 +16,11 @@ func (a AccessTokenClaims) Valid() error { return nil }
|
|||||||
func (a AccessTokenClaims) Type() string { return "access-token" }
|
func (a AccessTokenClaims) Type() string { return "access-token" }
|
||||||
|
|
||||||
// CreateAccessToken creates an access token with the default 15 minute duration
|
// CreateAccessToken creates an access token with the default 15 minute duration
|
||||||
func CreateAccessToken(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
|
func CreateAccessToken(p *mjwt.Issuer, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
|
||||||
return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, aud, perms)
|
return CreateAccessTokenWithDuration(p, time.Minute*15, sub, id, aud, perms)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAccessTokenWithDuration creates an access token with a custom duration
|
// CreateAccessTokenWithDuration creates an access token with a custom duration
|
||||||
func CreateAccessTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage) (string, error) {
|
func CreateAccessTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *PermStorage) (string, error) {
|
||||||
return p.GenerateJwt(sub, id, aud, dur, &AccessTokenClaims{Perms: perms})
|
return p.GenerateJwt(sub, id, aud, dur, &AccessTokenClaims{Perms: perms})
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAccessTokenWithKID creates an access token with the default 15 minute duration and the specified kID
|
|
||||||
func CreateAccessTokenWithKID(p mjwt.Signer, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) {
|
|
||||||
return CreateAccessTokenWithDurationAndKID(p, time.Minute*15, sub, id, aud, perms, kID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateAccessTokenWithDurationAndKID creates an access token with a custom duration and the specified kID
|
|
||||||
func CreateAccessTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id string, aud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, error) {
|
|
||||||
return p.GenerateJwtWithKID(sub, id, aud, dur, &AccessTokenClaims{Perms: perms}, kID)
|
|
||||||
}
|
|
||||||
|
@ -1,55 +1,26 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"github.com/1f349/mjwt"
|
"github.com/1f349/mjwt"
|
||||||
"github.com/1f349/mjwt/claims"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCreateAccessToken(t *testing.T) {
|
func TestCreateAccessToken(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
ps := claims.NewPermStorage()
|
ps := NewPermStorage()
|
||||||
ps.Set("mjwt:test")
|
ps.Set("mjwt:test")
|
||||||
ps.Set("mjwt:test2")
|
ps.Set("mjwt:test2")
|
||||||
|
|
||||||
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
kStore := mjwt.NewKeyStore()
|
||||||
|
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
accessToken, err := CreateAccessToken(s, "1", "test", nil, ps)
|
accessToken, err := CreateAccessToken(s, "1", "test", nil, ps)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, "1", b.Subject)
|
|
||||||
assert.Equal(t, "test", b.ID)
|
|
||||||
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
|
|
||||||
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
|
||||||
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateAccessTokenInvalid(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
kStore := mjwt.NewMJwtKeyStore()
|
|
||||||
kStore.SetKey("test", key)
|
|
||||||
|
|
||||||
ps := claims.NewPermStorage()
|
|
||||||
ps.Set("mjwt:test")
|
|
||||||
ps.Set("mjwt:test2")
|
|
||||||
|
|
||||||
s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)
|
|
||||||
|
|
||||||
accessToken, err := CreateAccessTokenWithKID(s, "1", "test", nil, ps, "test")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "1", b.Subject)
|
assert.Equal(t, "1", b.Subject)
|
||||||
assert.Equal(t, "test", b.ID)
|
assert.Equal(t, "test", b.ID)
|
||||||
|
25
auth/pair.go
25
auth/pair.go
@ -2,20 +2,19 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/1f349/mjwt"
|
"github.com/1f349/mjwt"
|
||||||
"github.com/1f349/mjwt/claims"
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateTokenPair creates an access and refresh token pair using the default
|
// CreateTokenPair creates an access and refresh token pair using the default
|
||||||
// 15 minute and 7 day durations respectively
|
// 15 minute and 7 day durations respectively
|
||||||
func CreateTokenPair(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
|
func CreateTokenPair(p *mjwt.Issuer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
|
||||||
return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms)
|
return CreateTokenPairWithDuration(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateTokenPairWithDuration creates an access and refresh token pair using
|
// CreateTokenPairWithDuration creates an access and refresh token pair using
|
||||||
// custom durations for the access and refresh tokens
|
// custom durations for the access and refresh tokens
|
||||||
func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage) (string, string, error) {
|
func CreateTokenPairWithDuration(p *mjwt.Issuer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *PermStorage) (string, string, error) {
|
||||||
accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, aud, perms)
|
accessToken, err := CreateAccessTokenWithDuration(p, accessDur, sub, id, aud, perms)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
@ -26,23 +25,3 @@ func CreateTokenPairWithDuration(p mjwt.Signer, accessDur, refreshDur time.Durat
|
|||||||
}
|
}
|
||||||
return accessToken, refreshToken, nil
|
return accessToken, refreshToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateTokenPairWithKID creates an access and refresh token pair using the default
|
|
||||||
// 15 minute and 7 day durations respectively using the specified kID
|
|
||||||
func CreateTokenPairWithKID(p mjwt.Signer, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) {
|
|
||||||
return CreateTokenPairWithDurationAndKID(p, time.Minute*15, time.Hour*24*7, sub, id, rId, aud, rAud, perms, kID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateTokenPairWithDurationAndKID creates an access and refresh token pair using
|
|
||||||
// custom durations for the access and refresh tokens
|
|
||||||
func CreateTokenPairWithDurationAndKID(p mjwt.Signer, accessDur, refreshDur time.Duration, sub, id, rId string, aud, rAud jwt.ClaimStrings, perms *claims.PermStorage, kID string) (string, string, error) {
|
|
||||||
accessToken, err := CreateAccessTokenWithDurationAndKID(p, accessDur, sub, id, aud, perms, kID)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", err
|
|
||||||
}
|
|
||||||
refreshToken, err := CreateRefreshTokenWithDurationAndKID(p, refreshDur, sub, rId, id, rAud, kID)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", err
|
|
||||||
}
|
|
||||||
return accessToken, refreshToken, nil
|
|
||||||
}
|
|
||||||
|
@ -1,29 +1,26 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"github.com/1f349/mjwt"
|
"github.com/1f349/mjwt"
|
||||||
"github.com/1f349/mjwt/claims"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCreateTokenPair(t *testing.T) {
|
func TestCreateTokenPair(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
ps := claims.NewPermStorage()
|
ps := NewPermStorage()
|
||||||
ps.Set("mjwt:test")
|
ps.Set("mjwt:test")
|
||||||
ps.Set("mjwt:test2")
|
ps.Set("mjwt:test2")
|
||||||
|
|
||||||
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
kStore := mjwt.NewKeyStore()
|
||||||
|
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key2", kStore)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", "test2", nil, nil, ps)
|
accessToken, refreshToken, err := CreateTokenPair(s, "1", "test", "test2", nil, nil, ps)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](kStore, accessToken)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "1", b.Subject)
|
assert.Equal(t, "1", b.Subject)
|
||||||
assert.Equal(t, "test", b.ID)
|
assert.Equal(t, "test", b.ID)
|
||||||
@ -31,38 +28,7 @@ func TestCreateTokenPair(t *testing.T) {
|
|||||||
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
||||||
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
||||||
|
|
||||||
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, "1", b2.Subject)
|
|
||||||
assert.Equal(t, "test2", b2.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateTokenPairWithKID(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
kStore := mjwt.NewMJwtKeyStore()
|
|
||||||
kStore.SetKey("test", key)
|
|
||||||
|
|
||||||
ps := claims.NewPermStorage()
|
|
||||||
ps.Set("mjwt:test")
|
|
||||||
ps.Set("mjwt:test2")
|
|
||||||
|
|
||||||
s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)
|
|
||||||
|
|
||||||
accessToken, refreshToken, err := CreateTokenPairWithKID(s, "1", "test", "test2", nil, nil, ps, "test")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
_, b, err := mjwt.ExtractClaims[AccessTokenClaims](s, accessToken)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, "1", b.Subject)
|
|
||||||
assert.Equal(t, "test", b.ID)
|
|
||||||
assert.True(t, b.Claims.Perms.Has("mjwt:test"))
|
|
||||||
assert.True(t, b.Claims.Perms.Has("mjwt:test2"))
|
|
||||||
assert.False(t, b.Claims.Perms.Has("mjwt:test3"))
|
|
||||||
|
|
||||||
_, b2, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "1", b2.Subject)
|
assert.Equal(t, "1", b2.Subject)
|
||||||
assert.Equal(t, "test2", b2.ID)
|
assert.Equal(t, "test2", b2.ID)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package claims
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
@ -1,4 +1,4 @@
|
|||||||
package claims
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
@ -16,21 +16,11 @@ func (r RefreshTokenClaims) Valid() error { return nil }
|
|||||||
func (r RefreshTokenClaims) Type() string { return "refresh-token" }
|
func (r RefreshTokenClaims) Type() string { return "refresh-token" }
|
||||||
|
|
||||||
// CreateRefreshToken creates a refresh token with the default 7 day duration
|
// CreateRefreshToken creates a refresh token with the default 7 day duration
|
||||||
func CreateRefreshToken(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
func CreateRefreshToken(p *mjwt.Issuer, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
||||||
return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, ati, aud)
|
return CreateRefreshTokenWithDuration(p, time.Hour*24*7, sub, id, ati, aud)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateRefreshTokenWithDuration creates a refresh token with a custom duration
|
// CreateRefreshTokenWithDuration creates a refresh token with a custom duration
|
||||||
func CreateRefreshTokenWithDuration(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
func CreateRefreshTokenWithDuration(p *mjwt.Issuer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings) (string, error) {
|
||||||
return p.GenerateJwt(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati})
|
return p.GenerateJwt(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati})
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateRefreshTokenWithKID creates a refresh token with the default 7 day duration and the specified kID
|
|
||||||
func CreateRefreshTokenWithKID(p mjwt.Signer, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) {
|
|
||||||
return CreateRefreshTokenWithDurationAndKID(p, time.Hour*24*7, sub, id, ati, aud, kID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateRefreshTokenWithDurationAndKID creates a refresh token with a custom duration and the specified kID
|
|
||||||
func CreateRefreshTokenWithDurationAndKID(p mjwt.Signer, dur time.Duration, sub, id, ati string, aud jwt.ClaimStrings, kID string) (string, error) {
|
|
||||||
return p.GenerateJwtWithKID(sub, id, aud, dur, RefreshTokenClaims{AccessTokenId: ati}, kID)
|
|
||||||
}
|
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"github.com/1f349/mjwt"
|
"github.com/1f349/mjwt"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"testing"
|
"testing"
|
||||||
@ -10,35 +8,15 @@ import (
|
|||||||
|
|
||||||
func TestCreateRefreshToken(t *testing.T) {
|
func TestCreateRefreshToken(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
s := mjwt.NewMJwtSigner("mjwt.test", key)
|
kStore := mjwt.NewKeyStore()
|
||||||
|
s, err := mjwt.NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
refreshToken, err := CreateRefreshToken(s, "1", "test", "test2", nil)
|
refreshToken, err := CreateRefreshToken(s, "1", "test", "test2", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](kStore, refreshToken)
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, "1", b.Subject)
|
|
||||||
assert.Equal(t, "test", b.ID)
|
|
||||||
assert.Equal(t, "test2", b.Claims.AccessTokenId)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateRefreshTokenWithKID(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
kStore := mjwt.NewMJwtKeyStore()
|
|
||||||
kStore.SetKey("test", key)
|
|
||||||
|
|
||||||
s := mjwt.NewMJwtSignerWithKeyStore("mjwt.test", nil, kStore)
|
|
||||||
|
|
||||||
refreshToken, err := CreateRefreshTokenWithKID(s, "1", "test", "test2", nil, "test")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
_, b, err := mjwt.ExtractClaims[RefreshTokenClaims](s, refreshToken)
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "1", b.Subject)
|
assert.Equal(t, "1", b.Subject)
|
||||||
assert.Equal(t, "test", b.ID)
|
assert.Equal(t, "test", b.ID)
|
||||||
|
@ -10,11 +10,11 @@ import (
|
|||||||
var ErrClaimTypeMismatch = errors.New("claim type mismatch")
|
var ErrClaimTypeMismatch = errors.New("claim type mismatch")
|
||||||
|
|
||||||
// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct
|
// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct
|
||||||
func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
|
func wrapClaims[T Claims](sub, id, issuer string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
return (&BaseTypeClaims[T]{
|
return (&BaseTypeClaims[T]{
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
Issuer: p.Issuer(),
|
Issuer: issuer,
|
||||||
Subject: sub,
|
Subject: sub,
|
||||||
Audience: aud,
|
Audience: aud,
|
||||||
ExpiresAt: jwt.NewNumericDate(now.Add(dur)),
|
ExpiresAt: jwt.NewNumericDate(now.Add(dur)),
|
||||||
@ -28,12 +28,12 @@ func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur ti
|
|||||||
|
|
||||||
// ExtractClaims uses a Verifier to validate the MJWT token and returns the parsed
|
// ExtractClaims uses a Verifier to validate the MJWT token and returns the parsed
|
||||||
// token and BaseTypeClaims
|
// token and BaseTypeClaims
|
||||||
func ExtractClaims[T Claims](p Verifier, token string) (*jwt.Token, BaseTypeClaims[T], error) {
|
func ExtractClaims[T Claims](ks *KeyStore, token string) (*jwt.Token, BaseTypeClaims[T], error) {
|
||||||
b := BaseTypeClaims[T]{
|
b := BaseTypeClaims[T]{
|
||||||
RegisteredClaims: jwt.RegisteredClaims{},
|
RegisteredClaims: jwt.RegisteredClaims{},
|
||||||
Claims: *new(T),
|
Claims: *new(T),
|
||||||
}
|
}
|
||||||
tok, err := p.VerifyJwt(token, &b)
|
tok, err := ks.VerifyJwt(token, &b)
|
||||||
return tok, b, err
|
return tok, b, err
|
||||||
}
|
}
|
||||||
|
|
100
claims_test.go
Normal file
100
claims_test.go
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
package mjwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testClaims struct{ TestValue string }
|
||||||
|
|
||||||
|
func (t testClaims) Valid() error {
|
||||||
|
if t.TestValue != "hello" && t.TestValue != "world" {
|
||||||
|
return fmt.Errorf("TestValue should be hello")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t testClaims) Type() string { return "testClaims" }
|
||||||
|
|
||||||
|
type testClaims2 struct{ TestValue2 string }
|
||||||
|
|
||||||
|
func (t testClaims2) Valid() error {
|
||||||
|
if t.TestValue2 != "world" {
|
||||||
|
return fmt.Errorf("TestValue2 should be world")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t testClaims2) Type() string { return "testClaims2" }
|
||||||
|
|
||||||
|
func TestExtractClaims(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
kStore := NewKeyStore()
|
||||||
|
|
||||||
|
t.Run("TestNoKID", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
s, err := NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
a, _, err := ExtractClaims[testClaims](kStore, token)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
kid, _ := a.Header["kid"].(string)
|
||||||
|
assert.Equal(t, "key1", kid)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TestKID", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
s, err := NewIssuerWithKeyStore("mjwt.test", "key2", kStore)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
s2, err := NewIssuerWithKeyStore("mjwt.test", "key3", kStore)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
token2, err := s2.GenerateJwt("2", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
k1, _, err := ExtractClaims[testClaims](kStore, token1)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
k2, _, err := ExtractClaims[testClaims](kStore, token2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEqual(t, k1.Header["kid"], k2.Header["kid"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractClaimsFail(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
kStore := NewKeyStore()
|
||||||
|
|
||||||
|
t.Run("TestInvalidClaims", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
s, err := NewIssuerWithKeyStore("mjwt.test", "key1", kStore)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
_, _, err = ExtractClaims[testClaims2](kStore, token)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrClaimTypeMismatch)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TestKIDNonExist", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
s, err := NewIssuerWithKeyStore("mjwt.test", "key2", kStore)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
kStore.RemoveKey("key2")
|
||||||
|
assert.NotContains(t, kStore.ListKeys(), "key2")
|
||||||
|
|
||||||
|
_, _, err = ExtractClaims[testClaims](kStore, token)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrMissingPublicKey)
|
||||||
|
})
|
||||||
|
}
|
@ -6,7 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/1f349/mjwt"
|
"github.com/1f349/mjwt"
|
||||||
"github.com/1f349/mjwt/auth"
|
"github.com/1f349/mjwt/auth"
|
||||||
"github.com/1f349/mjwt/claims"
|
|
||||||
"github.com/1f349/rsa-helper/rsaprivate"
|
"github.com/1f349/rsa-helper/rsaprivate"
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/google/subcommands"
|
"github.com/google/subcommands"
|
||||||
@ -35,7 +34,7 @@ func (s *accessCmd) SetFlags(f *flag.FlagSet) {
|
|||||||
f.StringVar(&s.id, "id", "", "MJWT ID")
|
f.StringVar(&s.id, "id", "", "MJWT ID")
|
||||||
f.StringVar(&s.audience, "aud", "", "Comma separated audience items for the MJWT")
|
f.StringVar(&s.audience, "aud", "", "Comma separated audience items for the MJWT")
|
||||||
f.StringVar(&s.duration, "dur", "15m", "Duration for the MJWT (default: 15m)")
|
f.StringVar(&s.duration, "dur", "15m", "Duration for the MJWT (default: 15m)")
|
||||||
f.StringVar(&s.kID, "kid", "\x00", "The Key ID of the signing key")
|
f.StringVar(&s.kID, "kid", "", "The Key ID of the signing key")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
|
func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
|
||||||
@ -51,7 +50,7 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
|
|||||||
return subcommands.ExitFailure
|
return subcommands.ExitFailure
|
||||||
}
|
}
|
||||||
|
|
||||||
ps := claims.NewPermStorage()
|
ps := auth.NewPermStorage()
|
||||||
for i := 1; i < len(args); i++ {
|
for i := 1; i < len(args); i++ {
|
||||||
ps.Set(args[i])
|
ps.Set(args[i])
|
||||||
}
|
}
|
||||||
@ -67,16 +66,16 @@ func (s *accessCmd) Execute(_ context.Context, f *flag.FlagSet, _ ...interface{}
|
|||||||
}
|
}
|
||||||
|
|
||||||
var token string
|
var token string
|
||||||
if s.kID == "\x00" {
|
|
||||||
signer := mjwt.NewMJwtSigner(s.issuer, key)
|
kStore := mjwt.NewKeyStore()
|
||||||
token, err = signer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
|
kStore.LoadPrivateKey(s.kID, key)
|
||||||
} else {
|
|
||||||
kStore := mjwt.NewMJwtKeyStore()
|
issuer, err := mjwt.NewIssuerWithKeyStore(s.issuer, s.kID, kStore)
|
||||||
kStore.SetKey(s.kID, key)
|
if err != nil {
|
||||||
signer := mjwt.NewMJwtSignerWithKeyStore(s.issuer, nil, kStore)
|
panic("this should not fail")
|
||||||
token, err = signer.GenerateJwtWithKID(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps}, s.kID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
token, err = issuer.GenerateJwt(s.subject, s.id, aud, dur, auth.AccessTokenClaims{Perms: ps})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_, _ = fmt.Fprintln(os.Stderr, "Error: Failed to generate MJWT token: ", err)
|
_, _ = fmt.Fprintln(os.Stderr, "Error: Failed to generate MJWT token: ", err)
|
||||||
return subcommands.ExitFailure
|
return subcommands.ExitFailure
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package claims
|
package mjwt
|
||||||
|
|
||||||
// EmptyClaims contains no claims
|
// EmptyClaims contains no claims
|
||||||
type EmptyClaims struct{}
|
type EmptyClaims struct{}
|
8
go.mod
8
go.mod
@ -10,11 +10,17 @@ require (
|
|||||||
github.com/golang-jwt/jwt/v4 v4.5.0
|
github.com/golang-jwt/jwt/v4 v4.5.0
|
||||||
github.com/google/subcommands v1.2.0
|
github.com/google/subcommands v1.2.0
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
github.com/stretchr/testify v1.8.4
|
github.com/spf13/afero v1.11.0
|
||||||
|
github.com/stretchr/testify v1.9.0
|
||||||
|
golang.org/x/sync v0.7.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/kr/pretty v0.3.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/rogpeppe/go-internal v1.12.0 // indirect
|
||||||
|
golang.org/x/text v0.16.0 // indirect
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||||
)
|
)
|
||||||
|
25
go.sum
25
go.sum
@ -2,19 +2,38 @@ github.com/1f349/rsa-helper v0.0.2 h1:N/fLQqg5wrjIzG6G4zdwa5Xcv9/jIPutCls9YekZr9
|
|||||||
github.com/1f349/rsa-helper v0.0.2/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg=
|
github.com/1f349/rsa-helper v0.0.2/go.mod h1:VUQ++1tYYhYrXeOmVFkQ82BegR24HQEJHl5lHbjg7yg=
|
||||||
github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA=
|
github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA=
|
||||||
github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4=
|
github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4=
|
||||||
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
|
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||||
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
||||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||||
|
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||||
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
|
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||||
|
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||||
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
|
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||||
|
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||||
|
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||||
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||||
|
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
|
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
||||||
|
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
@ -1,39 +0,0 @@
|
|||||||
package mjwt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rsa"
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Signer is used to generate MJWT tokens.
|
|
||||||
// Signer can also be used as a Verifier.
|
|
||||||
type Signer interface {
|
|
||||||
Verifier
|
|
||||||
GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error)
|
|
||||||
SignJwt(claims jwt.Claims) (string, error)
|
|
||||||
GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error)
|
|
||||||
SignJwtWithKID(claims jwt.Claims, kID string) (string, error)
|
|
||||||
Issuer() string
|
|
||||||
PrivateKey() *rsa.PrivateKey
|
|
||||||
PrivateKeyOf(kID string) *rsa.PrivateKey
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verifier is used to verify the validity MJWT tokens and extract the claim values.
|
|
||||||
type Verifier interface {
|
|
||||||
VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error)
|
|
||||||
PublicKey() *rsa.PublicKey
|
|
||||||
PublicKeyOf(kID string) *rsa.PublicKey
|
|
||||||
GetKeyStore() KeyStore
|
|
||||||
}
|
|
||||||
|
|
||||||
// KeyStore is used for the kid header support in Signer and Verifier.
|
|
||||||
type KeyStore interface {
|
|
||||||
SetKey(kID string, prvKey *rsa.PrivateKey)
|
|
||||||
SetKeyPublic(kID string, pubKey *rsa.PublicKey)
|
|
||||||
RemoveKey(kID string)
|
|
||||||
ListKeys() []string
|
|
||||||
GetKey(kID string) *rsa.PrivateKey
|
|
||||||
GetKeyPublic(kID string) *rsa.PublicKey
|
|
||||||
ClearKeys()
|
|
||||||
}
|
|
49
issuer.go
Normal file
49
issuer.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package mjwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Issuer struct {
|
||||||
|
issuer string
|
||||||
|
kid string
|
||||||
|
keystore *KeyStore
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewIssuer(name, kid string) (*Issuer, error) {
|
||||||
|
return NewIssuerWithKeyStore(name, kid, NewKeyStore())
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewIssuerWithKeyStore(name, kid string, keystore *KeyStore) (*Issuer, error) {
|
||||||
|
i := &Issuer{name, kid, keystore}
|
||||||
|
if i.keystore.HasPrivateKey(kid) {
|
||||||
|
return i, nil
|
||||||
|
}
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
i.keystore.LoadPrivateKey(kid, key)
|
||||||
|
return i, i.keystore.SaveSingleKey(kid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Issuer) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
|
||||||
|
return i.SignJwt(wrapClaims[Claims](sub, id, i.issuer, aud, dur, claims))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Issuer) SignJwt(wrapped jwt.Claims) (string, error) {
|
||||||
|
key, err := i.PrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
|
||||||
|
token.Header["kid"] = i.kid
|
||||||
|
return token.SignedString(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Issuer) PrivateKey() (*rsa.PrivateKey, error) {
|
||||||
|
return i.keystore.GetPrivateKey(i.kid)
|
||||||
|
}
|
58
issuer_test.go
Normal file
58
issuer_test.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package mjwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"github.com/1f349/rsa-helper/rsaprivate"
|
||||||
|
"github.com/spf13/afero"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewIssuer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
t.Run("generate missing key for issuer", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
kStore := NewKeyStore()
|
||||||
|
issuer, err := NewIssuerWithKeyStore("Test", "test", kStore)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, kStore.HasPrivateKey("test"))
|
||||||
|
assert.True(t, kStore.HasPublicKey("test"))
|
||||||
|
assert.Equal(t, "Test", issuer.issuer)
|
||||||
|
assert.Equal(t, "test", issuer.kid)
|
||||||
|
})
|
||||||
|
t.Run("use existing issuer key", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
kStore := NewKeyStore()
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
kStore.LoadPrivateKey("test", key)
|
||||||
|
issuer, err := NewIssuerWithKeyStore("Test", "test", kStore)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, kStore.HasPrivateKey("test"))
|
||||||
|
assert.True(t, kStore.HasPublicKey("test"))
|
||||||
|
assert.Equal(t, "Test", issuer.issuer)
|
||||||
|
assert.Equal(t, "test", issuer.kid)
|
||||||
|
privateKey, err := issuer.PrivateKey()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, key.Equal(privateKey))
|
||||||
|
})
|
||||||
|
t.Run("generate missing key in filesystem", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
dir := afero.NewMemMapFs()
|
||||||
|
kStore := NewKeyStoreWithDir(dir)
|
||||||
|
issuer, err := NewIssuerWithKeyStore("Test", "test", kStore)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, kStore.HasPrivateKey("test"))
|
||||||
|
assert.True(t, kStore.HasPublicKey("test"))
|
||||||
|
assert.Equal(t, "Test", issuer.issuer)
|
||||||
|
assert.Equal(t, "test", issuer.kid)
|
||||||
|
privKeyFile, err := dir.Open("test.private.pem")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
privKey, err := rsaprivate.Decode(privKeyFile)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
key, err := issuer.PrivateKey()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, key.Equal(privKey))
|
||||||
|
})
|
||||||
|
}
|
185
key_store.go
185
key_store.go
@ -1,185 +0,0 @@
|
|||||||
package mjwt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rsa"
|
|
||||||
"errors"
|
|
||||||
"github.com/1f349/rsa-helper/rsaprivate"
|
|
||||||
"github.com/1f349/rsa-helper/rsapublic"
|
|
||||||
"os"
|
|
||||||
"path"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// defaultMJwtKeyStore implements KeyStore and stores kIDs against just rsa.PublicKey
|
|
||||||
// or with rsa.PrivateKey instances as well.
|
|
||||||
type defaultMJwtKeyStore struct {
|
|
||||||
rwLocker *sync.RWMutex
|
|
||||||
store map[string]*rsa.PrivateKey
|
|
||||||
storePub map[string]*rsa.PublicKey
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ KeyStore = &defaultMJwtKeyStore{}
|
|
||||||
|
|
||||||
// NewMJwtKeyStore creates a new defaultMJwtKeyStore.
|
|
||||||
func NewMJwtKeyStore() KeyStore {
|
|
||||||
return &defaultMJwtKeyStore{
|
|
||||||
rwLocker: new(sync.RWMutex),
|
|
||||||
store: make(map[string]*rsa.PrivateKey),
|
|
||||||
storePub: make(map[string]*rsa.PublicKey),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMJwtKeyStoreFromDirectory loads keys from a directory with the specified extensions to denote public and private
|
|
||||||
// rsa keys; the kID is the filename of the key up to the first .
|
|
||||||
func NewMJwtKeyStoreFromDirectory(directory, keyPrvExt, keyPubExt string) (KeyStore, error) {
|
|
||||||
// Create empty KeyStore
|
|
||||||
ks := NewMJwtKeyStore().(*defaultMJwtKeyStore)
|
|
||||||
// List directory contents
|
|
||||||
dirEntries, err := os.ReadDir(directory)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
errs := make([]error, 0, len(dirEntries)/2)
|
|
||||||
// Import keys from files, based on extension
|
|
||||||
for _, entry := range dirEntries {
|
|
||||||
if entry.IsDir() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
kID, _, _ := strings.Cut(entry.Name(), ".")
|
|
||||||
if kID == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
pExt := path.Ext(entry.Name())
|
|
||||||
if pExt == "."+keyPrvExt {
|
|
||||||
// Load rsa private key with the file name as the kID (Up to the first .)
|
|
||||||
key, err2 := rsaprivate.Read(path.Join(directory, entry.Name()))
|
|
||||||
if err2 == nil {
|
|
||||||
ks.store[kID] = key
|
|
||||||
ks.storePub[kID] = &key.PublicKey
|
|
||||||
}
|
|
||||||
errs = append(errs, err2)
|
|
||||||
} else if pExt == "."+keyPubExt {
|
|
||||||
// Load rsa public key with the file name as the kID (Up to the first .)
|
|
||||||
key, err2 := rsapublic.Read(path.Join(directory, entry.Name()))
|
|
||||||
if err2 == nil {
|
|
||||||
_, exs := ks.store[kID]
|
|
||||||
if !exs {
|
|
||||||
ks.store[kID] = nil
|
|
||||||
}
|
|
||||||
ks.storePub[kID] = key
|
|
||||||
}
|
|
||||||
errs = append(errs, err2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ks, errors.Join(errs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExportKeyStore saves all the keys stored in the specified KeyStore into a directory with the specified
|
|
||||||
// extensions for public and private keys
|
|
||||||
func ExportKeyStore(ks KeyStore, directory, keyPrvExt, keyPubExt string) error {
|
|
||||||
if ks == nil {
|
|
||||||
return errors.New("ks is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create directory
|
|
||||||
err := os.MkdirAll(directory, 0700)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
errs := make([]error, 0, len(ks.ListKeys())/2)
|
|
||||||
// Export all keys
|
|
||||||
for _, kID := range ks.ListKeys() {
|
|
||||||
kPrv := ks.GetKey(kID)
|
|
||||||
if kPrv != nil {
|
|
||||||
err2 := rsaprivate.Write(path.Join(directory, kID+"."+keyPrvExt), kPrv)
|
|
||||||
errs = append(errs, err2)
|
|
||||||
}
|
|
||||||
kPub := ks.GetKeyPublic(kID)
|
|
||||||
if kPub != nil {
|
|
||||||
err2 := rsapublic.Write(path.Join(directory, kID+"."+keyPubExt), kPub)
|
|
||||||
errs = append(errs, err2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return errors.Join(errs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetKey adds a new rsa.PrivateKey with the specified kID to the KeyStore.
|
|
||||||
func (d *defaultMJwtKeyStore) SetKey(kID string, prvKey *rsa.PrivateKey) {
|
|
||||||
if prvKey == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
d.rwLocker.Lock()
|
|
||||||
defer d.rwLocker.Unlock()
|
|
||||||
d.store[kID] = prvKey
|
|
||||||
d.storePub[kID] = &prvKey.PublicKey
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetKeyPublic adds a new rsa.PublicKey with the specified kID to the KeyStore.
|
|
||||||
func (d *defaultMJwtKeyStore) SetKeyPublic(kID string, pubKey *rsa.PublicKey) {
|
|
||||||
if pubKey == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
d.rwLocker.Lock()
|
|
||||||
defer d.rwLocker.Unlock()
|
|
||||||
_, exs := d.store[kID]
|
|
||||||
if !exs {
|
|
||||||
d.store[kID] = nil
|
|
||||||
}
|
|
||||||
d.storePub[kID] = pubKey
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveKey removes a specified kID from the KeyStore.
|
|
||||||
func (d *defaultMJwtKeyStore) RemoveKey(kID string) {
|
|
||||||
d.rwLocker.Lock()
|
|
||||||
defer d.rwLocker.Unlock()
|
|
||||||
delete(d.store, kID)
|
|
||||||
delete(d.storePub, kID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListKeys lists the kIDs of all the keys in the KeyStore.
|
|
||||||
func (d *defaultMJwtKeyStore) ListKeys() []string {
|
|
||||||
d.rwLocker.RLock()
|
|
||||||
defer d.rwLocker.RUnlock()
|
|
||||||
lKeys := make([]string, len(d.store))
|
|
||||||
i := 0
|
|
||||||
for k := range d.store {
|
|
||||||
lKeys[i] = k
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
return lKeys
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetKey gets the rsa.PrivateKey given the kID in the KeyStore or null if not found.
|
|
||||||
func (d *defaultMJwtKeyStore) GetKey(kID string) *rsa.PrivateKey {
|
|
||||||
d.rwLocker.RLock()
|
|
||||||
defer d.rwLocker.RUnlock()
|
|
||||||
kPrv, ok := d.store[kID]
|
|
||||||
if ok {
|
|
||||||
return kPrv
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetKeyPublic gets the rsa.PublicKey given the kID in the KeyStore or null if not found.
|
|
||||||
func (d *defaultMJwtKeyStore) GetKeyPublic(kID string) *rsa.PublicKey {
|
|
||||||
d.rwLocker.RLock()
|
|
||||||
defer d.rwLocker.RUnlock()
|
|
||||||
kPub, ok := d.storePub[kID]
|
|
||||||
if ok {
|
|
||||||
return kPub
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearKeys removes all the stored keys in the KeyStore.
|
|
||||||
func (d *defaultMJwtKeyStore) ClearKeys() {
|
|
||||||
d.rwLocker.Lock()
|
|
||||||
defer d.rwLocker.Unlock()
|
|
||||||
clear(d.store)
|
|
||||||
clear(d.storePub)
|
|
||||||
}
|
|
@ -1,152 +0,0 @@
|
|||||||
package mjwt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"github.com/1f349/rsa-helper/rsaprivate"
|
|
||||||
"github.com/1f349/rsa-helper/rsapublic"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"os"
|
|
||||||
"path"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
const kst_prvExt = "prv"
|
|
||||||
const kst_pubExt = "pub"
|
|
||||||
|
|
||||||
func setupTestDirKeyStore(t *testing.T, genKeys bool) (string, func(t *testing.T)) {
|
|
||||||
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
if genKeys {
|
|
||||||
key1, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+kst_prvExt), key1)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
key2, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+kst_prvExt), key2)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = rsapublic.Write(path.Join(tempDir, "key2.pem."+kst_pubExt), &key2.PublicKey)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
key3, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = rsapublic.Write(path.Join(tempDir, "key3.pem."+kst_pubExt), &key3.PublicKey)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tempDir, func(t *testing.T) {
|
|
||||||
err := os.RemoveAll(tempDir)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func commonSubTestsKeyStore(t *testing.T, kStore KeyStore) {
|
|
||||||
key4, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
key5, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
const extraKID1 = "key4"
|
|
||||||
const extraKID2 = "key5"
|
|
||||||
|
|
||||||
t.Run("TestSetKey", func(t *testing.T) {
|
|
||||||
kStore.SetKey(extraKID1, key4)
|
|
||||||
assert.Contains(t, kStore.ListKeys(), extraKID1)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("TestSetKeyPublic", func(t *testing.T) {
|
|
||||||
kStore.SetKeyPublic(extraKID2, &key5.PublicKey)
|
|
||||||
assert.Contains(t, kStore.ListKeys(), extraKID2)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("TestGetKey", func(t *testing.T) {
|
|
||||||
oKey := kStore.GetKey(extraKID1)
|
|
||||||
assert.Same(t, key4, oKey)
|
|
||||||
pKey := kStore.GetKey(extraKID2)
|
|
||||||
assert.Nil(t, pKey)
|
|
||||||
aKey := kStore.GetKey("key1")
|
|
||||||
assert.NotNil(t, aKey)
|
|
||||||
bKey := kStore.GetKey("key2")
|
|
||||||
assert.NotNil(t, bKey)
|
|
||||||
cKey := kStore.GetKey("key3")
|
|
||||||
assert.Nil(t, cKey)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("TestGetKeyPublic", func(t *testing.T) {
|
|
||||||
oKey := kStore.GetKeyPublic(extraKID1)
|
|
||||||
assert.Same(t, &key4.PublicKey, oKey)
|
|
||||||
pKey := kStore.GetKeyPublic(extraKID2)
|
|
||||||
assert.Same(t, &key5.PublicKey, pKey)
|
|
||||||
aKey := kStore.GetKeyPublic("key1")
|
|
||||||
assert.NotNil(t, aKey)
|
|
||||||
bKey := kStore.GetKeyPublic("key2")
|
|
||||||
assert.NotNil(t, bKey)
|
|
||||||
cKey := kStore.GetKeyPublic("key3")
|
|
||||||
assert.NotNil(t, cKey)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("TestRemoveKey", func(t *testing.T) {
|
|
||||||
kStore.RemoveKey(extraKID1)
|
|
||||||
assert.NotContains(t, kStore.ListKeys(), extraKID1)
|
|
||||||
oKey1 := kStore.GetKey(extraKID1)
|
|
||||||
assert.Nil(t, oKey1)
|
|
||||||
oKey2 := kStore.GetKeyPublic(extraKID1)
|
|
||||||
assert.Nil(t, oKey2)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("TestClearKeys", func(t *testing.T) {
|
|
||||||
kStore.ClearKeys()
|
|
||||||
assert.Empty(t, kStore.ListKeys())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tempDir, cleaner := setupTestDirKeyStore(t, true)
|
|
||||||
defer cleaner(t)
|
|
||||||
|
|
||||||
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, kst_prvExt, kst_pubExt)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Len(t, kStore.ListKeys(), 3)
|
|
||||||
kIDsToFind := []string{"key1", "key2", "key3"}
|
|
||||||
for _, k := range kIDsToFind {
|
|
||||||
assert.Contains(t, kStore.ListKeys(), k)
|
|
||||||
}
|
|
||||||
|
|
||||||
commonSubTestsKeyStore(t, kStore)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExportKeyStore(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tempDir, cleaner := setupTestDirKeyStore(t, true)
|
|
||||||
defer cleaner(t)
|
|
||||||
tempDir2, cleaner2 := setupTestDirKeyStore(t, false)
|
|
||||||
defer cleaner2(t)
|
|
||||||
|
|
||||||
kStore, err := NewMJwtKeyStoreFromDirectory(tempDir, kst_prvExt, kst_pubExt)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
const prvExt2 = "v"
|
|
||||||
const pubExt2 = "b"
|
|
||||||
|
|
||||||
err = ExportKeyStore(kStore, tempDir2, prvExt2, pubExt2)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
kStore2, err := NewMJwtKeyStoreFromDirectory(tempDir2, prvExt2, pubExt2)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
kIDsToFind := kStore.ListKeys()
|
|
||||||
assert.Len(t, kStore2.ListKeys(), len(kIDsToFind))
|
|
||||||
for _, k := range kIDsToFind {
|
|
||||||
assert.Contains(t, kStore2.ListKeys(), k)
|
|
||||||
}
|
|
||||||
|
|
||||||
commonSubTestsKeyStore(t, kStore2)
|
|
||||||
}
|
|
233
keystore.go
Normal file
233
keystore.go
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
package mjwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rsa"
|
||||||
|
"errors"
|
||||||
|
"github.com/1f349/rsa-helper/rsaprivate"
|
||||||
|
"github.com/1f349/rsa-helper/rsapublic"
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/spf13/afero"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
"io/fs"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrMissingPrivateKey = errors.New("missing private key")
|
||||||
|
var ErrMissingPublicKey = errors.New("missing public key")
|
||||||
|
var ErrMissingKeyPair = errors.New("missing key pair")
|
||||||
|
|
||||||
|
const PrivateStr = ".private"
|
||||||
|
const PublicStr = ".public"
|
||||||
|
|
||||||
|
const PemExt = ".pem"
|
||||||
|
const PrivatePemExt = PrivateStr + PemExt
|
||||||
|
const PublicPemExt = PublicStr + PemExt
|
||||||
|
|
||||||
|
type KeyStore struct {
|
||||||
|
mu *sync.RWMutex
|
||||||
|
store map[string]*keyPair
|
||||||
|
dir afero.Fs
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewKeyStore() *KeyStore {
|
||||||
|
return &KeyStore{
|
||||||
|
mu: new(sync.RWMutex),
|
||||||
|
store: make(map[string]*keyPair),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewKeyStoreWithDir(dir afero.Fs) *KeyStore {
|
||||||
|
keyStore := NewKeyStore()
|
||||||
|
keyStore.dir = dir
|
||||||
|
return keyStore
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewKeyStoreFromDir(dir afero.Fs) (*KeyStore, error) {
|
||||||
|
keyStore := NewKeyStoreWithDir(dir)
|
||||||
|
err := afero.Walk(dir, ".", func(path string, d fs.FileInfo, err error) error {
|
||||||
|
// maybe this is "name.private.pem"
|
||||||
|
name := filepath.Base(path)
|
||||||
|
ext := filepath.Ext(name)
|
||||||
|
if ext != PemExt {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
name = strings.TrimSuffix(name, ext)
|
||||||
|
ext = filepath.Ext(name)
|
||||||
|
name = strings.TrimSuffix(name, ext)
|
||||||
|
switch ext {
|
||||||
|
case PrivateStr:
|
||||||
|
open, err := dir.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
decode, err := rsaprivate.Decode(open)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
keyStore.LoadPrivateKey(name, decode)
|
||||||
|
return nil
|
||||||
|
case PublicStr:
|
||||||
|
open, err := dir.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
decode, err := rsapublic.Decode(open)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
keyStore.LoadPublicKey(name, decode)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// still invalid
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return keyStore, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type keyPair struct {
|
||||||
|
private *rsa.PrivateKey
|
||||||
|
public *rsa.PublicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KeyStore) LoadPrivateKey(kid string, key *rsa.PrivateKey) {
|
||||||
|
k.mu.Lock()
|
||||||
|
if k.store[kid] == nil {
|
||||||
|
k.store[kid] = &keyPair{}
|
||||||
|
}
|
||||||
|
k.store[kid].private = key
|
||||||
|
k.store[kid].public = &key.PublicKey
|
||||||
|
k.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KeyStore) LoadPublicKey(kid string, key *rsa.PublicKey) {
|
||||||
|
k.mu.Lock()
|
||||||
|
if k.store[kid] == nil {
|
||||||
|
k.store[kid] = &keyPair{}
|
||||||
|
}
|
||||||
|
k.store[kid].public = key
|
||||||
|
k.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KeyStore) RemoveKey(kid string) {
|
||||||
|
k.mu.Lock()
|
||||||
|
delete(k.store, kid)
|
||||||
|
k.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KeyStore) ListKeys() []string {
|
||||||
|
k.mu.RLock()
|
||||||
|
defer k.mu.RUnlock()
|
||||||
|
keys := make([]string, 0, len(k.store))
|
||||||
|
for k, _ := range k.store {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KeyStore) GetPrivateKey(kid string) (*rsa.PrivateKey, error) {
|
||||||
|
k.mu.RLock()
|
||||||
|
defer k.mu.RUnlock()
|
||||||
|
if !k.internalHasPrivateKey(kid) {
|
||||||
|
return nil, ErrMissingPrivateKey
|
||||||
|
}
|
||||||
|
return k.store[kid].private, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KeyStore) GetPublicKey(kid string) (*rsa.PublicKey, error) {
|
||||||
|
k.mu.RLock()
|
||||||
|
defer k.mu.RUnlock()
|
||||||
|
if !k.internalHasPublicKey(kid) {
|
||||||
|
return nil, ErrMissingPublicKey
|
||||||
|
}
|
||||||
|
return k.store[kid].public, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KeyStore) ClearKeys() {
|
||||||
|
k.mu.Lock()
|
||||||
|
clear(k.store)
|
||||||
|
k.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KeyStore) HasPrivateKey(kid string) bool {
|
||||||
|
k.mu.RLock()
|
||||||
|
defer k.mu.RUnlock()
|
||||||
|
return k.internalHasPrivateKey(kid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KeyStore) internalHasPrivateKey(kid string) bool {
|
||||||
|
v := k.store[kid]
|
||||||
|
return v != nil && v.private != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KeyStore) HasPublicKey(kid string) bool {
|
||||||
|
k.mu.RLock()
|
||||||
|
defer k.mu.RUnlock()
|
||||||
|
return k.internalHasPublicKey(kid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KeyStore) internalHasPublicKey(kid string) bool {
|
||||||
|
v := k.store[kid]
|
||||||
|
return v != nil && v.public != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KeyStore) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
|
||||||
|
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
kid, ok := token.Header["kid"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrMissingPublicKey
|
||||||
|
}
|
||||||
|
return k.GetPublicKey(kid)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return withClaims, claims.Valid()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KeyStore) SaveSingleKey(kid string) error {
|
||||||
|
if k.dir == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
k.mu.RLock()
|
||||||
|
pair := k.store[kid]
|
||||||
|
k.mu.RUnlock()
|
||||||
|
if pair == nil {
|
||||||
|
return ErrMissingKeyPair
|
||||||
|
}
|
||||||
|
|
||||||
|
var errs []error
|
||||||
|
if pair.private != nil {
|
||||||
|
errs = append(errs, afero.WriteFile(k.dir, kid+PrivatePemExt, rsaprivate.Encode(pair.private), 0600))
|
||||||
|
}
|
||||||
|
if pair.public != nil {
|
||||||
|
errs = append(errs, afero.WriteFile(k.dir, kid+PublicPemExt, rsapublic.Encode(pair.public), 0600))
|
||||||
|
}
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KeyStore) SaveKeys() error {
|
||||||
|
k.mu.RLock()
|
||||||
|
defer k.mu.RUnlock()
|
||||||
|
|
||||||
|
workers := new(errgroup.Group)
|
||||||
|
workers.SetLimit(runtime.NumCPU())
|
||||||
|
for kid, pair := range k.store {
|
||||||
|
workers.Go(func() error {
|
||||||
|
var errs []error
|
||||||
|
if pair.private != nil {
|
||||||
|
errs = append(errs, afero.WriteFile(k.dir, kid+PrivatePemExt, rsaprivate.Encode(pair.private), 0600))
|
||||||
|
}
|
||||||
|
if pair.public != nil {
|
||||||
|
errs = append(errs, afero.WriteFile(k.dir, kid+PublicPemExt, rsapublic.Encode(pair.public), 0600))
|
||||||
|
}
|
||||||
|
return errors.Join(errs...)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return workers.Wait()
|
||||||
|
}
|
175
keystore_test.go
Normal file
175
keystore_test.go
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
package mjwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"github.com/1f349/rsa-helper/rsaprivate"
|
||||||
|
"github.com/1f349/rsa-helper/rsapublic"
|
||||||
|
"github.com/spf13/afero"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
const kst_prvExt = "prv"
|
||||||
|
const kst_pubExt = "pub"
|
||||||
|
|
||||||
|
func setupTestDirKeyStore(t *testing.T, genKeys bool) afero.Fs {
|
||||||
|
tempDir := afero.NewMemMapFs()
|
||||||
|
|
||||||
|
if genKeys {
|
||||||
|
key1, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = afero.WriteFile(tempDir, "key1.private.pem", rsaprivate.Encode(key1), 0600)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
key2, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = afero.WriteFile(tempDir, "key2.private.pem", rsaprivate.Encode(key2), 0600)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = afero.WriteFile(tempDir, "key2.public.pem", rsapublic.Encode(&key2.PublicKey), 0600)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
key3, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = afero.WriteFile(tempDir, "key3.public.pem", rsapublic.Encode(&key3.PublicKey), 0600)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tempDir
|
||||||
|
}
|
||||||
|
|
||||||
|
func commonSubTestsKeyStore(t *testing.T, kStore *KeyStore) {
|
||||||
|
key4, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
key5, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
const extraKID1 = "key4"
|
||||||
|
const extraKID2 = "key5"
|
||||||
|
|
||||||
|
t.Run("TestSetKey", func(t *testing.T) {
|
||||||
|
kStore.LoadPrivateKey(extraKID1, key4)
|
||||||
|
assert.Contains(t, kStore.ListKeys(), extraKID1)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TestSetKeyPublic", func(t *testing.T) {
|
||||||
|
kStore.LoadPublicKey(extraKID2, &key5.PublicKey)
|
||||||
|
assert.Contains(t, kStore.ListKeys(), extraKID2)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TestGetPrivateKey", func(t *testing.T) {
|
||||||
|
oKey, err := kStore.GetPrivateKey(extraKID1)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Same(t, key4, oKey)
|
||||||
|
pKey, err := kStore.GetPrivateKey(extraKID2)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrMissingPrivateKey)
|
||||||
|
assert.Nil(t, pKey)
|
||||||
|
aKey, err := kStore.GetPrivateKey("key1")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, aKey)
|
||||||
|
bKey, err := kStore.GetPrivateKey("key2")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, bKey)
|
||||||
|
cKey, err := kStore.GetPrivateKey("key3")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrMissingPrivateKey)
|
||||||
|
assert.Nil(t, cKey)
|
||||||
|
wKey, err := kStore.GetPrivateKey("key1337")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrMissingPrivateKey)
|
||||||
|
assert.Nil(t, wKey)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TestGetPublicKey", func(t *testing.T) {
|
||||||
|
oKey, err := kStore.GetPublicKey(extraKID1)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Same(t, &key4.PublicKey, oKey)
|
||||||
|
pKey, err := kStore.GetPublicKey(extraKID2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Same(t, &key5.PublicKey, pKey)
|
||||||
|
aKey, err := kStore.GetPublicKey("key1")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, aKey)
|
||||||
|
bKey, err := kStore.GetPublicKey("key2")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, bKey)
|
||||||
|
cKey, err := kStore.GetPublicKey("key3")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, cKey)
|
||||||
|
wKey, err := kStore.GetPublicKey("key1337")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrMissingPublicKey)
|
||||||
|
assert.Nil(t, wKey)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TestRemoveKey", func(t *testing.T) {
|
||||||
|
kStore.RemoveKey(extraKID1)
|
||||||
|
assert.NotContains(t, kStore.ListKeys(), extraKID1)
|
||||||
|
oKey1, err := kStore.GetPrivateKey(extraKID1)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrMissingPrivateKey)
|
||||||
|
assert.Nil(t, oKey1)
|
||||||
|
oKey2, err := kStore.GetPublicKey(extraKID1)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrMissingPublicKey)
|
||||||
|
assert.Nil(t, oKey2)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TestClearKeys", func(t *testing.T) {
|
||||||
|
kStore.ClearKeys()
|
||||||
|
assert.Empty(t, kStore.ListKeys())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewMJwtKeyStoreFromDirectory(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tempDir := setupTestDirKeyStore(t, true)
|
||||||
|
|
||||||
|
kStore, err := NewKeyStoreFromDir(tempDir)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, kStore.ListKeys(), 3)
|
||||||
|
kIDsToFind := []string{"key1", "key2", "key3"}
|
||||||
|
for _, k := range kIDsToFind {
|
||||||
|
assert.Contains(t, kStore.ListKeys(), k)
|
||||||
|
}
|
||||||
|
assert.True(t, kStore.HasPrivateKey("key1"))
|
||||||
|
assert.True(t, kStore.HasPublicKey("key1")) // loading a private key also loads the public key
|
||||||
|
assert.True(t, kStore.HasPrivateKey("key2"))
|
||||||
|
assert.True(t, kStore.HasPublicKey("key2"))
|
||||||
|
assert.False(t, kStore.HasPrivateKey("key3"))
|
||||||
|
assert.True(t, kStore.HasPublicKey("key3"))
|
||||||
|
|
||||||
|
commonSubTestsKeyStore(t, kStore)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExportKeyStore(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tempDir := setupTestDirKeyStore(t, true)
|
||||||
|
tempDir2 := setupTestDirKeyStore(t, false)
|
||||||
|
|
||||||
|
kStore, err := NewKeyStoreFromDir(tempDir)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// internally swap directory
|
||||||
|
kStore.dir = tempDir2
|
||||||
|
|
||||||
|
err = kStore.SaveKeys()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
kStore2, err := NewKeyStoreFromDir(tempDir2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
kidList1 := kStore.ListKeys()
|
||||||
|
kidList2 := kStore2.ListKeys()
|
||||||
|
sort.Strings(kidList1)
|
||||||
|
sort.Strings(kidList2)
|
||||||
|
assert.Equal(t, kidList1, kidList2)
|
||||||
|
|
||||||
|
commonSubTestsKeyStore(t, kStore2)
|
||||||
|
}
|
143
mjwt_test.go
143
mjwt_test.go
@ -1,143 +0,0 @@
|
|||||||
package mjwt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"fmt"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var mt_ExtraKID = "tester"
|
|
||||||
|
|
||||||
type testClaims struct{ TestValue string }
|
|
||||||
|
|
||||||
func (t testClaims) Valid() error {
|
|
||||||
if t.TestValue != "hello" && t.TestValue != "world" {
|
|
||||||
return fmt.Errorf("TestValue should be hello")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t testClaims) Type() string { return "testClaims" }
|
|
||||||
|
|
||||||
type testClaims2 struct{ TestValue2 string }
|
|
||||||
|
|
||||||
func (t testClaims2) Valid() error {
|
|
||||||
if t.TestValue2 != "world" {
|
|
||||||
return fmt.Errorf("TestValue2 should be world")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t testClaims2) Type() string { return "testClaims2" }
|
|
||||||
|
|
||||||
func setupTestKeyStoreMJWT(t *testing.T) (ks KeyStore, a, b, c *rsa.PrivateKey) {
|
|
||||||
ks = NewMJwtKeyStore()
|
|
||||||
var err error
|
|
||||||
|
|
||||||
a, err = rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
ks.SetKey("key1", a)
|
|
||||||
|
|
||||||
b, err = rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
ks.SetKey("key2", b)
|
|
||||||
|
|
||||||
c, err = rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
ks.SetKey("key3", c)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractClaims(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
kStore, key, _, _ := setupTestKeyStoreMJWT(t)
|
|
||||||
|
|
||||||
t.Run("TestNoKID", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
s := NewMJwtSigner("mjwt.test", key)
|
|
||||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
m := NewMJwtVerifier(&key.PublicKey)
|
|
||||||
_, _, err = ExtractClaims[testClaims](m, token)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("TestKID", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore)
|
|
||||||
token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
token2, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "hello"}, "key2")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
m := NewMJwtVerifierWithKeyStore(&key.PublicKey, kStore)
|
|
||||||
_, _, err = ExtractClaims[testClaims](m, token1)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
_, _, err = ExtractClaims[testClaims](m, token2)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractClaimsFail(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
kStore, key, key2, _ := setupTestKeyStoreMJWT(t)
|
|
||||||
|
|
||||||
t.Run("TestInvalidClaims", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
s := NewMJwtSigner("mjwt.test", key)
|
|
||||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
m := NewMJwtVerifier(&key.PublicKey)
|
|
||||||
_, _, err = ExtractClaims[testClaims2](m, token)
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.ErrorIs(t, err, ErrClaimTypeMismatch)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("TestDefaultKeyNoKID", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore)
|
|
||||||
token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}, "key1")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
m := NewMJwtVerifier(&key.PublicKey)
|
|
||||||
_, _, err = ExtractClaims[testClaims](m, token)
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.ErrorIs(t, err, ErrNoPublicKeyFound)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("TestNoDefaultKey", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore)
|
|
||||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
m := NewMJwtVerifierWithKeyStore(nil, kStore)
|
|
||||||
_, _, err = ExtractClaims[testClaims](m, token)
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.ErrorIs(t, err, ErrNoPublicKeyFound)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("TestKIDNonExist", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
kStore.SetKey(mt_ExtraKID, key2)
|
|
||||||
assert.Contains(t, kStore.ListKeys(), mt_ExtraKID)
|
|
||||||
|
|
||||||
s := NewMJwtSignerWithKeyStore("mjwt.test", key, kStore)
|
|
||||||
token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "test"}, mt_ExtraKID)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
kStore.RemoveKey(mt_ExtraKID)
|
|
||||||
assert.NotContains(t, kStore.ListKeys(), mt_ExtraKID)
|
|
||||||
|
|
||||||
m := NewMJwtVerifierWithKeyStore(&key.PublicKey, kStore)
|
|
||||||
_, _, err = ExtractClaims[testClaims](m, token)
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.ErrorIs(t, err, ErrNoPublicKeyFound)
|
|
||||||
})
|
|
||||||
}
|
|
204
signer.go
204
signer.go
@ -1,204 +0,0 @@
|
|||||||
package mjwt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/rsa"
|
|
||||||
"errors"
|
|
||||||
"github.com/1f349/rsa-helper/rsaprivate"
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const readLimit = 10240 // 10 KiB
|
|
||||||
|
|
||||||
var ErrNoPrivateKeyFound = errors.New("no private key found")
|
|
||||||
|
|
||||||
// defaultMJwtSigner implements Signer and uses an rsa.PrivateKey and issuer name
|
|
||||||
// to generate MJWT tokens
|
|
||||||
type defaultMJwtSigner struct {
|
|
||||||
issuer string
|
|
||||||
key *rsa.PrivateKey
|
|
||||||
verify *defaultMJwtVerifier
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Signer = &defaultMJwtSigner{}
|
|
||||||
var _ Verifier = &defaultMJwtSigner{}
|
|
||||||
|
|
||||||
// NewMJwtSigner creates a new defaultMJwtSigner using the issuer name and rsa.PrivateKey
|
|
||||||
func NewMJwtSigner(issuer string, key *rsa.PrivateKey) Signer {
|
|
||||||
return NewMJwtSignerWithKeyStore(issuer, key, NewMJwtKeyStore())
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMJwtSignerWithKeyStore creates a new defaultMJwtSigner using the issuer name, a rsa.PrivateKey
|
|
||||||
// for no kID and a KeyStore for kID based keys
|
|
||||||
func NewMJwtSignerWithKeyStore(issuer string, key *rsa.PrivateKey, kStore KeyStore) Signer {
|
|
||||||
var pKey *rsa.PublicKey = nil
|
|
||||||
if key != nil {
|
|
||||||
pKey = &key.PublicKey
|
|
||||||
}
|
|
||||||
return &defaultMJwtSigner{
|
|
||||||
issuer: issuer,
|
|
||||||
key: key,
|
|
||||||
verify: NewMJwtVerifierWithKeyStore(pKey, kStore).(*defaultMJwtVerifier),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMJwtSignerFromFileOrCreate creates a new defaultMJwtSigner using the path
|
|
||||||
// of a rsa.PrivateKey file. If the file does not exist then the file is created
|
|
||||||
// and a new private key is generated.
|
|
||||||
func NewMJwtSignerFromFileOrCreate(issuer, file string, random io.Reader, bits int) (Signer, error) {
|
|
||||||
privateKey, err := readOrCreatePrivateKey(file, random, bits)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return NewMJwtSigner(issuer, privateKey), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMJwtSignerFromFile creates a new defaultMJwtSigner using the path of a
|
|
||||||
// rsa.PrivateKey file.
|
|
||||||
func NewMJwtSignerFromFile(issuer, file string) (Signer, error) {
|
|
||||||
return NewMJwtSignerFromFileAndDirectory(issuer, file, "", "", "")
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMJwtSignerFromDirectory creates a new defaultMJwtSigner using the path of a directory to
|
|
||||||
// load the keys into a KeyStore; there is no default rsa.PrivateKey
|
|
||||||
func NewMJwtSignerFromDirectory(issuer, directory, prvExt, pubExt string) (Signer, error) {
|
|
||||||
return NewMJwtSignerFromFileAndDirectory(issuer, "", directory, prvExt, pubExt)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMJwtSignerFromFileAndDirectory creates a new defaultMJwtSigner using the path of a rsa.PrivateKey
|
|
||||||
// file as the non kID key and the path of a directory to load the keys into a KeyStore
|
|
||||||
func NewMJwtSignerFromFileAndDirectory(issuer, file, directory, prvExt, pubExt string) (Signer, error) {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// read key
|
|
||||||
var prv *rsa.PrivateKey = nil
|
|
||||||
if file != "" {
|
|
||||||
prv, err = rsaprivate.Read(file)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// read KeyStore
|
|
||||||
var kStore KeyStore = nil
|
|
||||||
if directory != "" {
|
|
||||||
kStore, err = NewMJwtKeyStoreFromDirectory(directory, prvExt, pubExt)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewMJwtSignerWithKeyStore(issuer, prv, kStore), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Issuer returns the name of the issuer
|
|
||||||
func (d *defaultMJwtSigner) Issuer() string {
|
|
||||||
return d.issuer
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateJwt generates and returns a JWT string using the sub, id, duration and claims; uses the default key
|
|
||||||
func (d *defaultMJwtSigner) GenerateJwt(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims) (string, error) {
|
|
||||||
return d.SignJwt(wrapClaims[Claims](d, sub, id, aud, dur, claims))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignJwt signs a jwt.Claims compatible struct, this is used internally by
|
|
||||||
// GenerateJwt but is available for signing custom structs; uses the default key
|
|
||||||
func (d *defaultMJwtSigner) SignJwt(wrapped jwt.Claims) (string, error) {
|
|
||||||
if d.key == nil {
|
|
||||||
return "", ErrNoPrivateKeyFound
|
|
||||||
}
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
|
|
||||||
return token.SignedString(d.key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateJwtWithKID generates and returns a JWT string using the sub, id, duration and claims; this gets signed with the specified kID
|
|
||||||
func (d *defaultMJwtSigner) GenerateJwtWithKID(sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims Claims, kID string) (string, error) {
|
|
||||||
return d.SignJwtWithKID(wrapClaims[Claims](d, sub, id, aud, dur, claims), kID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignJwtWithKID signs a jwt.Claims compatible struct, this is used internally by
|
|
||||||
// GenerateJwt but is available for signing custom structs; this gets signed with the specified kID
|
|
||||||
func (d *defaultMJwtSigner) SignJwtWithKID(wrapped jwt.Claims, kID string) (string, error) {
|
|
||||||
pKey := d.verify.GetKeyStore().GetKey(kID)
|
|
||||||
if pKey == nil {
|
|
||||||
return "", ErrNoPrivateKeyFound
|
|
||||||
}
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodRS512, wrapped)
|
|
||||||
token.Header["kid"] = kID
|
|
||||||
return token.SignedString(pKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// VerifyJwt validates and parses MJWT tokens see defaultMJwtVerifier.VerifyJwt()
|
|
||||||
func (d *defaultMJwtSigner) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
|
|
||||||
return d.verify.VerifyJwt(token, claims)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *defaultMJwtSigner) PrivateKey() *rsa.PrivateKey {
|
|
||||||
return d.key
|
|
||||||
}
|
|
||||||
func (d *defaultMJwtSigner) PublicKey() *rsa.PublicKey {
|
|
||||||
return d.verify.pub
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *defaultMJwtSigner) PublicKeyOf(kID string) *rsa.PublicKey {
|
|
||||||
return d.verify.kStore.GetKeyPublic(kID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *defaultMJwtSigner) GetKeyStore() KeyStore {
|
|
||||||
return d.verify.GetKeyStore()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *defaultMJwtSigner) PrivateKeyOf(kID string) *rsa.PrivateKey {
|
|
||||||
return d.verify.kStore.GetKey(kID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// readOrCreatePrivateKey returns the private key it the file already exists,
|
|
||||||
// generates a new private key and saves it to the file, or returns an error if
|
|
||||||
// reading or generating failed.
|
|
||||||
func readOrCreatePrivateKey(file string, random io.Reader, bits int) (*rsa.PrivateKey, error) {
|
|
||||||
// read the file or return nil
|
|
||||||
f, err := readOrEmptyFile(file)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if f == nil {
|
|
||||||
// generate a new key
|
|
||||||
key, err := rsa.GenerateKey(random, bits)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// save key to file
|
|
||||||
err = rsaprivate.Write(file, key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return key, err
|
|
||||||
} else {
|
|
||||||
// return key
|
|
||||||
return rsaprivate.Decode(bytes.NewReader(f))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// readOrEmptyFile returns bytes and errors from os.OpenFile or (nil, nil) if the
|
|
||||||
// file does not exist.
|
|
||||||
func readOrEmptyFile(file string) ([]byte, error) {
|
|
||||||
fp, err := os.Open(file)
|
|
||||||
if err != nil {
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer func() { _ = fp.Close() }()
|
|
||||||
// add hard limit
|
|
||||||
limitReader := io.LimitReader(fp, readLimit)
|
|
||||||
raw, err := io.ReadAll(limitReader)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return raw, nil
|
|
||||||
}
|
|
148
signer_test.go
148
signer_test.go
@ -1,148 +0,0 @@
|
|||||||
package mjwt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/pem"
|
|
||||||
"github.com/1f349/rsa-helper/rsaprivate"
|
|
||||||
"github.com/1f349/rsa-helper/rsapublic"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"os"
|
|
||||||
"path"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
const st_prvExt = "prv"
|
|
||||||
const st_pubExt = "pub"
|
|
||||||
|
|
||||||
func setupTestDirSigner(t *testing.T) (string, *rsa.PrivateKey, func(t *testing.T)) {
|
|
||||||
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
var key3 *rsa.PrivateKey = nil
|
|
||||||
key1, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+st_prvExt), key1)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
key2, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+st_prvExt), key2)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = rsapublic.Write(path.Join(tempDir, "key2.pem."+st_pubExt), &key2.PublicKey)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
key3, err = rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = rsapublic.Write(path.Join(tempDir, "key3.pem."+st_pubExt), &key3.PublicKey)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
return tempDir, key3, func(t *testing.T) {
|
|
||||||
err := os.RemoveAll(tempDir)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewMJwtSigner(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
NewMJwtSigner("Test", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewMJwtSignerWithKeyStore(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
kStore := NewMJwtKeyStore()
|
|
||||||
kStore.SetKey("test", key)
|
|
||||||
assert.Contains(t, kStore.ListKeys(), "test")
|
|
||||||
NewMJwtSignerWithKeyStore("Test", nil, kStore)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewMJwtSignerFromFile(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tempKey, err := os.CreateTemp("", "key-test-*.pem")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
|
||||||
_, err = tempKey.Write(b)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NoError(t, tempKey.Close())
|
|
||||||
signer, err := NewMJwtSignerFromFile("Test", tempKey.Name())
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NoError(t, os.Remove(tempKey.Name()))
|
|
||||||
_, err = NewMJwtSignerFromFile("Test", tempKey.Name())
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.True(t, os.IsNotExist(err))
|
|
||||||
assert.True(t, signer.(*defaultMJwtSigner).key.Equal(key))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewMJwtSignerFromFileOrCreate(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tempKey, err := os.CreateTemp("", "key-test-*.pem")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NoError(t, tempKey.Close())
|
|
||||||
assert.NoError(t, os.Remove(tempKey.Name()))
|
|
||||||
signer, err := NewMJwtSignerFromFileOrCreate("Test", tempKey.Name(), rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
signer2, err := NewMJwtSignerFromFileOrCreate("Test", tempKey.Name(), rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.True(t, signer.PrivateKey().Equal(signer2.PrivateKey()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadOrCreatePrivateKey(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tempKey, err := os.CreateTemp("", "key-test-*.pem")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
|
||||||
_, err = tempKey.Write(b)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NoError(t, tempKey.Close())
|
|
||||||
key2, err := readOrCreatePrivateKey(tempKey.Name(), rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.True(t, key.Equal(key2))
|
|
||||||
assert.NoError(t, os.Remove(tempKey.Name()))
|
|
||||||
key3, err := readOrCreatePrivateKey(tempKey.Name(), rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NoError(t, key3.Validate())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewMJwtSignerFromDirectory(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tempDir, prvKey3, cleaner := setupTestDirSigner(t)
|
|
||||||
defer cleaner(t)
|
|
||||||
|
|
||||||
signer, err := NewMJwtSignerFromDirectory("Test", tempDir, st_prvExt, st_pubExt)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Len(t, signer.GetKeyStore().ListKeys(), 3)
|
|
||||||
kIDsToFind := []string{"key1", "key2", "key3"}
|
|
||||||
for _, k := range kIDsToFind {
|
|
||||||
assert.Contains(t, signer.GetKeyStore().ListKeys(), k)
|
|
||||||
}
|
|
||||||
assert.True(t, prvKey3.PublicKey.Equal(signer.GetKeyStore().GetKeyPublic("key3")))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewMJwtSignerFromFileAndDirectory(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tempDir, prvKey3, cleaner := setupTestDirSigner(t)
|
|
||||||
defer cleaner(t)
|
|
||||||
|
|
||||||
signer, err := NewMJwtSignerFromFileAndDirectory("Test", path.Join(tempDir, "key1.pem."+st_prvExt), tempDir, st_prvExt, st_pubExt)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Len(t, signer.GetKeyStore().ListKeys(), 3)
|
|
||||||
kIDsToFind := []string{"key1", "key2", "key3"}
|
|
||||||
for _, k := range kIDsToFind {
|
|
||||||
assert.Contains(t, signer.GetKeyStore().ListKeys(), k)
|
|
||||||
}
|
|
||||||
assert.True(t, prvKey3.PublicKey.Equal(signer.GetKeyStore().GetKeyPublic("key3")))
|
|
||||||
assert.True(t, signer.PrivateKey().Equal(signer.GetKeyStore().GetKey("key1")))
|
|
||||||
}
|
|
108
verifier.go
108
verifier.go
@ -1,108 +0,0 @@
|
|||||||
package mjwt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rsa"
|
|
||||||
"errors"
|
|
||||||
"github.com/1f349/rsa-helper/rsapublic"
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
|
||||||
)
|
|
||||||
|
|
||||||
var ErrNoPublicKeyFound = errors.New("no public key found")
|
|
||||||
var ErrKIDInvalid = errors.New("kid invalid")
|
|
||||||
|
|
||||||
// defaultMJwtVerifier implements Verifier and uses a rsa.PublicKey to validate
|
|
||||||
// MJWT tokens
|
|
||||||
type defaultMJwtVerifier struct {
|
|
||||||
pub *rsa.PublicKey
|
|
||||||
kStore KeyStore
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Verifier = &defaultMJwtVerifier{}
|
|
||||||
|
|
||||||
// NewMJwtVerifier creates a new defaultMJwtVerifier using the rsa.PublicKey
|
|
||||||
func NewMJwtVerifier(key *rsa.PublicKey) Verifier {
|
|
||||||
return NewMJwtVerifierWithKeyStore(key, NewMJwtKeyStore())
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMJwtVerifierWithKeyStore creates a new defaultMJwtVerifier using a rsa.PublicKey as the non kID key
|
|
||||||
// and a KeyStore for kID based keys
|
|
||||||
func NewMJwtVerifierWithKeyStore(defaultKey *rsa.PublicKey, kStore KeyStore) Verifier {
|
|
||||||
return &defaultMJwtVerifier{pub: defaultKey, kStore: kStore}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMJwtVerifierFromFile creates a new defaultMJwtVerifier using the path of a
|
|
||||||
// rsa.PublicKey file
|
|
||||||
func NewMJwtVerifierFromFile(file string) (Verifier, error) {
|
|
||||||
return NewMJwtVerifierFromFileAndDirectory(file, "", "", "")
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMJwtVerifierFromDirectory creates a new defaultMJwtVerifier using the path of a directory to
|
|
||||||
// load the keys into a KeyStore; there is no default rsa.PublicKey
|
|
||||||
func NewMJwtVerifierFromDirectory(directory, prvExt, pubExt string) (Verifier, error) {
|
|
||||||
return NewMJwtVerifierFromFileAndDirectory("", directory, prvExt, pubExt)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMJwtVerifierFromFileAndDirectory creates a new defaultMJwtVerifier using the path of a rsa.PublicKey
|
|
||||||
// file as the non kID key and the path of a directory to load the keys into a KeyStore
|
|
||||||
func NewMJwtVerifierFromFileAndDirectory(file, directory, prvExt, pubExt string) (Verifier, error) {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// read key
|
|
||||||
var pub *rsa.PublicKey = nil
|
|
||||||
if file != "" {
|
|
||||||
pub, err = rsapublic.Read(file)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// read KeyStore
|
|
||||||
var kStore KeyStore = nil
|
|
||||||
if directory != "" {
|
|
||||||
kStore, err = NewMJwtKeyStoreFromDirectory(directory, prvExt, pubExt)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewMJwtVerifierWithKeyStore(pub, kStore), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// VerifyJwt validates and parses MJWT tokens and returns the claims
|
|
||||||
func (d *defaultMJwtVerifier) VerifyJwt(token string, claims baseTypeClaim) (*jwt.Token, error) {
|
|
||||||
withClaims, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
|
||||||
kIDI, exs := token.Header["kid"]
|
|
||||||
if exs {
|
|
||||||
kID, ok := kIDI.(string)
|
|
||||||
if !ok {
|
|
||||||
return nil, ErrKIDInvalid
|
|
||||||
}
|
|
||||||
key := d.kStore.GetKeyPublic(kID)
|
|
||||||
if key == nil {
|
|
||||||
return nil, ErrNoPublicKeyFound
|
|
||||||
} else {
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if d.pub == nil {
|
|
||||||
return nil, ErrNoPublicKeyFound
|
|
||||||
}
|
|
||||||
return d.pub, nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return withClaims, claims.Valid()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *defaultMJwtVerifier) PublicKey() *rsa.PublicKey {
|
|
||||||
return d.pub
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *defaultMJwtVerifier) PublicKeyOf(kID string) *rsa.PublicKey {
|
|
||||||
return d.kStore.GetKeyPublic(kID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *defaultMJwtVerifier) GetKeyStore() KeyStore {
|
|
||||||
return d.kStore
|
|
||||||
}
|
|
111
verifier_test.go
111
verifier_test.go
@ -1,111 +0,0 @@
|
|||||||
package mjwt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/pem"
|
|
||||||
"github.com/1f349/rsa-helper/rsaprivate"
|
|
||||||
"github.com/1f349/rsa-helper/rsapublic"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"os"
|
|
||||||
"path"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const vt_prvExt = "prv"
|
|
||||||
const vt_pubExt = "pub"
|
|
||||||
|
|
||||||
func setupTestDirVerifier(t *testing.T, genKeys bool) (string, *rsa.PrivateKey, func(t *testing.T)) {
|
|
||||||
tempDir, err := os.MkdirTemp("", "this-is-a-test-dir")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
var key3 *rsa.PrivateKey = nil
|
|
||||||
|
|
||||||
if genKeys {
|
|
||||||
key1, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = rsaprivate.Write(path.Join(tempDir, "key1.pem."+vt_prvExt), key1)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
key2, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = rsaprivate.Write(path.Join(tempDir, "key2.pem."+vt_prvExt), key2)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = rsapublic.Write(path.Join(tempDir, "key2.pem."+vt_pubExt), &key2.PublicKey)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
key3, err = rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = rsapublic.Write(path.Join(tempDir, "key3.pem."+vt_pubExt), &key3.PublicKey)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tempDir, key3, func(t *testing.T) {
|
|
||||||
err := os.RemoveAll(tempDir)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewMJwtVerifierFromFile(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
s := NewMJwtSigner("mjwt.test", key)
|
|
||||||
token, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
b := pem.EncodeToMemory(&pem.Block{Type: "RSA PUBLIC KEY", Bytes: x509.MarshalPKCS1PublicKey(&key.PublicKey)})
|
|
||||||
temp, err := os.CreateTemp("", "this-is-a-test-file.pem")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
_, err = temp.Write(b)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
file, err := NewMJwtVerifierFromFile(temp.Name())
|
|
||||||
assert.NoError(t, err)
|
|
||||||
_, _, err = ExtractClaims[testClaims](file, token)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = os.Remove(temp.Name())
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewMJwtVerifierFromDirectory(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tempDir, prvKey3, cleaner := setupTestDirVerifier(t, true)
|
|
||||||
defer cleaner(t)
|
|
||||||
|
|
||||||
s, err := NewMJwtSignerFromDirectory("mjwt.test", tempDir, vt_prvExt, vt_pubExt)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
s.GetKeyStore().SetKey("key3", prvKey3)
|
|
||||||
token, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}, "key3")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
v, err := NewMJwtVerifierFromDirectory(tempDir, vt_prvExt, vt_pubExt)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
_, _, err = ExtractClaims[testClaims](v, token)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewMJwtVerifierFromFileAndDirectory(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tempDir, prvKey3, cleaner := setupTestDirVerifier(t, true)
|
|
||||||
defer cleaner(t)
|
|
||||||
|
|
||||||
s, err := NewMJwtSignerFromFileAndDirectory("mjwt.test", path.Join(tempDir, "key2.pem."+vt_prvExt), tempDir, vt_prvExt, vt_pubExt)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
s.GetKeyStore().SetKey("key3", prvKey3)
|
|
||||||
token1, err := s.GenerateJwt("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
token2, err := s.GenerateJwtWithKID("1", "test", nil, 10*time.Minute, testClaims{TestValue: "world"}, "key3")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
v, err := NewMJwtVerifierFromFileAndDirectory(path.Join(tempDir, "key2.pem."+vt_pubExt), tempDir, vt_prvExt, vt_pubExt)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
_, _, err = ExtractClaims[testClaims](v, token1)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
_, _, err = ExtractClaims[testClaims](v, token2)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user